добавил drone_streaks

automatica-1
Sergey Revyakin 4 weeks ago
parent 42c724f227
commit 523dbf77d1

@ -3,6 +3,7 @@ from dotenv import dotenv_values
from common.runtime import load_root_env, validate_env, as_int, as_str from common.runtime import load_root_env, validate_env, as_int, as_str
import os import os
import sys import sys
import re
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from Model import Model from Model import Model
import numpy as np import numpy as np
@ -49,13 +50,47 @@ validate_env("NN_server/server.py", {
}) })
config = dict(dotenv_values(ROOT_ENV)) config = dict(dotenv_values(ROOT_ENV))
def is_model_config_key(key, value):
return bool(re.fullmatch(r"NN_\d+", key or "")) and isinstance(value, str) and " && " in value
def get_required_drone_streak(freq):
raw_value = config.get(f"DRONE_STREAK_{freq}", "1")
try:
return max(1, int(raw_value))
except (TypeError, ValueError):
logging.warning("Invalid DRONE_STREAK_%s=%r, falling back to 1", freq, raw_value)
return 1
def update_drone_streak(freq, prediction):
if prediction == "drone":
drone_streaks[freq] = drone_streaks.get(freq, 0) + 1
else:
drone_streaks[freq] = 0
required = get_required_drone_streak(freq)
triggered = prediction == "drone" and drone_streaks[freq] >= required
logging.info(
"NN alarm gate freq=%s prediction=%s streak=%s/%s triggered=%s",
freq,
prediction,
drone_streaks[freq],
required,
triggered,
)
return 8 if triggered else 0
if not config: if not config:
raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed") raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed")
if not any(key.startswith("NN_") for key in config): if not any(is_model_config_key(key, value) for key, value in config.items()):
raise RuntimeError("[NN_server/server.py] no NN_* model entries configured") raise RuntimeError("[NN_server/server.py] no NN_* model entries configured")
logging.info("NN config loaded from %s", ROOT_ENV) logging.info("NN config loaded from %s", ROOT_ENV)
gen_server_ip = config['GENERAL_SERVER_IP'] gen_server_ip = config['GENERAL_SERVER_IP']
gen_server_port = config['GENERAL_SERVER_PORT'] gen_server_port = config['GENERAL_SERVER_PORT']
drone_streaks = {}
def init_data_for_inference(): def init_data_for_inference():
try: try:
@ -71,9 +106,9 @@ def init_data_for_inference():
try: try:
global model_list global model_list
for key in config.keys(): for key, value in config.items():
if key.startswith('NN_'): if is_model_config_key(key, value):
params = config[key].split(' && ') params = value.split(' && ')
module = importlib.import_module('Models.' + params[4]) module = importlib.import_module('Models.' + params[4])
classes = {} classes = {}
for value in params[9][1:-1].split(','): for value in params[9][1:-1].split(','):
@ -137,13 +172,7 @@ def receive_data():
print() print()
try: try:
result = 0 result = update_drone_streak(freq, prediction_list[0])
if (int(freq) == 2400 and (prediction_list[0] in ['drone', 'drone_noise'] or (prediction_list[0] == 'wifi' and float(probability) >= 0.95))) or (int(freq) == 1200 and (prediction_list[0] in ['drone'] and float(probability) >= 0.95)):
result += 8
if int(freq) in [915]:
result = 0
if int(freq) in []:
result = 8
data_to_send={ data_to_send={
'freq': str(freq), 'freq': str(freq),
'amplitude': result 'amplitude': result

Loading…
Cancel
Save