You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/NN_server/server.py

317 lines
10 KiB
Python

from flask import Flask, request, jsonify
from dotenv import dotenv_values
from common.runtime import load_root_env, validate_env, as_int, as_str
import os
import sys
import matplotlib.pyplot as plt
from Model import Model
import numpy as np
import matplotlib
import importlib
import threading
import requests
import shutil
import json
import gc
import logging
TORCHSIG_PATH = "/app/torchsig"
if TORCHSIG_PATH not in sys.path:
sys.path.insert(0, TORCHSIG_PATH)
logging.basicConfig(level=logging.INFO)
app = Flask(__name__)
prediction_list = []
result_msg = {}
results = []
matplotlib.use('Agg')
plt.ioff()
alg_list = []
model_list = []
ROOT_ENV = load_root_env(__file__)
validate_env("NN_server/server.py", {
"GENERAL_SERVER_IP": as_str,
"GENERAL_SERVER_PORT": as_int,
"SERVER_IP": as_str,
"SERVER_PORT": as_int,
"SRC_RESULT": as_str,
"SRC_EXAMPLE": as_str,
"FREQS": as_str,
})
config = dict(dotenv_values(ROOT_ENV))
def get_required_drone_streak(freq):
return config.get(f"DRONE_STREAK_{freq}", "1")
def get_required_drone_prob(freq):
return config.get(f"DRONE_PROB_THRESHOLD_{freq}", config.get("DRONE_PROB_THRESHOLD_DEFAULT", "0"))
def update_drone_streak(freq, prediction, drone_probability):
required_prob = get_required_drone_prob(freq)
drone_probability = 0.0 if drone_probability is None else float(drone_probability)
passes_prob_gate = prediction == "drone" and drone_probability >= required_prob
if passes_prob_gate:
drone_streaks[freq] = drone_streaks.get(freq, 0) + 1
else:
drone_streaks[freq] = 0
required = get_required_drone_streak(freq)
triggered = passes_prob_gate and drone_streaks[freq] >= required
logging.info(
"NN alarm gate freq=%s prediction=%s drone_probability=%.3f threshold=%.3f streak=%s/%s triggered=%s",
freq,
prediction,
drone_probability,
required_prob,
drone_streaks[freq],
required,
triggered,
)
return 8 if triggered else 0
def parse_freqs(raw_value):
freqs = []
for item in (raw_value or "").split(','):
item = item.strip()
if not item:
continue
freqs.append(int(item))
if not freqs:
raise RuntimeError("[NN_server/server.py] no NN frequencies configured in FREQS")
return freqs
def parse_classes(raw_value):
if raw_value is None:
raise RuntimeError("[NN_server/server.py] model classes are missing")
value = raw_value.strip()
if value.startswith('[') and value.endswith(']'):
value = value[1:-1]
classes = {}
for class_name in value.split(','):
class_name = class_name.strip()
if class_name:
classes[len(classes)] = class_name
if not classes:
raise RuntimeError("[NN_server/server.py] no classes parsed from NN_CLASSES_*")
return classes
def get_required_config(key):
value = config.get(key)
if value is None:
raise RuntimeError(f"[NN_server/server.py] missing required env key: {key}")
value = str(value).strip()
if not value:
raise RuntimeError(f"[NN_server/server.py] empty required env key: {key}")
return value
def get_optional_config(key, default=''):
value = config.get(key)
if value is None:
return default
return str(value).strip()
def build_model_specs():
build_func_name = get_optional_config('NN_BUILD_FUNC', 'build_func_ensemble')
pre_func_name = get_optional_config('NN_PRE_FUNC', 'pre_func_ensemble')
inference_func_name = get_optional_config('NN_INFERENCE_FUNC', 'inference_func_ensemble')
post_func_name = get_optional_config('NN_POST_FUNC', 'post_func_ensemble')
src_example = get_optional_config('NN_SRC_EXAMPLE', config['SRC_EXAMPLE'])
src_result = get_optional_config('NN_SRC_RESULT', config['SRC_RESULT'])
synthetic_examples = int(get_optional_config('NN_SYNTHETIC_EXAMPLES', '0'))
synthetic_mix_count = int(get_optional_config('NN_SYNTHETIC_MIX_COUNT', '1'))
src_dataset = get_optional_config('NN_SRC_DATASET', '')
specs = []
for freq in parse_freqs(config.get('NN_FREQS', config.get('FREQS', ''))):
module_name = get_required_config(f'NN_MODEL_{freq}')
weights = get_required_config(f'NN_WEIGHTS_{freq}')
classes = parse_classes(get_required_config(f'NN_CLASSES_{freq}'))
file_config = get_optional_config(f'NN_CONFIG_{freq}', get_optional_config('NN_CONFIG', ''))
specs.append({
'freq': freq,
'module_name': module_name,
'weights': weights,
'config': file_config,
'classes': classes,
'src_example': src_example,
'src_result': src_result,
'build_func_name': build_func_name,
'pre_func_name': pre_func_name,
'inference_func_name': inference_func_name,
'post_func_name': post_func_name,
'synthetic_examples': synthetic_examples,
'synthetic_mix_count': synthetic_mix_count,
'src_dataset': src_dataset,
})
return specs
if not config:
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
logging.info("NN config loaded from %s", ROOT_ENV)
gen_server_ip = config['GENERAL_SERVER_IP']
gen_server_port = config['GENERAL_SERVER_PORT']
drone_streaks = {}
MODEL_SPECS = build_model_specs()
def recreate_directory(path):
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
def init_data_for_inference():
try:
if MODEL_SPECS:
recreate_directory(MODEL_SPECS[0]['src_result'])
recreate_directory(MODEL_SPECS[0]['src_example'])
except Exception as exc:
print(str(exc))
print()
try:
global model_list
model_list.clear()
for spec in MODEL_SPECS:
module = importlib.import_module('Models.' + spec['module_name'])
model = Model(
freq=spec['freq'],
file_model=spec['weights'],
file_config=spec['config'],
src_example=spec['src_example'],
src_result=spec['src_result'],
type_model=f"{spec['module_name']}@{spec['freq']}",
build_model_func=getattr(module, spec['build_func_name']),
pre_func=getattr(module, spec['pre_func_name']),
inference_func=getattr(module, spec['inference_func_name']),
post_func=getattr(module, spec['post_func_name']),
classes=spec['classes'],
number_synthetic_examples=spec['synthetic_examples'],
number_src_data_for_one_synthetic_example=spec['synthetic_mix_count'],
path_to_src_dataset=spec['src_dataset'],
)
model_list.append(model)
except Exception as exc:
print(str(exc))
print()
def run_example():
try:
for model in model_list:
model.get_test_inference()
except Exception as exc:
print(str(exc))
def find_model_for_freq(freq):
for model in model_list:
if model.get_freq() == freq:
return model
return None
@app.route('/receive_data', methods=['POST'])
def receive_data():
try:
print()
data = json.loads(request.json)
print('#' * 100)
print('Получен пакет ' + str(Model.get_ind_inference()))
freq = int(data['freq'])
print('Частота: ' + str(freq))
result_msg = {}
data_to_send = {}
prediction_list = []
model = find_model_for_freq(freq)
if model is None:
raise RuntimeError(f"No NN model configured for freq={freq}")
print('-' * 100)
print(str(model))
result_msg[str(model.get_model_name())] = {'freq': freq}
inference_result = model.get_inference([
np.asarray(data['data_real'], dtype=np.float32),
np.asarray(data['data_imag'], dtype=np.float32),
])
if inference_result is None:
raise RuntimeError(f"Inference failed for {model.get_model_name()}")
prediction, probability = inference_result[:2]
drone_probability = float(probability) if prediction == "drone" else 0.0
result_msg[str(model.get_model_name())]['prediction'] = prediction
result_msg[str(model.get_model_name())]['probability'] = str(probability)
result_msg[str(model.get_model_name())]['drone_probability'] = str(drone_probability)
result_msg[str(model.get_model_name())]['drone_threshold'] = str(get_required_drone_prob(freq))
prediction_list.append(prediction)
print('-' * 100)
print()
try:
result = update_drone_streak(freq, prediction, drone_probability)
data_to_send = {
'freq': str(freq),
'amplitude': result,
}
response = requests.post(
"http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port),
json=data_to_send,
)
if response.status_code == 200:
print("Данные успешно отправлены!")
print("Частота: " + str(freq))
print("Отправлено светодиодов: " + str(result))
else:
print("Ошибка при отправке данных: ", response.status_code)
except Exception as exc:
print(str(exc))
Model.get_inc_ind_inference()
print()
print('#' * 100)
for alg in alg_list:
print('-' * 100)
print(str(alg))
alg.get_inference([
np.asarray(data['data_real'], dtype=np.float32),
np.asarray(data['data_imag'], dtype=np.float32),
])
print('-' * 100)
print()
print()
print('#' * 100)
del data
gc.collect()
return jsonify(result_msg)
except Exception as exc:
print(str(exc))
def run_flask():
print(config['SERVER_IP'])
app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))
if __name__ == '__main__':
init_data_for_inference()
flask_thread = threading.Thread(target=run_flask)
flask_thread.start()