You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
313 lines
11 KiB
Python
313 lines
11 KiB
Python
from flask import Flask, request, jsonify
|
|
from dotenv import dotenv_values
|
|
from common.runtime import load_root_env, validate_env, as_int, as_str
|
|
import matplotlib.pyplot as plt
|
|
from Model import Model
|
|
import numpy as np
|
|
import matplotlib
|
|
import importlib
|
|
import threading
|
|
import requests
|
|
import asyncio
|
|
import shutil
|
|
import json
|
|
import gc
|
|
import os
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
app = Flask(__name__)
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
queue = asyncio.Queue()
|
|
semaphore = asyncio.Semaphore(3)
|
|
|
|
prediction_list = []
|
|
result_msg = {}
|
|
results = []
|
|
matplotlib.use('Agg')
|
|
plt.ioff()
|
|
|
|
alg_list = []
|
|
model_list = []
|
|
ROOT_ENV = load_root_env(__file__)
|
|
validate_env("NN_server/server.py", {
|
|
"GENERAL_SERVER_IP": as_str,
|
|
"GENERAL_SERVER_PORT": as_int,
|
|
"SERVER_IP": as_str,
|
|
"SERVER_PORT": as_int,
|
|
"PATH_TO_NN": as_str,
|
|
"SRC_RESULT": as_str,
|
|
"SRC_EXAMPLE": as_str,
|
|
})
|
|
config = dict(dotenv_values(ROOT_ENV))
|
|
|
|
if not config:
|
|
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
|
|
if not any(key.startswith("NN_") for key in config):
|
|
raise RuntimeError("[NN_server/server.py] no NN_* model entries configured")
|
|
logging.info("NN config loaded from %s", ROOT_ENV)
|
|
gen_server_ip = config['GENERAL_SERVER_IP']
|
|
gen_server_port = config['GENERAL_SERVER_PORT']
|
|
|
|
def init_data_for_inference():
|
|
try:
|
|
if os.path.isdir(config['SRC_RESULT']):
|
|
shutil.rmtree(config['SRC_RESULT'])
|
|
os.mkdir(config['SRC_RESULT'])
|
|
if os.path.isdir(config['SRC_EXAMPLE']):
|
|
shutil.rmtree(config['SRC_EXAMPLE'])
|
|
os.mkdir(config['SRC_EXAMPLE'])
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
try:
|
|
global model_list
|
|
for key in config.keys():
|
|
if key.startswith('NN_'):
|
|
params = config[key].split(' && ')
|
|
module = importlib.import_module('Models.' + params[4])
|
|
classes = {}
|
|
for value in params[9][1:-1].split(','):
|
|
classes[len(classes)] = value
|
|
model = Model(file_model=params[0], file_config=params[1], src_example=params[2], src_result=params[3],
|
|
type_model=params[4], build_model_func=getattr(module, params[5]),
|
|
pre_func=getattr(module, params[6]), inference_func=getattr(module, params[7]),
|
|
post_func=getattr(module, params[8]), classes=classes, number_synthetic_examples=int(params[10]),
|
|
number_src_data_for_one_synthetic_example=int(params[11]), path_to_src_dataset=params[12])
|
|
model_list.append(model)
|
|
# if key.startswith('ALG_'):
|
|
# params = config[key].split(' && ')
|
|
# module = importlib.import_module('Algorithms.' + params[2])
|
|
# classes = {}
|
|
# for value in params[6][1:-1].split(','):
|
|
# classes[len(classes)] = value
|
|
# alg = Algorithm(src_example=params[0], src_result=params[1], type_alg=params[2], pre_func=getattr(module, params[3]),
|
|
# inference_func=getattr(module, params[4]), post_func=getattr(module, params[5]), classes=classes,
|
|
# number_synthetic_examples=int(params[7]), number_src_data_for_one_synthetic_example=int(params[8]), path_to_src_dataset=params[9])
|
|
# alg_list.append(alg)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
|
|
def run_example():
|
|
try:
|
|
for model in model_list:
|
|
model.get_test_inference()
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
def receive_data():
|
|
try:
|
|
print()
|
|
data = json.loads(request.json)
|
|
print('#' * 100)
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
freq = int(data['freq'])
|
|
print('Частота: ' + str(freq))
|
|
# print('Канал: ' + str(data['channel']))
|
|
|
|
result_msg = {}
|
|
data_to_send = {}
|
|
prediction_list = []
|
|
for model in model_list:
|
|
if str(freq) in model.get_model_name():
|
|
print('-' * 100)
|
|
print(str(model))
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
prediction_list.append(prediction)
|
|
print('-' * 100)
|
|
print()
|
|
|
|
try:
|
|
result = 0
|
|
if (int(freq) == 2400 and (prediction_list[0] in ['drone', 'drone_noise'] or (prediction_list[0] == 'wifi' and float(probability) >= 0.95))) or (int(freq) == 1200 and (prediction_list[0] in ['drone'] and float(probability) >= 0.95)):
|
|
result += 8
|
|
if int(freq) in [915]:
|
|
result = 0
|
|
if int(freq) in []:
|
|
result = 8
|
|
data_to_send={
|
|
'freq': str(freq),
|
|
'amplitude': result
|
|
#'triggered': False if result < 7 else True,
|
|
#'light_len': result
|
|
}
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
if response.status_code == 200:
|
|
print("Данные успешно отправлены!")
|
|
print("Частота: " + str(freq))
|
|
print("Отправлено светодиодов: " + str(result))
|
|
else:
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
break
|
|
|
|
Model.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
for alg in alg_list:
|
|
print('-' * 100)
|
|
print(str(alg))
|
|
alg.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
print('-' * 100)
|
|
print()
|
|
|
|
#Algorithm.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
del data
|
|
gc.collect()
|
|
|
|
return jsonify(result_msg)
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
'''
|
|
def run_flask():
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
async def process_tasks():
|
|
workers = [asyncio.create_task(worker(queue=queue, semaphore=semaphore)) for _ in range(2)]
|
|
await asyncio.gather(*workers)
|
|
|
|
|
|
async def main():
|
|
asyncio.create_task(process_tasks())
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
flask_thread.start()
|
|
|
|
while True:
|
|
if queue.qsize() <= 1:
|
|
asyncio.create_task(process_tasks())
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
def add_task():
|
|
queue_size = queue.qsize()
|
|
if queue_size > 1:
|
|
return {}
|
|
|
|
print()
|
|
data = json.loads(request.json)
|
|
print('#' * 100)
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
freq = int(data['freq'])
|
|
print('Частота ' + str(freq))
|
|
|
|
|
|
result_msg = {}
|
|
for model in model_list:
|
|
if str(freq) in model.get_model_name():
|
|
print('-' * 100)
|
|
print(str(model))
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
asyncio.run_coroutine_threadsafe(queue.put({'freq': freq, 'model': model, 'data': data}), loop)
|
|
do_inference(model=model, data=data, freq=freq)
|
|
break
|
|
|
|
del data
|
|
gc.collect()
|
|
return jsonify(result_msg)
|
|
|
|
|
|
async def worker(queue, semaphore):
|
|
while True:
|
|
task = await queue.get()
|
|
if task is None:
|
|
break
|
|
async with semaphore:
|
|
try:
|
|
await do_inference(model=task['model'], data=task['data'], freq=task['freq'])
|
|
except Exception as e:
|
|
print(str(e))
|
|
print(results)
|
|
queue.task_done()
|
|
|
|
|
|
async def do_inference(model=None, data=None, freq=0):
|
|
prediction_list = []
|
|
print("Длина очереди" + str(queue.qsize()))
|
|
inference(model=model, data=data, freq=freq)
|
|
|
|
try:
|
|
results = []
|
|
for pred in prediction_list:
|
|
if pred[1] == 'drone':
|
|
results.append([pred[0],8])
|
|
else:
|
|
results.append([pred[0],0])
|
|
for result in results:
|
|
try:
|
|
data_to_send={
|
|
'freq': result[0],
|
|
'amplitude': result[1],
|
|
'triggered': False if result[1] < 7 else True,
|
|
'light_len': result[1]
|
|
}
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
await response.text
|
|
if response.status_code == 200:
|
|
print("Данные успешно отправлены!")
|
|
print("Отправлено светодиодов: " + str(data_to_send['light_len']))
|
|
else:
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
Model.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
del data
|
|
gc.collect()
|
|
|
|
|
|
def inference(model=None, data=None, freq=0):
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
queue_size = queue.qsize()
|
|
print(queue_size)
|
|
prediction_list.append([freq, prediction])
|
|
print('-' * 100)
|
|
print()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
init_data_for_inference()
|
|
#asyncio.run(main)
|
|
loop.run_until_complete(main())
|
|
'''
|
|
|
|
def run_flask():
|
|
print(config['SERVER_IP'])
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
init_data_for_inference()
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
flask_thread.start()
|
|
|
|
#app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|