|
|
|
|
@ -3,6 +3,7 @@ from dotenv import dotenv_values
|
|
|
|
|
from common.runtime import load_root_env, validate_env, as_int, as_str
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import re
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
from Model import Model
|
|
|
|
|
import numpy as np
|
|
|
|
|
@ -10,19 +11,25 @@ import matplotlib
|
|
|
|
|
import importlib
|
|
|
|
|
import threading
|
|
|
|
|
import requests
|
|
|
|
|
import asyncio
|
|
|
|
|
import shutil
|
|
|
|
|
import json
|
|
|
|
|
import gc
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
TORCHSIG_PATH = "/app/torchsig"
|
|
|
|
|
if TORCHSIG_PATH not in sys.path:
|
|
|
|
|
# Ensure import torchsig resolves to /app/torchsig/torchsig package.
|
|
|
|
|
sys.path.insert(0, TORCHSIG_PATH)
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
queue = asyncio.Queue()
|
|
|
|
|
semaphore = asyncio.Semaphore(3)
|
|
|
|
|
|
|
|
|
|
prediction_list = []
|
|
|
|
|
result_msg = {}
|
|
|
|
|
results = []
|
|
|
|
|
@ -37,39 +44,38 @@ validate_env("NN_server/server.py", {
|
|
|
|
|
"GENERAL_SERVER_PORT": as_int,
|
|
|
|
|
"SERVER_IP": as_str,
|
|
|
|
|
"SERVER_PORT": as_int,
|
|
|
|
|
"PATH_TO_NN": as_str,
|
|
|
|
|
"SRC_RESULT": as_str,
|
|
|
|
|
"SRC_EXAMPLE": as_str,
|
|
|
|
|
"FREQS": as_str,
|
|
|
|
|
})
|
|
|
|
|
config = dict(dotenv_values(ROOT_ENV))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_required_drone_streak(freq):
|
|
|
|
|
return config.get(f"DRONE_STREAK_{freq}", "1")
|
|
|
|
|
|
|
|
|
|
def is_model_config_key(key, value):
|
|
|
|
|
return bool(re.fullmatch(r"NN_\d+", key or "")) and isinstance(value, str) and " && " in value
|
|
|
|
|
|
|
|
|
|
def get_required_drone_prob(freq):
|
|
|
|
|
return config.get(f"DRONE_PROB_THRESHOLD_{freq}", config.get("DRONE_PROB_THRESHOLD_DEFAULT", "0"))
|
|
|
|
|
|
|
|
|
|
def get_required_drone_streak(freq):
|
|
|
|
|
raw_value = config.get(f"DRONE_STREAK_{freq}", "1")
|
|
|
|
|
try:
|
|
|
|
|
return max(1, int(raw_value))
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
logging.warning("Invalid DRONE_STREAK_%s=%r, falling back to 1", freq, raw_value)
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
def update_drone_streak(freq, prediction, drone_probability):
|
|
|
|
|
required_prob = get_required_drone_prob(freq)
|
|
|
|
|
drone_probability = 0.0 if drone_probability is None else float(drone_probability)
|
|
|
|
|
passes_prob_gate = prediction == "drone" and drone_probability >= required_prob
|
|
|
|
|
|
|
|
|
|
if passes_prob_gate:
|
|
|
|
|
def update_drone_streak(freq, prediction):
|
|
|
|
|
if prediction == "drone":
|
|
|
|
|
drone_streaks[freq] = drone_streaks.get(freq, 0) + 1
|
|
|
|
|
else:
|
|
|
|
|
drone_streaks[freq] = 0
|
|
|
|
|
|
|
|
|
|
required = get_required_drone_streak(freq)
|
|
|
|
|
triggered = passes_prob_gate and drone_streaks[freq] >= required
|
|
|
|
|
triggered = prediction == "drone" and drone_streaks[freq] >= required
|
|
|
|
|
logging.info(
|
|
|
|
|
"NN alarm gate freq=%s prediction=%s drone_probability=%.3f threshold=%.3f streak=%s/%s triggered=%s",
|
|
|
|
|
"NN alarm gate freq=%s prediction=%s streak=%s/%s triggered=%s",
|
|
|
|
|
freq,
|
|
|
|
|
prediction,
|
|
|
|
|
drone_probability,
|
|
|
|
|
required_prob,
|
|
|
|
|
drone_streaks[freq],
|
|
|
|
|
required,
|
|
|
|
|
triggered,
|
|
|
|
|
@ -77,171 +83,52 @@ def update_drone_streak(freq, prediction, drone_probability):
|
|
|
|
|
return 8 if triggered else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_freqs(raw_value):
|
|
|
|
|
freqs = []
|
|
|
|
|
for item in (raw_value or "").split(','):
|
|
|
|
|
item = item.strip()
|
|
|
|
|
if not item:
|
|
|
|
|
continue
|
|
|
|
|
freqs.append(int(item))
|
|
|
|
|
if not freqs:
|
|
|
|
|
raise RuntimeError("[NN_server/server.py] no NN frequencies configured in FREQS")
|
|
|
|
|
return freqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_classes(raw_value):
|
|
|
|
|
if raw_value is None:
|
|
|
|
|
raise RuntimeError("[NN_server/server.py] model classes are missing")
|
|
|
|
|
value = raw_value.strip()
|
|
|
|
|
if value.startswith('[') and value.endswith(']'):
|
|
|
|
|
value = value[1:-1]
|
|
|
|
|
classes = {}
|
|
|
|
|
for class_name in value.split(','):
|
|
|
|
|
class_name = class_name.strip()
|
|
|
|
|
if class_name:
|
|
|
|
|
classes[len(classes)] = class_name
|
|
|
|
|
if not classes:
|
|
|
|
|
raise RuntimeError("[NN_server/server.py] no classes parsed from NN_CLASSES_*")
|
|
|
|
|
return classes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_required_config(key):
|
|
|
|
|
value = config.get(key)
|
|
|
|
|
if value is None:
|
|
|
|
|
raise RuntimeError(f"[NN_server/server.py] missing required env key: {key}")
|
|
|
|
|
value = str(value).strip()
|
|
|
|
|
if not value:
|
|
|
|
|
raise RuntimeError(f"[NN_server/server.py] empty required env key: {key}")
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_optional_config(key, default=''):
|
|
|
|
|
value = config.get(key)
|
|
|
|
|
if value is None:
|
|
|
|
|
return default
|
|
|
|
|
return str(value).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_model_specs():
|
|
|
|
|
build_func_name = get_optional_config('NN_BUILD_FUNC', 'build_func_ensemble')
|
|
|
|
|
pre_func_name = get_optional_config('NN_PRE_FUNC', 'pre_func_ensemble')
|
|
|
|
|
inference_func_name = get_optional_config('NN_INFERENCE_FUNC', 'inference_func_ensemble')
|
|
|
|
|
post_func_name = get_optional_config('NN_POST_FUNC', 'post_func_ensemble')
|
|
|
|
|
src_example = get_optional_config('NN_SRC_EXAMPLE', config['SRC_EXAMPLE'])
|
|
|
|
|
src_result = get_optional_config('NN_SRC_RESULT', config['SRC_RESULT'])
|
|
|
|
|
synthetic_examples = int(get_optional_config('NN_SYNTHETIC_EXAMPLES', '0'))
|
|
|
|
|
synthetic_mix_count = int(get_optional_config('NN_SYNTHETIC_MIX_COUNT', '1'))
|
|
|
|
|
src_dataset = get_optional_config('NN_SRC_DATASET', '')
|
|
|
|
|
|
|
|
|
|
specs = []
|
|
|
|
|
for freq in parse_freqs(config.get('NN_FREQS', config.get('FREQS', ''))):
|
|
|
|
|
module_name = get_required_config(f'NN_MODEL_{freq}')
|
|
|
|
|
weights = get_required_config(f'NN_WEIGHTS_{freq}')
|
|
|
|
|
classes = parse_classes(get_required_config(f'NN_CLASSES_{freq}'))
|
|
|
|
|
file_config = get_optional_config(f'NN_CONFIG_{freq}', get_optional_config('NN_CONFIG', ''))
|
|
|
|
|
specs.append({
|
|
|
|
|
'freq': freq,
|
|
|
|
|
'module_name': module_name,
|
|
|
|
|
'weights': weights,
|
|
|
|
|
'config': file_config,
|
|
|
|
|
'classes': classes,
|
|
|
|
|
'src_example': src_example,
|
|
|
|
|
'src_result': src_result,
|
|
|
|
|
'build_func_name': build_func_name,
|
|
|
|
|
'pre_func_name': pre_func_name,
|
|
|
|
|
'inference_func_name': inference_func_name,
|
|
|
|
|
'post_func_name': post_func_name,
|
|
|
|
|
'synthetic_examples': synthetic_examples,
|
|
|
|
|
'synthetic_mix_count': synthetic_mix_count,
|
|
|
|
|
'src_dataset': src_dataset,
|
|
|
|
|
})
|
|
|
|
|
return specs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not config:
|
|
|
|
|
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
|
|
|
|
|
if not any(is_model_config_key(key, value) for key, value in config.items()):
|
|
|
|
|
raise RuntimeError("[NN_server/server.py] no NN_* model entries configured")
|
|
|
|
|
logging.info("NN config loaded from %s", ROOT_ENV)
|
|
|
|
|
gen_server_ip = config['GENERAL_SERVER_IP']
|
|
|
|
|
gen_server_port = config['GENERAL_SERVER_PORT']
|
|
|
|
|
drone_streaks = {}
|
|
|
|
|
MODEL_SPECS = build_model_specs()
|
|
|
|
|
INFERENCE_TELEMETRY_HOST = os.getenv('telemetry_host', '127.0.0.1')
|
|
|
|
|
INFERENCE_TELEMETRY_PORT = os.getenv('telemetry_port', '5020')
|
|
|
|
|
INFERENCE_TELEMETRY_ENDPOINT = os.getenv('telemetry_inference_endpoint', 'inference/result')
|
|
|
|
|
INFERENCE_TELEMETRY_TIMEOUT_SEC = float(os.getenv('telemetry_inference_timeout_sec', '0.30'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def recreate_directory(path):
|
|
|
|
|
if os.path.isdir(path):
|
|
|
|
|
shutil.rmtree(path)
|
|
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_result_dir():
|
|
|
|
|
if not MODEL_SPECS:
|
|
|
|
|
return ''
|
|
|
|
|
return MODEL_SPECS[0]['src_result']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collect_inference_images(result_id):
|
|
|
|
|
result_dir = get_result_dir()
|
|
|
|
|
if not result_dir or not os.path.isdir(result_dir):
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
needle = f"_inference_{result_id}_"
|
|
|
|
|
images = []
|
|
|
|
|
for name in sorted(os.listdir(result_dir)):
|
|
|
|
|
if needle in name and name.endswith('.png'):
|
|
|
|
|
images.append(name)
|
|
|
|
|
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def send_inference_result(payload):
|
|
|
|
|
try:
|
|
|
|
|
requests.post(
|
|
|
|
|
"http://{0}:{1}/{2}".format(
|
|
|
|
|
INFERENCE_TELEMETRY_HOST,
|
|
|
|
|
INFERENCE_TELEMETRY_PORT,
|
|
|
|
|
INFERENCE_TELEMETRY_ENDPOINT.lstrip('/'),
|
|
|
|
|
),
|
|
|
|
|
json=payload,
|
|
|
|
|
timeout=INFERENCE_TELEMETRY_TIMEOUT_SEC,
|
|
|
|
|
)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(str(exc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_data_for_inference():
|
|
|
|
|
try:
|
|
|
|
|
if MODEL_SPECS:
|
|
|
|
|
recreate_directory(MODEL_SPECS[0]['src_result'])
|
|
|
|
|
recreate_directory(MODEL_SPECS[0]['src_example'])
|
|
|
|
|
if os.path.isdir(config['SRC_RESULT']):
|
|
|
|
|
shutil.rmtree(config['SRC_RESULT'])
|
|
|
|
|
os.mkdir(config['SRC_RESULT'])
|
|
|
|
|
if os.path.isdir(config['SRC_EXAMPLE']):
|
|
|
|
|
shutil.rmtree(config['SRC_EXAMPLE'])
|
|
|
|
|
os.mkdir(config['SRC_EXAMPLE'])
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(str(exc))
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
global model_list
|
|
|
|
|
model_list.clear()
|
|
|
|
|
for spec in MODEL_SPECS:
|
|
|
|
|
module = importlib.import_module('Models.' + spec['module_name'])
|
|
|
|
|
model = Model(
|
|
|
|
|
freq=spec['freq'],
|
|
|
|
|
file_model=spec['weights'],
|
|
|
|
|
file_config=spec['config'],
|
|
|
|
|
src_example=spec['src_example'],
|
|
|
|
|
src_result=spec['src_result'],
|
|
|
|
|
type_model=f"{spec['module_name']}@{spec['freq']}",
|
|
|
|
|
build_model_func=getattr(module, spec['build_func_name']),
|
|
|
|
|
pre_func=getattr(module, spec['pre_func_name']),
|
|
|
|
|
inference_func=getattr(module, spec['inference_func_name']),
|
|
|
|
|
post_func=getattr(module, spec['post_func_name']),
|
|
|
|
|
classes=spec['classes'],
|
|
|
|
|
number_synthetic_examples=spec['synthetic_examples'],
|
|
|
|
|
number_src_data_for_one_synthetic_example=spec['synthetic_mix_count'],
|
|
|
|
|
path_to_src_dataset=spec['src_dataset'],
|
|
|
|
|
)
|
|
|
|
|
model_list.append(model)
|
|
|
|
|
for key, value in config.items():
|
|
|
|
|
if is_model_config_key(key, value):
|
|
|
|
|
params = value.split(' && ')
|
|
|
|
|
module = importlib.import_module('Models.' + params[4])
|
|
|
|
|
classes = {}
|
|
|
|
|
for value in params[9][1:-1].split(','):
|
|
|
|
|
classes[len(classes)] = value
|
|
|
|
|
model = Model(file_model=params[0], file_config=params[1], src_example=params[2], src_result=params[3],
|
|
|
|
|
type_model=params[4], build_model_func=getattr(module, params[5]),
|
|
|
|
|
pre_func=getattr(module, params[6]), inference_func=getattr(module, params[7]),
|
|
|
|
|
post_func=getattr(module, params[8]), classes=classes, number_synthetic_examples=int(params[10]),
|
|
|
|
|
number_src_data_for_one_synthetic_example=int(params[11]), path_to_src_dataset=params[12])
|
|
|
|
|
model_list.append(model)
|
|
|
|
|
# if key.startswith('ALG_'):
|
|
|
|
|
# params = config[key].split(' && ')
|
|
|
|
|
# module = importlib.import_module('Algorithms.' + params[2])
|
|
|
|
|
# classes = {}
|
|
|
|
|
# for value in params[6][1:-1].split(','):
|
|
|
|
|
# classes[len(classes)] = value
|
|
|
|
|
# alg = Algorithm(src_example=params[0], src_result=params[1], type_alg=params[2], pre_func=getattr(module, params[3]),
|
|
|
|
|
# inference_func=getattr(module, params[4]), post_func=getattr(module, params[5]), classes=classes,
|
|
|
|
|
# number_synthetic_examples=int(params[7]), number_src_data_for_one_synthetic_example=int(params[8]), path_to_src_dataset=params[9])
|
|
|
|
|
# alg_list.append(alg)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(str(exc))
|
|
|
|
|
print()
|
|
|
|
|
@ -255,83 +142,54 @@ def run_example():
|
|
|
|
|
print(str(exc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_model_for_freq(freq):
|
|
|
|
|
for model in model_list:
|
|
|
|
|
if model.get_freq() == freq:
|
|
|
|
|
return model
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
|
|
|
def receive_data():
|
|
|
|
|
try:
|
|
|
|
|
print()
|
|
|
|
|
data = request.json
|
|
|
|
|
if isinstance(data, str):
|
|
|
|
|
data = json.loads(data)
|
|
|
|
|
data = json.loads(request.json)
|
|
|
|
|
print('#' * 100)
|
|
|
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
|
|
|
result_id = Model.get_ind_inference()
|
|
|
|
|
freq = int(data['freq'])
|
|
|
|
|
print('Частота: ' + str(freq))
|
|
|
|
|
# print('Канал: ' + str(data['channel']))
|
|
|
|
|
|
|
|
|
|
result_msg = {}
|
|
|
|
|
data_to_send = {}
|
|
|
|
|
prediction_list = []
|
|
|
|
|
model = find_model_for_freq(freq)
|
|
|
|
|
if model is None:
|
|
|
|
|
raise RuntimeError(f"No NN model configured for freq={freq}")
|
|
|
|
|
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print(str(model))
|
|
|
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
|
|
|
inference_result = model.get_inference([
|
|
|
|
|
np.asarray(data['data_real'], dtype=np.float32),
|
|
|
|
|
np.asarray(data['data_imag'], dtype=np.float32),
|
|
|
|
|
])
|
|
|
|
|
if inference_result is None:
|
|
|
|
|
raise RuntimeError(f"Inference failed for {model.get_model_name()}")
|
|
|
|
|
prediction, probability = inference_result[:2]
|
|
|
|
|
drone_probability = float(probability) if prediction == "drone" else 0.0
|
|
|
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
|
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
|
|
|
result_msg[str(model.get_model_name())]['drone_probability'] = str(drone_probability)
|
|
|
|
|
result_msg[str(model.get_model_name())]['drone_threshold'] = str(get_required_drone_prob(freq))
|
|
|
|
|
prediction_list.append(prediction)
|
|
|
|
|
inference_event = {
|
|
|
|
|
'result_id': result_id,
|
|
|
|
|
'ts': time.time(),
|
|
|
|
|
'freq': str(freq),
|
|
|
|
|
'model': model.get_model_name(),
|
|
|
|
|
'prediction': prediction,
|
|
|
|
|
'probability': float(probability),
|
|
|
|
|
'drone_probability': drone_probability,
|
|
|
|
|
'drone_threshold': str(get_required_drone_prob(freq)),
|
|
|
|
|
'images': collect_inference_images(result_id),
|
|
|
|
|
}
|
|
|
|
|
send_inference_result(inference_event)
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = update_drone_streak(freq, prediction, drone_probability)
|
|
|
|
|
data_to_send = {
|
|
|
|
|
'freq': str(freq),
|
|
|
|
|
'amplitude': result,
|
|
|
|
|
}
|
|
|
|
|
response = requests.post(
|
|
|
|
|
"http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port),
|
|
|
|
|
json=data_to_send,
|
|
|
|
|
)
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
print("Данные успешно отправлены!")
|
|
|
|
|
print("Частота: " + str(freq))
|
|
|
|
|
print("Отправлено светодиодов: " + str(result))
|
|
|
|
|
else:
|
|
|
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(str(exc))
|
|
|
|
|
|
|
|
|
|
#print(model_list)
|
|
|
|
|
for model in model_list:
|
|
|
|
|
#print(str(freq))
|
|
|
|
|
#print(model.get_model_name())
|
|
|
|
|
if str(freq) in model.get_model_name():
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print(str(model))
|
|
|
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
|
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
|
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
|
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
|
|
|
prediction_list.append(prediction)
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = update_drone_streak(freq, prediction_list[0])
|
|
|
|
|
data_to_send={
|
|
|
|
|
'freq': str(freq),
|
|
|
|
|
'amplitude': result
|
|
|
|
|
#'triggered': False if result < 7 else True,
|
|
|
|
|
#'light_len': result
|
|
|
|
|
}
|
|
|
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
print("Данные успешно отправлены!")
|
|
|
|
|
print("Частота: " + str(freq))
|
|
|
|
|
print("Отправлено светодиодов: " + str(result))
|
|
|
|
|
else:
|
|
|
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(str(exc))
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
Model.get_inc_ind_inference()
|
|
|
|
|
print()
|
|
|
|
|
print('#' * 100)
|
|
|
|
|
@ -339,16 +197,14 @@ def receive_data():
|
|
|
|
|
for alg in alg_list:
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print(str(alg))
|
|
|
|
|
alg.get_inference([
|
|
|
|
|
np.asarray(data['data_real'], dtype=np.float32),
|
|
|
|
|
np.asarray(data['data_imag'], dtype=np.float32),
|
|
|
|
|
])
|
|
|
|
|
alg.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
#Algorithm.get_inc_ind_inference()
|
|
|
|
|
print()
|
|
|
|
|
print('#' * 100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del data
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
@ -358,6 +214,128 @@ def receive_data():
|
|
|
|
|
print(str(exc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def run_flask():
|
|
|
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_tasks():
|
|
|
|
|
workers = [asyncio.create_task(worker(queue=queue, semaphore=semaphore)) for _ in range(2)]
|
|
|
|
|
await asyncio.gather(*workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
|
|
asyncio.create_task(process_tasks())
|
|
|
|
|
|
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
|
|
|
flask_thread.start()
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
if queue.qsize() <= 1:
|
|
|
|
|
asyncio.create_task(process_tasks())
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
|
|
|
def add_task():
|
|
|
|
|
queue_size = queue.qsize()
|
|
|
|
|
if queue_size > 1:
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
print()
|
|
|
|
|
data = json.loads(request.json)
|
|
|
|
|
print('#' * 100)
|
|
|
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
|
|
|
freq = int(data['freq'])
|
|
|
|
|
print('Частота ' + str(freq))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_msg = {}
|
|
|
|
|
for model in model_list:
|
|
|
|
|
if str(freq) in model.get_model_name():
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print(str(model))
|
|
|
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
|
|
|
asyncio.run_coroutine_threadsafe(queue.put({'freq': freq, 'model': model, 'data': data}), loop)
|
|
|
|
|
do_inference(model=model, data=data, freq=freq)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
del data
|
|
|
|
|
gc.collect()
|
|
|
|
|
return jsonify(result_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def worker(queue, semaphore):
|
|
|
|
|
while True:
|
|
|
|
|
task = await queue.get()
|
|
|
|
|
if task is None:
|
|
|
|
|
break
|
|
|
|
|
async with semaphore:
|
|
|
|
|
try:
|
|
|
|
|
await do_inference(model=task['model'], data=task['data'], freq=task['freq'])
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(str(e))
|
|
|
|
|
print(results)
|
|
|
|
|
queue.task_done()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def do_inference(model=None, data=None, freq=0):
|
|
|
|
|
prediction_list = []
|
|
|
|
|
print("Длина очереди" + str(queue.qsize()))
|
|
|
|
|
inference(model=model, data=data, freq=freq)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
results = []
|
|
|
|
|
for pred in prediction_list:
|
|
|
|
|
if pred[1] == 'drone':
|
|
|
|
|
results.append([pred[0],8])
|
|
|
|
|
else:
|
|
|
|
|
results.append([pred[0],0])
|
|
|
|
|
for result in results:
|
|
|
|
|
try:
|
|
|
|
|
data_to_send={
|
|
|
|
|
'freq': result[0],
|
|
|
|
|
'amplitude': result[1],
|
|
|
|
|
'triggered': False if result[1] < 7 else True,
|
|
|
|
|
'light_len': result[1]
|
|
|
|
|
}
|
|
|
|
|
response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send)
|
|
|
|
|
await response.text
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
print("Данные успешно отправлены!")
|
|
|
|
|
print("Отправлено светодиодов: " + str(data_to_send['light_len']))
|
|
|
|
|
else:
|
|
|
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(str(exc))
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(str(exc))
|
|
|
|
|
|
|
|
|
|
Model.get_inc_ind_inference()
|
|
|
|
|
print()
|
|
|
|
|
print('#' * 100)
|
|
|
|
|
|
|
|
|
|
del data
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(model=None, data=None, freq=0):
|
|
|
|
|
prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)])
|
|
|
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
|
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
|
|
|
queue_size = queue.qsize()
|
|
|
|
|
print(queue_size)
|
|
|
|
|
prediction_list.append([freq, prediction])
|
|
|
|
|
print('-' * 100)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
init_data_for_inference()
|
|
|
|
|
#asyncio.run(main)
|
|
|
|
|
loop.run_until_complete(main())
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def run_flask():
|
|
|
|
|
print(config['SERVER_IP'])
|
|
|
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
@ -368,3 +346,5 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
|
|
|
flask_thread.start()
|
|
|
|
|
|
|
|
|
|
#app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|