from flask import Flask, request, jsonify from dotenv import dotenv_values from common.runtime import load_root_env, validate_env, as_int, as_str import os import sys import re import matplotlib.pyplot as plt from Model import Model import numpy as np import matplotlib import importlib import threading import requests import asyncio import shutil import json import gc import logging TORCHSIG_PATH = "/app/torchsig" if TORCHSIG_PATH not in sys.path: # Ensure import torchsig resolves to /app/torchsig/torchsig package. sys.path.insert(0, TORCHSIG_PATH) logging.basicConfig(level=logging.INFO) app = Flask(__name__) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) queue = asyncio.Queue() semaphore = asyncio.Semaphore(3) prediction_list = [] result_msg = {} results = [] matplotlib.use('Agg') plt.ioff() alg_list = [] model_list = [] ROOT_ENV = load_root_env(__file__) validate_env("NN_server/server.py", { "GENERAL_SERVER_IP": as_str, "GENERAL_SERVER_PORT": as_int, "SERVER_IP": as_str, "SERVER_PORT": as_int, "PATH_TO_NN": as_str, "SRC_RESULT": as_str, "SRC_EXAMPLE": as_str, }) config = dict(dotenv_values(ROOT_ENV)) def is_model_config_key(key, value): return bool(re.fullmatch(r"NN_\d+", key or "")) and isinstance(value, str) and " && " in value def get_required_drone_streak(freq): return config.get(f"DRONE_STREAK_{freq}", "1") def get_required_drone_prob(freq): return config.get(f"DRONE_PROB_THRESHOLD_{freq}", config.get("DRONE_PROB_THRESHOLD_DEFAULT", "0")) def update_drone_streak(freq, prediction, drone_probability): required_prob = get_required_drone_prob(freq) drone_probability = 0.0 if drone_probability is None else float(drone_probability) passes_prob_gate = prediction == "drone" and drone_probability >= required_prob if passes_prob_gate: drone_streaks[freq] = drone_streaks.get(freq, 0) + 1 else: drone_streaks[freq] = 0 required = get_required_drone_streak(freq) triggered = passes_prob_gate and drone_streaks[freq] >= required logging.info( "NN alarm gate freq=%s prediction=%s drone_probability=%.3f threshold=%.3f streak=%s/%s triggered=%s", freq, prediction, drone_probability, required_prob, drone_streaks[freq], required, triggered, ) return 8 if triggered else 0 if not config: raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed") if not any(is_model_config_key(key, value) for key, value in config.items()): raise RuntimeError("[NN_server/server.py] no NN_* model entries configured") logging.info("NN config loaded from %s", ROOT_ENV) gen_server_ip = config['GENERAL_SERVER_IP'] gen_server_port = config['GENERAL_SERVER_PORT'] drone_streaks = {} def init_data_for_inference(): try: if os.path.isdir(config['SRC_RESULT']): shutil.rmtree(config['SRC_RESULT']) os.mkdir(config['SRC_RESULT']) if os.path.isdir(config['SRC_EXAMPLE']): shutil.rmtree(config['SRC_EXAMPLE']) os.mkdir(config['SRC_EXAMPLE']) except Exception as exc: print(str(exc)) print() try: global model_list for key, value in config.items(): if is_model_config_key(key, value): params = value.split(' && ') module = importlib.import_module('Models.' + params[4]) classes = {} for value in params[9][1:-1].split(','): classes[len(classes)] = value model = Model(file_model=params[0], file_config=params[1], src_example=params[2], src_result=params[3], type_model=params[4], build_model_func=getattr(module, params[5]), pre_func=getattr(module, params[6]), inference_func=getattr(module, params[7]), post_func=getattr(module, params[8]), classes=classes, number_synthetic_examples=int(params[10]), number_src_data_for_one_synthetic_example=int(params[11]), path_to_src_dataset=params[12]) model_list.append(model) # if key.startswith('ALG_'): # params = config[key].split(' && ') # module = importlib.import_module('Algorithms.' + params[2]) # classes = {} # for value in params[6][1:-1].split(','): # classes[len(classes)] = value # alg = Algorithm(src_example=params[0], src_result=params[1], type_alg=params[2], pre_func=getattr(module, params[3]), # inference_func=getattr(module, params[4]), post_func=getattr(module, params[5]), classes=classes, # number_synthetic_examples=int(params[7]), number_src_data_for_one_synthetic_example=int(params[8]), path_to_src_dataset=params[9]) # alg_list.append(alg) except Exception as exc: print(str(exc)) print() def run_example(): try: for model in model_list: model.get_test_inference() except Exception as exc: print(str(exc)) @app.route('/receive_data', methods=['POST']) def receive_data(): try: print() data = json.loads(request.json) print('#' * 100) print('Получен пакет ' + str(Model.get_ind_inference())) freq = int(data['freq']) print('Частота: ' + str(freq)) # print('Канал: ' + str(data['channel'])) result_msg = {} data_to_send = {} prediction_list = [] #print(model_list) for model in model_list: #print(str(freq)) #print(model.get_model_name()) if str(freq) in model.get_model_name(): print('-' * 100) print(str(model)) result_msg[str(model.get_model_name())] = {'freq': freq} inference_result = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)]) if inference_result is None: raise RuntimeError(f"Inference failed for {model.get_model_name()}") prediction, probability = inference_result[:2] drone_probability = float(probability) if prediction == "drone" else 0.0 result_msg[str(model.get_model_name())]['prediction'] = prediction result_msg[str(model.get_model_name())]['probability'] = str(probability) result_msg[str(model.get_model_name())]['drone_probability'] = str(drone_probability) result_msg[str(model.get_model_name())]['drone_threshold'] = str(get_required_drone_prob(freq)) prediction_list.append(prediction) print('-' * 100) print() try: result = update_drone_streak(freq, prediction, drone_probability) data_to_send={ 'freq': str(freq), #'channel': int(data['channel']), 'amplitude': result #'triggered': False if result < 7 else True, #'light_len': result } response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send) if response.status_code == 200: print("Данные успешно отправлены!") print("Частота: " + str(freq)) print("Отправлено светодиодов: " + str(result)) else: print("Ошибка при отправке данных: ", response.status_code) except Exception as exc: print(str(exc)) break Model.get_inc_ind_inference() print() print('#' * 100) for alg in alg_list: print('-' * 100) print(str(alg)) alg.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)]) print('-' * 100) print() #Algorithm.get_inc_ind_inference() print() print('#' * 100) del data gc.collect() return jsonify(result_msg) except Exception as exc: print(str(exc)) ''' def run_flask(): app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT'])) async def process_tasks(): workers = [asyncio.create_task(worker(queue=queue, semaphore=semaphore)) for _ in range(2)] await asyncio.gather(*workers) async def main(): asyncio.create_task(process_tasks()) flask_thread = threading.Thread(target=run_flask) flask_thread.start() while True: if queue.qsize() <= 1: asyncio.create_task(process_tasks()) await asyncio.sleep(1) @app.route('/receive_data', methods=['POST']) def add_task(): queue_size = queue.qsize() if queue_size > 1: return {} print() data = json.loads(request.json) print('#' * 100) print('Получен пакет ' + str(Model.get_ind_inference())) freq = int(data['freq']) print('Частота ' + str(freq)) result_msg = {} for model in model_list: if str(freq) in model.get_model_name(): print('-' * 100) print(str(model)) result_msg[str(model.get_model_name())] = {'freq': freq} asyncio.run_coroutine_threadsafe(queue.put({'freq': freq, 'model': model, 'data': data}), loop) do_inference(model=model, data=data, freq=freq) break del data gc.collect() return jsonify(result_msg) async def worker(queue, semaphore): while True: task = await queue.get() if task is None: break async with semaphore: try: await do_inference(model=task['model'], data=task['data'], freq=task['freq']) except Exception as e: print(str(e)) print(results) queue.task_done() async def do_inference(model=None, data=None, freq=0): prediction_list = [] print("Длина очереди" + str(queue.qsize())) inference(model=model, data=data, freq=freq) try: results = [] for pred in prediction_list: if pred[1] == 'drone': results.append([pred[0],8]) else: results.append([pred[0],0]) for result in results: try: data_to_send={ 'freq': result[0], 'amplitude': result[1], 'triggered': False if result[1] < 7 else True, 'light_len': result[1] } response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send) await response.text if response.status_code == 200: print("Данные успешно отправлены!") print("Отправлено светодиодов: " + str(data_to_send['light_len'])) else: print("Ошибка при отправке данных: ", response.status_code) except Exception as exc: print(str(exc)) except Exception as exc: print(str(exc)) Model.get_inc_ind_inference() print() print('#' * 100) del data gc.collect() def inference(model=None, data=None, freq=0): prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)]) result_msg[str(model.get_model_name())]['prediction'] = prediction result_msg[str(model.get_model_name())]['probability'] = str(probability) queue_size = queue.qsize() print(queue_size) prediction_list.append([freq, prediction]) print('-' * 100) print() if __name__ == '__main__': init_data_for_inference() #asyncio.run(main) loop.run_until_complete(main()) ''' def run_flask(): print(config['SERVER_IP']) app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT'])) if __name__ == '__main__': init_data_for_inference() flask_thread = threading.Thread(target=run_flask) flask_thread.start() #app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))