You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
418 lines
14 KiB
Python
418 lines
14 KiB
Python
from flask import Flask, request, jsonify
|
|
from dotenv import dotenv_values
|
|
from common.runtime import load_root_env, validate_env, as_int, as_str
|
|
import os
|
|
import sys
|
|
import matplotlib.pyplot as plt
|
|
from Model import Model
|
|
import numpy as np
|
|
import matplotlib
|
|
import importlib
|
|
import threading
|
|
import requests
|
|
import asyncio
|
|
import shutil
|
|
import json
|
|
import gc
|
|
import logging
|
|
import time
|
|
import re
|
|
|
|
TORCHSIG_PATH = "/app/torchsig"
|
|
if TORCHSIG_PATH not in sys.path:
|
|
# Ensure import torchsig resolves to /app/torchsig/torchsig package.
|
|
sys.path.insert(0, TORCHSIG_PATH)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
app = Flask(__name__)
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
queue = asyncio.Queue()
|
|
semaphore = asyncio.Semaphore(3)
|
|
|
|
prediction_list = []
|
|
result_msg = {}
|
|
results = []
|
|
matplotlib.use('Agg')
|
|
plt.ioff()
|
|
|
|
alg_list = []
|
|
model_list = []
|
|
ROOT_ENV = load_root_env(__file__)
|
|
validate_env("NN_server/server.py", {
|
|
"GENERAL_SERVER_IP": as_str,
|
|
"GENERAL_SERVER_PORT": as_int,
|
|
"SERVER_IP": as_str,
|
|
"SERVER_PORT": as_int,
|
|
"PATH_TO_NN": as_str,
|
|
"SRC_RESULT": as_str,
|
|
"SRC_EXAMPLE": as_str,
|
|
})
|
|
config = dict(dotenv_values(ROOT_ENV))
|
|
|
|
if not config:
|
|
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
|
|
MODEL_ENV_RE = re.compile(r"^NN_\d+$")
|
|
|
|
if not any(MODEL_ENV_RE.match(key) for key in config):
|
|
raise RuntimeError("[NN_server/server.py] no NN_* model entries configured")
|
|
logging.info("NN config loaded from %s", ROOT_ENV)
|
|
gen_server_ip = config['GENERAL_SERVER_IP']
|
|
gen_server_port = config['GENERAL_SERVER_PORT']
|
|
INFERENCE_TELEMETRY_HOST = os.getenv('telemetry_host', '127.0.0.1')
|
|
INFERENCE_TELEMETRY_PORT = os.getenv('telemetry_port', '5020')
|
|
INFERENCE_TELEMETRY_ENDPOINT = os.getenv('telemetry_inference_endpoint', 'inference/result')
|
|
INFERENCE_TELEMETRY_TIMEOUT_SEC = float(os.getenv('telemetry_inference_timeout_sec', '0.30'))
|
|
INFERENCE_IMAGE_RE = re.compile(r"_inference_(\d+)_")
|
|
|
|
|
|
def get_result_dir():
|
|
return config.get('SRC_RESULT', '')
|
|
|
|
|
|
def collect_inference_images(result_id, model_name=''):
|
|
result_dir = get_result_dir()
|
|
if not result_dir or not os.path.isdir(result_dir):
|
|
return result_id, []
|
|
|
|
needle = f"_inference_{result_id}_"
|
|
model_suffix = f"_{model_name}.png" if model_name else ''
|
|
exact_images = []
|
|
grouped_images = {}
|
|
|
|
for name in sorted(os.listdir(result_dir)):
|
|
if not name.endswith('.png'):
|
|
continue
|
|
if model_suffix and not name.endswith(model_suffix):
|
|
continue
|
|
match = INFERENCE_IMAGE_RE.search(name)
|
|
if match is None:
|
|
continue
|
|
|
|
image_result_id = int(match.group(1))
|
|
grouped_images.setdefault(image_result_id, []).append(name)
|
|
if image_result_id == result_id and needle in name:
|
|
exact_images.append(name)
|
|
|
|
if exact_images:
|
|
return result_id, exact_images
|
|
if not grouped_images:
|
|
return result_id, []
|
|
|
|
latest_result_id = max(grouped_images)
|
|
return latest_result_id, grouped_images[latest_result_id]
|
|
|
|
|
|
def send_inference_result(payload):
|
|
try:
|
|
requests.post(
|
|
"http://{0}:{1}/{2}".format(
|
|
INFERENCE_TELEMETRY_HOST,
|
|
INFERENCE_TELEMETRY_PORT,
|
|
INFERENCE_TELEMETRY_ENDPOINT.lstrip('/'),
|
|
),
|
|
json=payload,
|
|
timeout=INFERENCE_TELEMETRY_TIMEOUT_SEC,
|
|
)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
def reset_directory_contents(path):
|
|
os.makedirs(path, exist_ok=True)
|
|
for name in os.listdir(path):
|
|
full_path = os.path.join(path, name)
|
|
try:
|
|
if os.path.isdir(full_path) and not os.path.islink(full_path):
|
|
shutil.rmtree(full_path)
|
|
else:
|
|
os.remove(full_path)
|
|
except FileNotFoundError:
|
|
continue
|
|
|
|
|
|
def init_data_for_inference():
|
|
try:
|
|
reset_directory_contents(config['SRC_RESULT'])
|
|
reset_directory_contents(config['SRC_EXAMPLE'])
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
try:
|
|
global model_list
|
|
for key in config.keys():
|
|
if MODEL_ENV_RE.match(key):
|
|
params = config[key].split(' && ')
|
|
module = importlib.import_module('Models.' + params[4])
|
|
classes = {}
|
|
for value in params[9][1:-1].split(','):
|
|
classes[len(classes)] = value
|
|
model = Model(file_model=params[0], file_config=params[1], src_example=params[2], src_result=params[3],
|
|
type_model=params[4], build_model_func=getattr(module, params[5]),
|
|
pre_func=getattr(module, params[6]), inference_func=getattr(module, params[7]),
|
|
post_func=getattr(module, params[8]), classes=classes, number_synthetic_examples=int(params[10]),
|
|
number_src_data_for_one_synthetic_example=int(params[11]), path_to_src_dataset=params[12])
|
|
model_list.append(model)
|
|
# if key.startswith('ALG_'):
|
|
# params = config[key].split(' && ')
|
|
# module = importlib.import_module('Algorithms.' + params[2])
|
|
# classes = {}
|
|
# for value in params[6][1:-1].split(','):
|
|
# classes[len(classes)] = value
|
|
# alg = Algorithm(src_example=params[0], src_result=params[1], type_alg=params[2], pre_func=getattr(module, params[3]),
|
|
# inference_func=getattr(module, params[4]), post_func=getattr(module, params[5]), classes=classes,
|
|
# number_synthetic_examples=int(params[7]), number_src_data_for_one_synthetic_example=int(params[8]), path_to_src_dataset=params[9])
|
|
# alg_list.append(alg)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
|
|
def run_example():
|
|
try:
|
|
for model in model_list:
|
|
model.get_test_inference()
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
def receive_data():
|
|
try:
|
|
print()
|
|
data = request.json
|
|
if isinstance(data, str):
|
|
data = json.loads(data)
|
|
print('#' * 100)
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
result_id = Model.get_ind_inference()
|
|
freq = int(data['freq'])
|
|
print('Частота: ' + str(freq))
|
|
# print('Канал: ' + str(data['channel']))
|
|
|
|
result_msg = {}
|
|
data_to_send = {}
|
|
prediction_list = []
|
|
#print(model_list)
|
|
for model in model_list:
|
|
#print(str(freq))
|
|
#print(model.get_model_name())
|
|
if str(freq) in model.get_model_name():
|
|
print('-' * 100)
|
|
print(str(model))
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
prediction_list.append(prediction)
|
|
image_result_id, images = collect_inference_images(result_id, model.get_model_name())
|
|
send_inference_result({
|
|
'result_id': image_result_id,
|
|
'ts': time.time(),
|
|
'freq': str(freq),
|
|
'model': model.get_model_name(),
|
|
'prediction': prediction,
|
|
'probability': float(probability),
|
|
'drone_probability': float(probability) if prediction == 'drone' else 0.0,
|
|
'drone_threshold': None,
|
|
'images': images,
|
|
})
|
|
print('-' * 100)
|
|
print()
|
|
|
|
try:
|
|
|
|
result = 0
|
|
freq_int = int(freq)
|
|
prediction = prediction_list[0]
|
|
prob = float(probability)
|
|
|
|
if freq_int == 2400:
|
|
if prediction in ["drone", "drone_noise"]:
|
|
result += 0
|
|
elif prediction == "wifi" and prob >= 0.95:
|
|
result += 0
|
|
|
|
elif freq_int == 1200:
|
|
if prediction == "drone" and prob >= 0.95:
|
|
result += 8
|
|
|
|
elif freq_int == 915:
|
|
result = 0
|
|
|
|
data_to_send = {
|
|
"freq": str(freq),
|
|
"amplitude": result,
|
|
}
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
if response.status_code == 200:
|
|
print("Данные успешно отправлены!")
|
|
print("Частота: " + str(freq))
|
|
print("Отправлено светодиодов: " + str(result))
|
|
else:
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
break
|
|
|
|
Model.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
for alg in alg_list:
|
|
print('-' * 100)
|
|
print(str(alg))
|
|
alg.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
print('-' * 100)
|
|
print()
|
|
|
|
#Algorithm.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
del data
|
|
gc.collect()
|
|
|
|
return jsonify(result_msg)
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
'''
|
|
def run_flask():
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
async def process_tasks():
|
|
workers = [asyncio.create_task(worker(queue=queue, semaphore=semaphore)) for _ in range(2)]
|
|
await asyncio.gather(*workers)
|
|
|
|
|
|
async def main():
|
|
asyncio.create_task(process_tasks())
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
flask_thread.start()
|
|
|
|
while True:
|
|
if queue.qsize() <= 1:
|
|
asyncio.create_task(process_tasks())
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
def add_task():
|
|
queue_size = queue.qsize()
|
|
if queue_size > 1:
|
|
return {}
|
|
|
|
print()
|
|
data = json.loads(request.json)
|
|
print('#' * 100)
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
freq = int(data['freq'])
|
|
print('Частота ' + str(freq))
|
|
|
|
|
|
result_msg = {}
|
|
for model in model_list:
|
|
if str(freq) in model.get_model_name():
|
|
print('-' * 100)
|
|
print(str(model))
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
asyncio.run_coroutine_threadsafe(queue.put({'freq': freq, 'model': model, 'data': data}), loop)
|
|
do_inference(model=model, data=data, freq=freq)
|
|
break
|
|
|
|
del data
|
|
gc.collect()
|
|
return jsonify(result_msg)
|
|
|
|
|
|
async def worker(queue, semaphore):
|
|
while True:
|
|
task = await queue.get()
|
|
if task is None:
|
|
break
|
|
async with semaphore:
|
|
try:
|
|
await do_inference(model=task['model'], data=task['data'], freq=task['freq'])
|
|
except Exception as e:
|
|
print(str(e))
|
|
print(results)
|
|
queue.task_done()
|
|
|
|
|
|
async def do_inference(model=None, data=None, freq=0):
|
|
prediction_list = []
|
|
print("Длина очереди" + str(queue.qsize()))
|
|
inference(model=model, data=data, freq=freq)
|
|
|
|
try:
|
|
results = []
|
|
for pred in prediction_list:
|
|
if pred[1] == 'drone':
|
|
results.append([pred[0],8])
|
|
else:
|
|
results.append([pred[0],0])
|
|
for result in results:
|
|
try:
|
|
data_to_send={
|
|
'freq': result[0],
|
|
'amplitude': result[1],
|
|
'triggered': False if result[1] < 7 else True,
|
|
'light_len': result[1]
|
|
}
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
await response.text
|
|
if response.status_code == 200:
|
|
print("Данные успешно отправлены!")
|
|
print("Отправлено светодиодов: " + str(data_to_send['light_len']))
|
|
else:
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
Model.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
del data
|
|
gc.collect()
|
|
|
|
|
|
def inference(model=None, data=None, freq=0):
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
queue_size = queue.qsize()
|
|
print(queue_size)
|
|
prediction_list.append([freq, prediction])
|
|
print('-' * 100)
|
|
print()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
init_data_for_inference()
|
|
#asyncio.run(main)
|
|
loop.run_until_complete(main())
|
|
'''
|
|
|
|
def run_flask():
|
|
print(config['SERVER_IP'])
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
init_data_for_inference()
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
flask_thread.start()
|
|
|
|
#app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|