You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
371 lines
12 KiB
Python
371 lines
12 KiB
Python
from flask import Flask, request, jsonify
|
|
from dotenv import dotenv_values, load_dotenv
|
|
from common.nn_profile_schedule import (
|
|
get_profile_model_entries,
|
|
normalize_profile_name,
|
|
resolve_active_profile,
|
|
)
|
|
from common.runtime import load_root_env, validate_env, as_int, as_str
|
|
import os
|
|
import sys
|
|
import matplotlib.pyplot as plt
|
|
from Model import Model
|
|
import numpy as np
|
|
import matplotlib
|
|
import importlib
|
|
import threading
|
|
import requests
|
|
import asyncio
|
|
import shutil
|
|
import json
|
|
import gc
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
TORCHSIG_PATH = "/app/torchsig"
|
|
if TORCHSIG_PATH not in sys.path:
|
|
# Ensure import torchsig resolves to /app/torchsig/torchsig package.
|
|
sys.path.insert(0, TORCHSIG_PATH)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
app = Flask(__name__)
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
queue = asyncio.Queue()
|
|
semaphore = asyncio.Semaphore(3)
|
|
|
|
prediction_list = []
|
|
result_msg = {}
|
|
results = []
|
|
matplotlib.use('Agg')
|
|
plt.ioff()
|
|
|
|
alg_list = []
|
|
model_list = []
|
|
ROOT_ENV = load_root_env(__file__)
|
|
RUNTIME_ENV = Path(ROOT_ENV).parent / "runtime" / "nn_active_profile.env"
|
|
if RUNTIME_ENV.exists():
|
|
load_dotenv(RUNTIME_ENV, override=True)
|
|
validate_env("NN_server/server.py", {
|
|
"GENERAL_SERVER_IP": as_str,
|
|
"GENERAL_SERVER_PORT": as_int,
|
|
"SERVER_IP": as_str,
|
|
"SERVER_PORT": as_int,
|
|
"PATH_TO_NN": as_str,
|
|
"SRC_RESULT": as_str,
|
|
"SRC_EXAMPLE": as_str,
|
|
})
|
|
config = dict(dotenv_values(ROOT_ENV))
|
|
if RUNTIME_ENV.exists():
|
|
config.update(dotenv_values(RUNTIME_ENV))
|
|
|
|
|
|
def get_required_drone_streak(freq):
|
|
raw_value = config.get(f"DRONE_STREAK_{freq}", "1")
|
|
try:
|
|
return max(1, int(raw_value))
|
|
except (TypeError, ValueError):
|
|
logging.warning("Invalid DRONE_STREAK_%s=%r, falling back to 1", freq, raw_value)
|
|
return 1
|
|
|
|
|
|
def update_drone_streak(freq, prediction):
|
|
if prediction == "drone":
|
|
drone_streaks[freq] = drone_streaks.get(freq, 0) + 1
|
|
else:
|
|
drone_streaks[freq] = 0
|
|
|
|
required = get_required_drone_streak(freq)
|
|
triggered = prediction == "drone" and drone_streaks[freq] >= required
|
|
logging.info(
|
|
"NN alarm gate freq=%s prediction=%s streak=%s/%s triggered=%s",
|
|
freq,
|
|
prediction,
|
|
drone_streaks[freq],
|
|
required,
|
|
triggered,
|
|
)
|
|
return 8 if triggered else 0
|
|
|
|
|
|
if not config:
|
|
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
|
|
logging.info("NN config loaded from %s", ROOT_ENV)
|
|
if RUNTIME_ENV.exists():
|
|
logging.info("NN runtime overrides loaded from %s", RUNTIME_ENV)
|
|
gen_server_ip = config['GENERAL_SERVER_IP']
|
|
gen_server_port = config['GENERAL_SERVER_PORT']
|
|
requested_profile = normalize_profile_name(config.get("NN_ACTIVE_PROFILE"))
|
|
active_profile = resolve_active_profile({k: v for k, v in config.items() if k != "NN_SCHEDULE"})
|
|
if requested_profile != active_profile:
|
|
logging.warning(
|
|
"Requested NN profile %s is not configured, falling back to %s",
|
|
requested_profile,
|
|
active_profile,
|
|
)
|
|
logging.info("NN active profile: %s", active_profile)
|
|
drone_streaks = {}
|
|
|
|
def init_data_for_inference():
|
|
try:
|
|
if os.path.isdir(config['SRC_RESULT']):
|
|
shutil.rmtree(config['SRC_RESULT'])
|
|
os.mkdir(config['SRC_RESULT'])
|
|
if os.path.isdir(config['SRC_EXAMPLE']):
|
|
shutil.rmtree(config['SRC_EXAMPLE'])
|
|
os.mkdir(config['SRC_EXAMPLE'])
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
try:
|
|
global model_list
|
|
model_list = []
|
|
loaded_model_keys = []
|
|
model_entries = get_profile_model_entries(config, active_profile)
|
|
if not model_entries:
|
|
raise RuntimeError(f"[NN_server/server.py] no models configured for profile {active_profile!r}")
|
|
|
|
for key, value in model_entries:
|
|
params = value.split(' && ')
|
|
module = importlib.import_module('Models.' + params[4])
|
|
classes = {}
|
|
for value in params[9][1:-1].split(','):
|
|
classes[len(classes)] = value
|
|
model = Model(file_model=params[0], file_config=params[1], src_example=params[2], src_result=params[3],
|
|
type_model=params[4], build_model_func=getattr(module, params[5]),
|
|
pre_func=getattr(module, params[6]), inference_func=getattr(module, params[7]),
|
|
post_func=getattr(module, params[8]), classes=classes, number_synthetic_examples=int(params[10]),
|
|
number_src_data_for_one_synthetic_example=int(params[11]), path_to_src_dataset=params[12])
|
|
model_list.append(model)
|
|
loaded_model_keys.append(key)
|
|
|
|
logging.info(
|
|
"Loaded %s NN models for profile %s: %s",
|
|
len(model_list),
|
|
active_profile,
|
|
", ".join(loaded_model_keys),
|
|
)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
|
|
def run_example():
|
|
try:
|
|
for model in model_list:
|
|
model.get_test_inference()
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
def receive_data():
|
|
try:
|
|
print()
|
|
data = json.loads(request.json)
|
|
print('#' * 100)
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
freq = int(data['freq'])
|
|
print('Частота: ' + str(freq))
|
|
# print('Канал: ' + str(data['channel']))
|
|
|
|
result_msg = {}
|
|
data_to_send = {}
|
|
prediction_list = []
|
|
#print(model_list)
|
|
for model in model_list:
|
|
#print(str(freq))
|
|
#print(model.get_model_name())
|
|
if str(freq) in model.get_model_name():
|
|
print('-' * 100)
|
|
print(str(model))
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
prediction_list.append(prediction)
|
|
print('-' * 100)
|
|
print()
|
|
|
|
try:
|
|
result = update_drone_streak(freq, prediction_list[0])
|
|
if str(freq)==2400:
|
|
result=0
|
|
data_to_send={
|
|
'freq': str(freq),
|
|
'amplitude': result
|
|
#'triggered': False if result < 7 else True,
|
|
#'light_len': result
|
|
}
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
if response.status_code == 200:
|
|
print("Данные успешно отправлены!")
|
|
print("Частота: " + str(freq))
|
|
print("Отправлено светодиодов: " + str(result))
|
|
else:
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
break
|
|
|
|
Model.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
for alg in alg_list:
|
|
print('-' * 100)
|
|
print(str(alg))
|
|
alg.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
print('-' * 100)
|
|
print()
|
|
|
|
#Algorithm.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
del data
|
|
gc.collect()
|
|
|
|
return jsonify(result_msg)
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
'''
|
|
def run_flask():
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
async def process_tasks():
|
|
workers = [asyncio.create_task(worker(queue=queue, semaphore=semaphore)) for _ in range(2)]
|
|
await asyncio.gather(*workers)
|
|
|
|
|
|
async def main():
|
|
asyncio.create_task(process_tasks())
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
flask_thread.start()
|
|
|
|
while True:
|
|
if queue.qsize() <= 1:
|
|
asyncio.create_task(process_tasks())
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
def add_task():
|
|
queue_size = queue.qsize()
|
|
if queue_size > 1:
|
|
return {}
|
|
|
|
print()
|
|
data = json.loads(request.json)
|
|
print('#' * 100)
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
freq = int(data['freq'])
|
|
print('Частота ' + str(freq))
|
|
|
|
|
|
result_msg = {}
|
|
for model in model_list:
|
|
if str(freq) in model.get_model_name():
|
|
print('-' * 100)
|
|
print(str(model))
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
asyncio.run_coroutine_threadsafe(queue.put({'freq': freq, 'model': model, 'data': data}), loop)
|
|
do_inference(model=model, data=data, freq=freq)
|
|
break
|
|
|
|
del data
|
|
gc.collect()
|
|
return jsonify(result_msg)
|
|
|
|
|
|
async def worker(queue, semaphore):
|
|
while True:
|
|
task = await queue.get()
|
|
if task is None:
|
|
break
|
|
async with semaphore:
|
|
try:
|
|
await do_inference(model=task['model'], data=task['data'], freq=task['freq'])
|
|
except Exception as e:
|
|
print(str(e))
|
|
print(results)
|
|
queue.task_done()
|
|
|
|
|
|
async def do_inference(model=None, data=None, freq=0):
|
|
prediction_list = []
|
|
print("Длина очереди" + str(queue.qsize()))
|
|
inference(model=model, data=data, freq=freq)
|
|
|
|
try:
|
|
results = []
|
|
for pred in prediction_list:
|
|
if pred[1] == 'drone':
|
|
results.append([pred[0],8])
|
|
else:
|
|
results.append([pred[0],0])
|
|
for result in results:
|
|
try:
|
|
data_to_send={
|
|
'freq': result[0],
|
|
'amplitude': result[1],
|
|
'triggered': False if result[1] < 7 else True,
|
|
'light_len': result[1]
|
|
}
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
await response.text
|
|
if response.status_code == 200:
|
|
print("Данные успешно отправлены!")
|
|
print("Отправлено светодиодов: " + str(data_to_send['light_len']))
|
|
else:
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
Model.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
del data
|
|
gc.collect()
|
|
|
|
|
|
def inference(model=None, data=None, freq=0):
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
queue_size = queue.qsize()
|
|
print(queue_size)
|
|
prediction_list.append([freq, prediction])
|
|
print('-' * 100)
|
|
print()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
init_data_for_inference()
|
|
#asyncio.run(main)
|
|
loop.run_until_complete(main())
|
|
'''
|
|
|
|
def run_flask():
|
|
print(config['SERVER_IP'])
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
init_data_for_inference()
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
flask_thread.start()
|
|
|
|
#app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|