You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
317 lines
10 KiB
Python
317 lines
10 KiB
Python
from flask import Flask, request, jsonify
|
|
from dotenv import dotenv_values
|
|
from common.runtime import load_root_env, validate_env, as_int, as_str
|
|
import os
|
|
import sys
|
|
import matplotlib.pyplot as plt
|
|
from Model import Model
|
|
import numpy as np
|
|
import matplotlib
|
|
import importlib
|
|
import threading
|
|
import requests
|
|
import shutil
|
|
import json
|
|
import gc
|
|
import logging
|
|
|
|
TORCHSIG_PATH = "/app/torchsig"
|
|
if TORCHSIG_PATH not in sys.path:
|
|
sys.path.insert(0, TORCHSIG_PATH)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
app = Flask(__name__)
|
|
prediction_list = []
|
|
result_msg = {}
|
|
results = []
|
|
matplotlib.use('Agg')
|
|
plt.ioff()
|
|
|
|
alg_list = []
|
|
model_list = []
|
|
ROOT_ENV = load_root_env(__file__)
|
|
validate_env("NN_server/server.py", {
|
|
"GENERAL_SERVER_IP": as_str,
|
|
"GENERAL_SERVER_PORT": as_int,
|
|
"SERVER_IP": as_str,
|
|
"SERVER_PORT": as_int,
|
|
"SRC_RESULT": as_str,
|
|
"SRC_EXAMPLE": as_str,
|
|
"FREQS": as_str,
|
|
})
|
|
config = dict(dotenv_values(ROOT_ENV))
|
|
|
|
|
|
def get_required_drone_streak(freq):
|
|
return config.get(f"DRONE_STREAK_{freq}", "1")
|
|
|
|
|
|
def get_required_drone_prob(freq):
|
|
return config.get(f"DRONE_PROB_THRESHOLD_{freq}", config.get("DRONE_PROB_THRESHOLD_DEFAULT", "0"))
|
|
|
|
|
|
def update_drone_streak(freq, prediction, drone_probability):
|
|
required_prob = get_required_drone_prob(freq)
|
|
drone_probability = 0.0 if drone_probability is None else float(drone_probability)
|
|
passes_prob_gate = prediction == "drone" and drone_probability >= required_prob
|
|
|
|
if passes_prob_gate:
|
|
drone_streaks[freq] = drone_streaks.get(freq, 0) + 1
|
|
else:
|
|
drone_streaks[freq] = 0
|
|
|
|
required = get_required_drone_streak(freq)
|
|
triggered = passes_prob_gate and drone_streaks[freq] >= required
|
|
logging.info(
|
|
"NN alarm gate freq=%s prediction=%s drone_probability=%.3f threshold=%.3f streak=%s/%s triggered=%s",
|
|
freq,
|
|
prediction,
|
|
drone_probability,
|
|
required_prob,
|
|
drone_streaks[freq],
|
|
required,
|
|
triggered,
|
|
)
|
|
return 8 if triggered else 0
|
|
|
|
|
|
def parse_freqs(raw_value):
|
|
freqs = []
|
|
for item in (raw_value or "").split(','):
|
|
item = item.strip()
|
|
if not item:
|
|
continue
|
|
freqs.append(int(item))
|
|
if not freqs:
|
|
raise RuntimeError("[NN_server/server.py] no NN frequencies configured in FREQS")
|
|
return freqs
|
|
|
|
|
|
def parse_classes(raw_value):
|
|
if raw_value is None:
|
|
raise RuntimeError("[NN_server/server.py] model classes are missing")
|
|
value = raw_value.strip()
|
|
if value.startswith('[') and value.endswith(']'):
|
|
value = value[1:-1]
|
|
classes = {}
|
|
for class_name in value.split(','):
|
|
class_name = class_name.strip()
|
|
if class_name:
|
|
classes[len(classes)] = class_name
|
|
if not classes:
|
|
raise RuntimeError("[NN_server/server.py] no classes parsed from NN_CLASSES_*")
|
|
return classes
|
|
|
|
|
|
def get_required_config(key):
|
|
value = config.get(key)
|
|
if value is None:
|
|
raise RuntimeError(f"[NN_server/server.py] missing required env key: {key}")
|
|
value = str(value).strip()
|
|
if not value:
|
|
raise RuntimeError(f"[NN_server/server.py] empty required env key: {key}")
|
|
return value
|
|
|
|
|
|
def get_optional_config(key, default=''):
|
|
value = config.get(key)
|
|
if value is None:
|
|
return default
|
|
return str(value).strip()
|
|
|
|
|
|
def build_model_specs():
|
|
build_func_name = get_optional_config('NN_BUILD_FUNC', 'build_func_ensemble')
|
|
pre_func_name = get_optional_config('NN_PRE_FUNC', 'pre_func_ensemble')
|
|
inference_func_name = get_optional_config('NN_INFERENCE_FUNC', 'inference_func_ensemble')
|
|
post_func_name = get_optional_config('NN_POST_FUNC', 'post_func_ensemble')
|
|
src_example = get_optional_config('NN_SRC_EXAMPLE', config['SRC_EXAMPLE'])
|
|
src_result = get_optional_config('NN_SRC_RESULT', config['SRC_RESULT'])
|
|
synthetic_examples = int(get_optional_config('NN_SYNTHETIC_EXAMPLES', '0'))
|
|
synthetic_mix_count = int(get_optional_config('NN_SYNTHETIC_MIX_COUNT', '1'))
|
|
src_dataset = get_optional_config('NN_SRC_DATASET', '')
|
|
|
|
specs = []
|
|
for freq in parse_freqs(config.get('NN_FREQS', config.get('FREQS', ''))):
|
|
module_name = get_required_config(f'NN_MODEL_{freq}')
|
|
weights = get_required_config(f'NN_WEIGHTS_{freq}')
|
|
classes = parse_classes(get_required_config(f'NN_CLASSES_{freq}'))
|
|
file_config = get_optional_config(f'NN_CONFIG_{freq}', get_optional_config('NN_CONFIG', ''))
|
|
specs.append({
|
|
'freq': freq,
|
|
'module_name': module_name,
|
|
'weights': weights,
|
|
'config': file_config,
|
|
'classes': classes,
|
|
'src_example': src_example,
|
|
'src_result': src_result,
|
|
'build_func_name': build_func_name,
|
|
'pre_func_name': pre_func_name,
|
|
'inference_func_name': inference_func_name,
|
|
'post_func_name': post_func_name,
|
|
'synthetic_examples': synthetic_examples,
|
|
'synthetic_mix_count': synthetic_mix_count,
|
|
'src_dataset': src_dataset,
|
|
})
|
|
return specs
|
|
|
|
|
|
if not config:
|
|
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
|
|
logging.info("NN config loaded from %s", ROOT_ENV)
|
|
gen_server_ip = config['GENERAL_SERVER_IP']
|
|
gen_server_port = config['GENERAL_SERVER_PORT']
|
|
drone_streaks = {}
|
|
MODEL_SPECS = build_model_specs()
|
|
|
|
|
|
def recreate_directory(path):
|
|
if os.path.isdir(path):
|
|
shutil.rmtree(path)
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
def init_data_for_inference():
|
|
try:
|
|
if MODEL_SPECS:
|
|
recreate_directory(MODEL_SPECS[0]['src_result'])
|
|
recreate_directory(MODEL_SPECS[0]['src_example'])
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
try:
|
|
global model_list
|
|
model_list.clear()
|
|
for spec in MODEL_SPECS:
|
|
module = importlib.import_module('Models.' + spec['module_name'])
|
|
model = Model(
|
|
freq=spec['freq'],
|
|
file_model=spec['weights'],
|
|
file_config=spec['config'],
|
|
src_example=spec['src_example'],
|
|
src_result=spec['src_result'],
|
|
type_model=f"{spec['module_name']}@{spec['freq']}",
|
|
build_model_func=getattr(module, spec['build_func_name']),
|
|
pre_func=getattr(module, spec['pre_func_name']),
|
|
inference_func=getattr(module, spec['inference_func_name']),
|
|
post_func=getattr(module, spec['post_func_name']),
|
|
classes=spec['classes'],
|
|
number_synthetic_examples=spec['synthetic_examples'],
|
|
number_src_data_for_one_synthetic_example=spec['synthetic_mix_count'],
|
|
path_to_src_dataset=spec['src_dataset'],
|
|
)
|
|
model_list.append(model)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
print()
|
|
|
|
|
|
def run_example():
|
|
try:
|
|
for model in model_list:
|
|
model.get_test_inference()
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
def find_model_for_freq(freq):
|
|
for model in model_list:
|
|
if model.get_freq() == freq:
|
|
return model
|
|
return None
|
|
|
|
|
|
@app.route('/receive_data', methods=['POST'])
|
|
def receive_data():
|
|
try:
|
|
print()
|
|
data = json.loads(request.json)
|
|
print('#' * 100)
|
|
print('Получен пакет ' + str(Model.get_ind_inference()))
|
|
freq = int(data['freq'])
|
|
print('Частота: ' + str(freq))
|
|
|
|
result_msg = {}
|
|
data_to_send = {}
|
|
prediction_list = []
|
|
model = find_model_for_freq(freq)
|
|
if model is None:
|
|
raise RuntimeError(f"No NN model configured for freq={freq}")
|
|
|
|
print('-' * 100)
|
|
print(str(model))
|
|
result_msg[str(model.get_model_name())] = {'freq': freq}
|
|
inference_result = model.get_inference([
|
|
np.asarray(data['data_real'], dtype=np.float32),
|
|
np.asarray(data['data_imag'], dtype=np.float32),
|
|
])
|
|
if inference_result is None:
|
|
raise RuntimeError(f"Inference failed for {model.get_model_name()}")
|
|
prediction, probability = inference_result[:2]
|
|
drone_probability = float(probability) if prediction == "drone" else 0.0
|
|
result_msg[str(model.get_model_name())]['prediction'] = prediction
|
|
result_msg[str(model.get_model_name())]['probability'] = str(probability)
|
|
result_msg[str(model.get_model_name())]['drone_probability'] = str(drone_probability)
|
|
result_msg[str(model.get_model_name())]['drone_threshold'] = str(get_required_drone_prob(freq))
|
|
prediction_list.append(prediction)
|
|
print('-' * 100)
|
|
print()
|
|
|
|
try:
|
|
result = update_drone_streak(freq, prediction, drone_probability)
|
|
data_to_send = {
|
|
'freq': str(freq),
|
|
'amplitude': result,
|
|
}
|
|
response = requests.post(
|
|
"http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port),
|
|
json=data_to_send,
|
|
)
|
|
if response.status_code == 200:
|
|
print("Данные успешно отправлены!")
|
|
print("Частота: " + str(freq))
|
|
print("Отправлено светодиодов: " + str(result))
|
|
else:
|
|
print("Ошибка при отправке данных: ", response.status_code)
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
Model.get_inc_ind_inference()
|
|
print()
|
|
print('#' * 100)
|
|
|
|
for alg in alg_list:
|
|
print('-' * 100)
|
|
print(str(alg))
|
|
alg.get_inference([
|
|
np.asarray(data['data_real'], dtype=np.float32),
|
|
np.asarray(data['data_imag'], dtype=np.float32),
|
|
])
|
|
print('-' * 100)
|
|
print()
|
|
|
|
print()
|
|
print('#' * 100)
|
|
|
|
del data
|
|
gc.collect()
|
|
|
|
return jsonify(result_msg)
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
|
|
|
|
def run_flask():
|
|
print(config['SERVER_IP'])
|
|
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
init_data_for_inference()
|
|
|
|
flask_thread = threading.Thread(target=run_flask)
|
|
flask_thread.start()
|