diff --git a/NN_server/server.py b/NN_server/server.py index 331e1d0..36c751b 100644 --- a/NN_server/server.py +++ b/NN_server/server.py @@ -15,6 +15,8 @@ import shutil import json import gc import logging +import time +import re TORCHSIG_PATH = "/app/torchsig" if TORCHSIG_PATH not in sys.path: @@ -51,20 +53,89 @@ config = dict(dotenv_values(ROOT_ENV)) if not config: raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed") -if not any(key.startswith("NN_") for key in config): +MODEL_ENV_RE = re.compile(r"^NN_\d+$") + +if not any(MODEL_ENV_RE.match(key) for key in config): raise RuntimeError("[NN_server/server.py] no NN_* model entries configured") logging.info("NN config loaded from %s", ROOT_ENV) gen_server_ip = config['GENERAL_SERVER_IP'] gen_server_port = config['GENERAL_SERVER_PORT'] +INFERENCE_TELEMETRY_HOST = os.getenv('telemetry_host', '127.0.0.1') +INFERENCE_TELEMETRY_PORT = os.getenv('telemetry_port', '5020') +INFERENCE_TELEMETRY_ENDPOINT = os.getenv('telemetry_inference_endpoint', 'inference/result') +INFERENCE_TELEMETRY_TIMEOUT_SEC = float(os.getenv('telemetry_inference_timeout_sec', '0.30')) +INFERENCE_IMAGE_RE = re.compile(r"_inference_(\d+)_") + + +def get_result_dir(): + return config.get('SRC_RESULT', '') + + +def collect_inference_images(result_id, model_name=''): + result_dir = get_result_dir() + if not result_dir or not os.path.isdir(result_dir): + return result_id, [] + + needle = f"_inference_{result_id}_" + model_suffix = f"_{model_name}.png" if model_name else '' + exact_images = [] + grouped_images = {} + + for name in sorted(os.listdir(result_dir)): + if not name.endswith('.png'): + continue + if model_suffix and not name.endswith(model_suffix): + continue + match = INFERENCE_IMAGE_RE.search(name) + if match is None: + continue + + image_result_id = int(match.group(1)) + grouped_images.setdefault(image_result_id, []).append(name) + if image_result_id == result_id and needle in name: + exact_images.append(name) + + if exact_images: + return result_id, exact_images + if not grouped_images: + return result_id, [] + + latest_result_id = max(grouped_images) + return latest_result_id, grouped_images[latest_result_id] + + +def send_inference_result(payload): + try: + requests.post( + "http://{0}:{1}/{2}".format( + INFERENCE_TELEMETRY_HOST, + INFERENCE_TELEMETRY_PORT, + INFERENCE_TELEMETRY_ENDPOINT.lstrip('/'), + ), + json=payload, + timeout=INFERENCE_TELEMETRY_TIMEOUT_SEC, + ) + except Exception as exc: + print(str(exc)) + + +def reset_directory_contents(path): + os.makedirs(path, exist_ok=True) + for name in os.listdir(path): + full_path = os.path.join(path, name) + try: + if os.path.isdir(full_path) and not os.path.islink(full_path): + shutil.rmtree(full_path) + else: + os.remove(full_path) + except FileNotFoundError: + continue + def init_data_for_inference(): try: - if os.path.isdir(config['SRC_RESULT']): - shutil.rmtree(config['SRC_RESULT']) - os.mkdir(config['SRC_RESULT']) - if os.path.isdir(config['SRC_EXAMPLE']): - shutil.rmtree(config['SRC_EXAMPLE']) - os.mkdir(config['SRC_EXAMPLE']) + reset_directory_contents(config['SRC_RESULT']) + reset_directory_contents(config['SRC_EXAMPLE']) except Exception as exc: print(str(exc)) print() @@ -72,7 +143,7 @@ def init_data_for_inference(): try: global model_list for key in config.keys(): - if key.startswith('NN_'): + if MODEL_ENV_RE.match(key): params = config[key].split(' && ') module = importlib.import_module('Models.' + params[4]) classes = {} @@ -111,9 +182,12 @@ def run_example(): def receive_data(): try: print() - data = json.loads(request.json) + data = request.json + if isinstance(data, str): + data = json.loads(data) print('#' * 100) print('Получен пакет ' + str(Model.get_ind_inference())) + result_id = Model.get_ind_inference() freq = int(data['freq']) print('Частота: ' + str(freq)) # print('Канал: ' + str(data['channel'])) @@ -133,6 +207,18 @@ def receive_data(): result_msg[str(model.get_model_name())]['prediction'] = prediction result_msg[str(model.get_model_name())]['probability'] = str(probability) prediction_list.append(prediction) + image_result_id, images = collect_inference_images(result_id, model.get_model_name()) + send_inference_result({ + 'result_id': image_result_id, + 'ts': time.time(), + 'freq': str(freq), + 'model': model.get_model_name(), + 'prediction': prediction, + 'probability': float(probability), + 'drone_probability': float(probability) if prediction == 'drone' else 0.0, + 'drone_threshold': None, + 'images': images, + }) print('-' * 100) print() diff --git a/deploy/docker/docker-compose.yml b/deploy/docker/docker-compose.yml index a18eb02..9f83b36 100644 --- a/deploy/docker/docker-compose.yml +++ b/deploy/docker/docker-compose.yml @@ -37,6 +37,7 @@ services: environment: - PYTHONPATH=/app:/app/NN_server - NN_HOT_RELOAD=${NN_HOT_RELOAD:-1} + - telemetry_host=dronedetector-telemetry-server working_dir: /app/NN_server command: - sh @@ -76,6 +77,7 @@ services: - ../../.env:/app/.env:ro - ../../telemetry:/app/telemetry - ../../common:/app/common + - ../../NN_server/result:/app/inference_result:ro networks: - dronedetector-net diff --git a/telemetry/telemetry_server.py b/telemetry/telemetry_server.py index d507fb4..83d7113 100644 --- a/telemetry/telemetry_server.py +++ b/telemetry/telemetry_server.py @@ -1,11 +1,13 @@ import asyncio import os +import re import time from collections import defaultdict, deque +from pathlib import Path from typing import Any, Deque, Dict, List, Optional -from fastapi import FastAPI, Query, WebSocket, WebSocketDisconnect -from fastapi.responses import HTMLResponse +from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect +from fastapi.responses import FileResponse, HTMLResponse from pydantic import BaseModel, Field from common.runtime import load_root_env @@ -16,15 +18,25 @@ TELEMETRY_BIND_HOST = os.getenv('telemetry_bind_host', os.getenv('lochost', '0.0 TELEMETRY_BIND_PORT = int(os.getenv('telemetry_bind_port', os.getenv('telemetry_port', '5020'))) TELEMETRY_HISTORY_SEC = int(float(os.getenv('telemetry_history_sec', '900'))) TELEMETRY_MAX_POINTS_PER_FREQ = int(os.getenv('telemetry_max_points_per_freq', '5000')) +INFERENCE_HISTORY_SEC = int(float(os.getenv('inference_history_sec', str(TELEMETRY_HISTORY_SEC)))) +INFERENCE_MAX_RESULTS_PER_FREQ = int(os.getenv('inference_max_results_per_freq', '100')) +INFERENCE_RESULT_DIR = Path(os.getenv('inference_result_dir', '/app/inference_result')).resolve() +INFERENCE_IMAGE_RE = re.compile(r"_inference_(\d+)_") def _new_buffer() -> Deque[Dict[str, Any]]: return deque(maxlen=TELEMETRY_MAX_POINTS_PER_FREQ) +def _new_inference_buffer() -> Deque[Dict[str, Any]]: + return deque(maxlen=INFERENCE_MAX_RESULTS_PER_FREQ) + + app = FastAPI(title='DroneDetector Telemetry Server') _buffers: Dict[str, Deque[Dict[str, Any]]] = defaultdict(_new_buffer) _ws_clients: List[WebSocket] = [] +_inference_buffers: Dict[str, Deque[Dict[str, Any]]] = defaultdict(_new_inference_buffer) +_inference_ws_clients: List[WebSocket] = [] _state_lock = asyncio.Lock() @@ -41,6 +53,18 @@ class TelemetryPoint(BaseModel): alarm_channels: Optional[List[int]] = None +class InferenceResult(BaseModel): + result_id: int + ts: float = Field(default_factory=lambda: time.time()) + freq: str + model: str + prediction: str + probability: float + drone_probability: Optional[float] = None + drone_threshold: Optional[str] = None + images: List[str] = Field(default_factory=list) + + def _prune_freq_locked(freq: str, now_ts: float) -> None: cutoff = now_ts - TELEMETRY_HISTORY_SEC buf = _buffers[freq] @@ -62,6 +86,70 @@ def _copy_series_locked(seconds: int, freq: Optional[str] = None) -> Dict[str, L return series +def _prune_inference_freq_locked(freq: str, now_ts: float) -> None: + cutoff = now_ts - INFERENCE_HISTORY_SEC + buf = _inference_buffers[freq] + while buf and float(buf[0].get('ts', 0.0)) < cutoff: + buf.popleft() + + +def _copy_inference_series_locked(limit: int, freq: Optional[str] = None) -> Dict[str, List[Dict[str, Any]]]: + now_ts = time.time() + cutoff = now_ts - INFERENCE_HISTORY_SEC + + def _slice(buf: Deque[Dict[str, Any]]) -> List[Dict[str, Any]]: + recent = [item for item in buf if float(item.get('ts', 0.0)) >= cutoff] + return recent[-limit:] + + if freq is not None: + return {freq: _slice(_inference_buffers.get(freq, deque()))} + + series: Dict[str, List[Dict[str, Any]]] = {} + for key, buf in _inference_buffers.items(): + series[key] = _slice(buf) + return series + + +def _sanitize_image_names(names: List[str]) -> List[str]: + safe_names: List[str] = [] + for name in names: + base = Path(str(name)).name + if not base or not base.endswith('.png'): + continue + safe_names.append(base) + return safe_names + + +def _resolve_latest_images_for_model(payload: Dict[str, Any]) -> Dict[str, Any]: + model_name = str(payload.get('model', '')) + if not model_name or not INFERENCE_RESULT_DIR.is_dir(): + payload['images'] = _sanitize_image_names(payload.get('images', [])) + return payload + + model_suffix = f"_{model_name}.png" + grouped: Dict[int, List[str]] = {} + for path in INFERENCE_RESULT_DIR.iterdir(): + if not path.is_file(): + continue + name = path.name + if not name.endswith(model_suffix): + continue + match = INFERENCE_IMAGE_RE.search(name) + if match is None: + continue + grouped.setdefault(int(match.group(1)), []).append(name) + + if not grouped: + payload['images'] = _sanitize_image_names(payload.get('images', [])) + return payload + + current_id = int(payload.get('result_id', 0) or 0) + resolved_id = current_id if current_id in grouped else max(grouped) + payload['result_id'] = resolved_id + payload['images'] = sorted(grouped[resolved_id]) + return payload + + async def _broadcast(message: Dict[str, Any]) -> None: dead: List[WebSocket] = [] for ws in list(_ws_clients): @@ -77,6 +165,21 @@ async def _broadcast(message: Dict[str, Any]) -> None: _ws_clients.remove(ws) +async def _broadcast_inference(message: Dict[str, Any]) -> None: + dead: List[WebSocket] = [] + for ws in list(_inference_ws_clients): + try: + await ws.send_json(message) + except Exception: + dead.append(ws) + + if dead: + async with _state_lock: + for ws in dead: + if ws in _inference_ws_clients: + _inference_ws_clients.remove(ws) + + @app.post('/telemetry') async def ingest_telemetry(point: TelemetryPoint): payload = point.model_dump() @@ -102,6 +205,31 @@ async def telemetry_history( return {'seconds': seconds, 'series': series} +@app.post('/inference/result') +async def ingest_inference_result(result: InferenceResult): + payload = result.model_dump() + payload = _resolve_latest_images_for_model(payload) + freq = str(payload['freq']) + now_ts = time.time() + + async with _state_lock: + _inference_buffers[freq].append(payload) + _prune_inference_freq_locked(freq, now_ts) + + await _broadcast_inference({'type': 'inference_result', 'data': payload}) + return {'ok': True} + + +@app.get('/inference/history') +async def inference_history( + freq: Optional[str] = Query(default=None), + limit: int = Query(default=20, ge=1, le=200), +): + async with _state_lock: + series = _copy_inference_series_locked(limit=limit, freq=freq) + return {'limit': limit, 'series': series} + + @app.websocket('/telemetry/ws') async def telemetry_ws(websocket: WebSocket): await websocket.accept() @@ -123,6 +251,26 @@ async def telemetry_ws(websocket: WebSocket): _ws_clients.remove(websocket) +@app.websocket('/inference/ws') +async def inference_ws(websocket: WebSocket): + await websocket.accept() + async with _state_lock: + _inference_ws_clients.append(websocket) + snapshot = _copy_inference_series_locked(limit=20, freq=None) + + await websocket.send_json({'type': 'snapshot', 'data': snapshot}) + + try: + while True: + await websocket.receive_text() + except WebSocketDisconnect: + pass + finally: + async with _state_lock: + if websocket in _inference_ws_clients: + _inference_ws_clients.remove(websocket) + + MONITOR_HTML = """ @@ -453,12 +601,294 @@ loadInitial().then(connectWs).catch((e) => { """ +INFERENCE_VIEWER_HTML = """ + + + + + + DroneDetector Inference Viewer + + + +
+
+
+

DroneDetector Inference Viewer

+
Latest inference card per frequency. Browser keeps last 20 results per frequency.
+
+
+
connecting...
+ +
+
+
+
+ + + + +""" + + @app.get('/', response_class=HTMLResponse) @app.get('/monitor', response_class=HTMLResponse) async def monitor_page(): return HTMLResponse(content=MONITOR_HTML) +@app.get('/inference-viewer', response_class=HTMLResponse) +async def inference_viewer_page(): + return HTMLResponse(content=INFERENCE_VIEWER_HTML) + + +@app.get('/inference/images/{filename}') +async def inference_image(filename: str): + safe_name = Path(filename).name + if safe_name != filename: + raise HTTPException(status_code=404, detail='image not found') + + image_path = (INFERENCE_RESULT_DIR / safe_name).resolve() + if image_path.parent != INFERENCE_RESULT_DIR or not image_path.is_file(): + raise HTTPException(status_code=404, detail='image not found') + + return FileResponse(image_path) + + if __name__ == '__main__': import uvicorn