You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/NN_server/server.py

420 lines
14 KiB
Python

from flask import Flask, request, jsonify
from dotenv import dotenv_values
from common.runtime import load_root_env, validate_env, as_int, as_str
import os
import sys
import matplotlib.pyplot as plt
from Model import Model
import numpy as np
import matplotlib
import importlib
import threading
import requests
import asyncio
import shutil
import json
import gc
import logging
import time
import re
TORCHSIG_PATH = "/app/torchsig"
if TORCHSIG_PATH not in sys.path:
# Ensure import torchsig resolves to /app/torchsig/torchsig package.
sys.path.insert(0, TORCHSIG_PATH)
logging.basicConfig(level=logging.INFO)
app = Flask(__name__)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
queue = asyncio.Queue()
semaphore = asyncio.Semaphore(3)
receive_data_lock = threading.Lock()
prediction_list = []
result_msg = {}
results = []
matplotlib.use('Agg')
plt.ioff()
alg_list = []
model_list = []
ROOT_ENV = load_root_env(__file__)
validate_env("NN_server/server.py", {
"GENERAL_SERVER_IP": as_str,
"GENERAL_SERVER_PORT": as_int,
"SERVER_IP": as_str,
"SERVER_PORT": as_int,
"PATH_TO_NN": as_str,
"SRC_RESULT": as_str,
"SRC_EXAMPLE": as_str,
})
config = dict(dotenv_values(ROOT_ENV))
if not config:
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
MODEL_ENV_RE = re.compile(r"^NN_\d+$")
if not any(MODEL_ENV_RE.match(key) for key in config):
raise RuntimeError("[NN_server/server.py] no NN_* model entries configured")
logging.info("NN config loaded from %s", ROOT_ENV)
gen_server_ip = config['GENERAL_SERVER_IP']
gen_server_port = config['GENERAL_SERVER_PORT']
INFERENCE_TELEMETRY_HOST = os.getenv('telemetry_host', '127.0.0.1')
INFERENCE_TELEMETRY_PORT = os.getenv('telemetry_port', '5020')
INFERENCE_TELEMETRY_ENDPOINT = os.getenv('telemetry_inference_endpoint', 'inference/result')
INFERENCE_TELEMETRY_TIMEOUT_SEC = float(os.getenv('telemetry_inference_timeout_sec', '0.30'))
INFERENCE_IMAGE_RE = re.compile(r"_inference_(\d+)_")
def get_result_dir():
return config.get('SRC_RESULT', '')
def collect_inference_images(result_id, model_name=''):
result_dir = get_result_dir()
if not result_dir or not os.path.isdir(result_dir):
return result_id, []
needle = f"_inference_{result_id}_"
model_suffix = f"_{model_name}.png" if model_name else ''
exact_images = []
grouped_images = {}
for name in sorted(os.listdir(result_dir)):
if not name.endswith('.png'):
continue
if model_suffix and not name.endswith(model_suffix):
continue
match = INFERENCE_IMAGE_RE.search(name)
if match is None:
continue
image_result_id = int(match.group(1))
grouped_images.setdefault(image_result_id, []).append(name)
if image_result_id == result_id and needle in name:
exact_images.append(name)
if exact_images:
return result_id, exact_images
return result_id, []
def send_inference_result(payload):
try:
requests.post(
"http://{0}:{1}/{2}".format(
INFERENCE_TELEMETRY_HOST,
INFERENCE_TELEMETRY_PORT,
INFERENCE_TELEMETRY_ENDPOINT.lstrip('/'),
),
json=payload,
timeout=INFERENCE_TELEMETRY_TIMEOUT_SEC,
)
except Exception as exc:
print(str(exc))
def reset_directory_contents(path):
os.makedirs(path, exist_ok=True)
for name in os.listdir(path):
full_path = os.path.join(path, name)
try:
if os.path.isdir(full_path) and not os.path.islink(full_path):
shutil.rmtree(full_path)
else:
os.remove(full_path)
except FileNotFoundError:
continue
def init_data_for_inference():
try:
reset_directory_contents(config['SRC_RESULT'])
reset_directory_contents(config['SRC_EXAMPLE'])
except Exception as exc:
print(str(exc))
print()
try:
global model_list
for key in config.keys():
if MODEL_ENV_RE.match(key):
params = config[key].split(' && ')
module = importlib.import_module('Models.' + params[4])
classes = {}
for value in params[9][1:-1].split(','):
classes[len(classes)] = value
model = Model(file_model=params[0], file_config=params[1], src_example=params[2], src_result=params[3],
type_model=params[4], build_model_func=getattr(module, params[5]),
pre_func=getattr(module, params[6]), inference_func=getattr(module, params[7]),
post_func=getattr(module, params[8]), classes=classes, number_synthetic_examples=int(params[10]),
number_src_data_for_one_synthetic_example=int(params[11]), path_to_src_dataset=params[12])
model_list.append(model)
# if key.startswith('ALG_'):
# params = config[key].split(' && ')
# module = importlib.import_module('Algorithms.' + params[2])
# classes = {}
# for value in params[6][1:-1].split(','):
# classes[len(classes)] = value
# alg = Algorithm(src_example=params[0], src_result=params[1], type_alg=params[2], pre_func=getattr(module, params[3]),
# inference_func=getattr(module, params[4]), post_func=getattr(module, params[5]), classes=classes,
# number_synthetic_examples=int(params[7]), number_src_data_for_one_synthetic_example=int(params[8]), path_to_src_dataset=params[9])
# alg_list.append(alg)
except Exception as exc:
print(str(exc))
print()
def run_example():
try:
for model in model_list:
model.get_test_inference()
except Exception as exc:
print(str(exc))
@app.route('/receive_data', methods=['POST'])
def receive_data():
with receive_data_lock:
return _receive_data_locked()
def _receive_data_locked():
try:
print()
data = request.json
if isinstance(data, str):
data = json.loads(data)
print('#' * 100)
print('Получен пакет ' + str(Model.get_ind_inference()))
result_id = Model.get_ind_inference()
freq = int(data['freq'])
print('Частота: ' + str(freq))
# print('Канал: ' + str(data['channel']))
result_msg = {}
data_to_send = {}
prediction_list = []
#print(model_list)
for model in model_list:
#print(str(freq))
#print(model.get_model_name())
if str(freq) in model.get_model_name():
print('-' * 100)
print(str(model))
result_msg[str(model.get_model_name())] = {'freq': freq}
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)], ind_inference=result_id)
result_msg[str(model.get_model_name())]['prediction'] = prediction
result_msg[str(model.get_model_name())]['probability'] = str(probability)
prediction_list.append(prediction)
image_result_id, images = collect_inference_images(result_id, model.get_model_name())
send_inference_result({
'result_id': image_result_id,
'ts': time.time(),
'freq': str(freq),
'model': model.get_model_name(),
'prediction': prediction,
'probability': float(probability),
'drone_probability': float(probability) if prediction == 'drone' else 0.0,
'drone_threshold': None,
'images': images,
})
print('-' * 100)
print()
try:
result = 0
freq_int = int(freq)
prediction = prediction_list[0]
prob = float(probability)
if freq_int == 2400:
if prediction in ["drone", "drone_noise"]:
result += 0
elif prediction == "wifi" and prob >= 0.95:
result += 0
elif freq_int == 1200:
if prediction == "drone" and prob >= 0.95:
result += 8
elif freq_int == 915:
result = 0
data_to_send = {
"freq": str(freq),
"amplitude": result,
}
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
if response.status_code == 200:
print("Данные успешно отправлены!")
print("Частота: " + str(freq))
print("Отправлено светодиодов: " + str(result))
else:
print("Ошибка при отправке данных: ", response.status_code)
except Exception as exc:
print(str(exc))
break
Model.get_inc_ind_inference()
print()
print('#' * 100)
for alg in alg_list:
print('-' * 100)
print(str(alg))
alg.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
print('-' * 100)
print()
#Algorithm.get_inc_ind_inference()
print()
print('#' * 100)
del data
gc.collect()
return jsonify(result_msg)
except Exception as exc:
print(str(exc))
'''
def run_flask():
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
async def process_tasks():
workers = [asyncio.create_task(worker(queue=queue, semaphore=semaphore)) for _ in range(2)]
await asyncio.gather(*workers)
async def main():
asyncio.create_task(process_tasks())
flask_thread = threading.Thread(target=run_flask)
flask_thread.start()
while True:
if queue.qsize() <= 1:
asyncio.create_task(process_tasks())
await asyncio.sleep(1)
@app.route('/receive_data', methods=['POST'])
def add_task():
queue_size = queue.qsize()
if queue_size > 1:
return {}
print()
data = json.loads(request.json)
print('#' * 100)
print('Получен пакет ' + str(Model.get_ind_inference()))
freq = int(data['freq'])
print('Частота ' + str(freq))
result_msg = {}
for model in model_list:
if str(freq) in model.get_model_name():
print('-' * 100)
print(str(model))
result_msg[str(model.get_model_name())] = {'freq': freq}
asyncio.run_coroutine_threadsafe(queue.put({'freq': freq, 'model': model, 'data': data}), loop)
do_inference(model=model, data=data, freq=freq)
break
del data
gc.collect()
return jsonify(result_msg)
async def worker(queue, semaphore):
while True:
task = await queue.get()
if task is None:
break
async with semaphore:
try:
await do_inference(model=task['model'], data=task['data'], freq=task['freq'])
except Exception as e:
print(str(e))
print(results)
queue.task_done()
async def do_inference(model=None, data=None, freq=0):
prediction_list = []
print("Длина очереди" + str(queue.qsize()))
inference(model=model, data=data, freq=freq)
try:
results = []
for pred in prediction_list:
if pred[1] == 'drone':
results.append([pred[0],8])
else:
results.append([pred[0],0])
for result in results:
try:
data_to_send={
'freq': result[0],
'amplitude': result[1],
'triggered': False if result[1] < 7 else True,
'light_len': result[1]
}
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
await response.text
if response.status_code == 200:
print("Данные успешно отправлены!")
print("Отправлено светодиодов: " + str(data_to_send['light_len']))
else:
print("Ошибка при отправке данных: ", response.status_code)
except Exception as exc:
print(str(exc))
except Exception as exc:
print(str(exc))
Model.get_inc_ind_inference()
print()
print('#' * 100)
del data
gc.collect()
def inference(model=None, data=None, freq=0):
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
result_msg[str(model.get_model_name())]['prediction'] = prediction
result_msg[str(model.get_model_name())]['probability'] = str(probability)
queue_size = queue.qsize()
print(queue_size)
prediction_list.append([freq, prediction])
print('-' * 100)
print()
if __name__ == '__main__':
init_data_for_inference()
#asyncio.run(main)
loop.run_until_complete(main())
'''
def run_flask():
print(config['SERVER_IP'])
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
if __name__ == '__main__':
init_data_for_inference()
flask_thread = threading.Thread(target=run_flask)
flask_thread.start()
#app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))