You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

679 lines
25 KiB
Python

from __future__ import annotations
import argparse
import json
import math
import mimetypes
import statistics
import threading
import time
from datetime import datetime, timezone
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
from urllib import error, parse, request
from triangulation import (
PropagationModel,
Sphere,
rssi_to_distance_m,
send_payload_to_server,
solve_three_sphere_intersection,
)
Point3D = Tuple[float, float, float]
def _load_json(path: str) -> Dict[str, object]:
file_path = Path(path)
if not file_path.exists():
raise SystemExit(f"Config file not found: {path}")
with file_path.open("r", encoding="utf-8") as fh:
data = json.load(fh)
if not isinstance(data, dict):
raise SystemExit("Config root must be a JSON object.")
return data
def _center_from_obj(obj: Dict[str, object]) -> Point3D:
center = obj.get("center")
if not isinstance(center, dict):
raise ValueError("Receiver center must be an object.")
return (float(center["x"]), float(center["y"]), float(center["z"]))
def _parse_model(config: Dict[str, object]) -> PropagationModel:
model_obj = config.get("model")
if not isinstance(model_obj, dict):
raise ValueError("Config must contain object 'model'.")
return PropagationModel(
tx_power_dbm=float(model_obj["tx_power_dbm"]),
tx_gain_dbi=float(model_obj.get("tx_gain_dbi", 0.0)),
rx_gain_dbi=float(model_obj.get("rx_gain_dbi", 0.0)),
path_loss_exponent=float(model_obj.get("path_loss_exponent", 2.0)),
reference_distance_m=float(model_obj.get("reference_distance_m", 1.0)),
min_distance_m=float(model_obj.get("min_distance_m", 1e-3)),
)
def _float_from_measurement(
item: Dict[str, object],
keys: Sequence[str],
field_name: str,
source_label: str,
row_index: int,
) -> float:
for key in keys:
if key in item:
value = item[key]
try:
parsed = float(value)
except (TypeError, ValueError):
raise ValueError(
f"{source_label}: row #{row_index} field '{key}' must be numeric, got {value!r}."
) from None
if not math.isfinite(parsed):
raise ValueError(
f"{source_label}: row #{row_index} field '{key}' must be finite, got {value!r}."
)
return parsed
raise ValueError(f"{source_label}: row #{row_index} missing field '{field_name}'.")
def parse_source_payload(
payload: object,
source_label: str,
expected_receiver_id: Optional[str] = None,
) -> List[Tuple[float, float]]:
if isinstance(payload, dict):
if expected_receiver_id is not None and "receiver_id" in payload:
payload_receiver_id = str(payload["receiver_id"])
if payload_receiver_id != expected_receiver_id:
raise ValueError(
f"{source_label}: payload receiver_id '{payload_receiver_id}' "
f"does not match expected '{expected_receiver_id}'."
)
raw_items = payload.get("measurements")
if raw_items is None:
raw_items = payload.get("samples")
if raw_items is None:
raw_items = payload.get("data")
elif isinstance(payload, list):
raw_items = payload
else:
raise ValueError(f"{source_label}: payload must be list or object.")
if not isinstance(raw_items, list) or not raw_items:
raise ValueError(f"{source_label}: payload contains no measurements.")
parsed_items: List[Tuple[float, float]] = []
for row_index, row in enumerate(raw_items, start=1):
if not isinstance(row, dict):
raise ValueError(f"{source_label}: row #{row_index} must be an object.")
frequency_hz = _float_from_measurement(
row,
keys=("frequency_hz", "freq_hz", "frequency", "freq"),
field_name="frequency_hz",
source_label=source_label,
row_index=row_index,
)
amplitude_dbm = _float_from_measurement(
row,
keys=("amplitude_dbm", "rssi_dbm", "amplitude", "rssi"),
field_name="amplitude_dbm",
source_label=source_label,
row_index=row_index,
)
if frequency_hz <= 0.0:
raise ValueError(
f"{source_label}: row #{row_index} field 'frequency_hz' must be > 0."
)
parsed_items.append((frequency_hz, amplitude_dbm))
return parsed_items
def aggregate_radius(
measurements: Sequence[Tuple[float, float]],
model: PropagationModel,
method: str,
) -> float:
distances = [
rssi_to_distance_m(amplitude_dbm=amplitude_dbm, frequency_hz=frequency_hz, model=model)
for frequency_hz, amplitude_dbm in measurements
]
if method == "median":
return float(statistics.median(distances))
if method == "mean":
return float(sum(distances) / len(distances))
raise ValueError("aggregation must be 'median' or 'mean'")
def _group_by_frequency(
measurements: Sequence[Tuple[float, float]],
) -> Dict[float, List[Tuple[float, float]]]:
grouped: Dict[float, List[Tuple[float, float]]] = {}
for frequency_hz, amplitude_dbm in measurements:
if frequency_hz not in grouped:
grouped[frequency_hz] = []
grouped[frequency_hz].append((frequency_hz, amplitude_dbm))
return grouped
def _fetch_measurements(
url: str,
timeout_s: float,
expected_receiver_id: Optional[str] = None,
) -> List[Tuple[float, float]]:
source_label = f"source_url={url}"
req = request.Request(url=url, method="GET", headers={"Accept": "application/json"})
try:
with request.urlopen(req, timeout=timeout_s) as response:
payload = json.loads(response.read().decode("utf-8"))
except error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {exc.code} for '{url}': {body}")
except error.URLError as exc:
raise RuntimeError(f"Cannot reach '{url}': {exc.reason}")
except TimeoutError:
raise RuntimeError(f"Timeout while reading '{url}'")
except json.JSONDecodeError as exc:
raise RuntimeError(f"Invalid JSON from '{url}': {exc}")
try:
return parse_source_payload(
payload=payload,
source_label=source_label,
expected_receiver_id=expected_receiver_id,
)
except ValueError as exc:
raise RuntimeError(str(exc)) from None
class AutoService:
def __init__(self, config: Dict[str, object], config_path: Optional[str] = None) -> None:
self.config = config
self.config_path = config_path
self.model = _parse_model(config)
solver_obj = config.get("solver", {})
runtime_obj = config.get("runtime", {})
input_obj = config.get("input")
if not isinstance(solver_obj, dict):
raise ValueError("solver must be object.")
if not isinstance(runtime_obj, dict):
raise ValueError("runtime must be object.")
if not isinstance(input_obj, dict):
raise ValueError("input must be object.")
self.tolerance = float(solver_obj.get("tolerance", 1e-3))
self.z_preference = str(solver_obj.get("z_preference", "positive"))
if self.z_preference not in ("positive", "negative"):
raise ValueError("solver.z_preference must be 'positive' or 'negative'.")
self.poll_interval_s = float(runtime_obj.get("poll_interval_s", 1.0))
output_obj = runtime_obj.get("output_server", {})
if output_obj is None:
output_obj = {}
if not isinstance(output_obj, dict):
raise ValueError("runtime.output_server must be object.")
self.output_enabled = bool(output_obj.get("enabled", False))
self.output_ip = str(output_obj.get("ip", ""))
self.output_port = int(output_obj.get("port", 8080))
self.output_path = str(output_obj.get("path", "/triangulation"))
self.output_timeout_s = float(output_obj.get("timeout_s", 3.0))
if self.output_enabled and not self.output_ip:
raise ValueError("runtime.output_server.ip must be non-empty when enabled=true.")
self.source_timeout_s = float(input_obj.get("source_timeout_s", 3.0))
self.aggregation = str(input_obj.get("aggregation", "median"))
if self.aggregation not in ("median", "mean"):
raise ValueError("input.aggregation must be 'median' or 'mean'.")
input_mode = str(input_obj.get("mode", "http_sources"))
if input_mode != "http_sources":
raise ValueError("Automatic service requires input.mode = 'http_sources'.")
receivers = input_obj.get("receivers")
if not isinstance(receivers, list) or len(receivers) != 3:
raise ValueError("input.receivers must contain exactly 3 objects.")
parsed_receivers: List[Dict[str, object]] = []
for receiver in receivers:
if not isinstance(receiver, dict):
raise ValueError("Each receiver must be object.")
parsed_receivers.append(
{
"receiver_id": str(receiver["receiver_id"]),
"center": _center_from_obj(receiver),
"source_url": str(receiver["source_url"]),
}
)
self.receivers = parsed_receivers
self.state_lock = threading.Lock()
self.latest_payload: Optional[Dict[str, object]] = None
self.last_error: str = "no data yet"
self.updated_at_utc: Optional[str] = None
self.last_output_delivery: Dict[str, object] = {
"enabled": self.output_enabled,
"status": "disabled" if not self.output_enabled else "pending",
"http_status": None,
"response_body": "",
"sent_at_utc": None,
}
self.stop_event = threading.Event()
self.poll_thread = threading.Thread(target=self._poll_loop, daemon=True)
def start(self) -> None:
self.poll_thread.start()
def stop(self) -> None:
self.stop_event.set()
self.poll_thread.join(timeout=2.0)
def refresh_once(self) -> None:
spheres_all: List[Sphere] = []
receiver_payloads: List[Dict[str, object]] = []
grouped_by_receiver: List[Dict[float, List[Tuple[float, float]]]] = []
for receiver in self.receivers:
receiver_id = str(receiver["receiver_id"])
center = receiver["center"]
source_url = str(receiver["source_url"])
measurements = _fetch_measurements(
source_url,
timeout_s=self.source_timeout_s,
expected_receiver_id=receiver_id,
)
grouped = _group_by_frequency(measurements)
grouped_by_receiver.append(grouped)
radius_m = aggregate_radius(measurements, model=self.model, method=self.aggregation)
spheres_all.append(Sphere(center=center, radius=radius_m))
samples = []
for frequency_hz, amplitude_dbm in measurements:
samples.append(
{
"frequency_hz": frequency_hz,
"amplitude_dbm": amplitude_dbm,
"distance_m": rssi_to_distance_m(
amplitude_dbm=amplitude_dbm,
frequency_hz=frequency_hz,
model=self.model,
),
}
)
receiver_payloads.append(
{
"receiver_id": receiver_id,
"center": {"x": center[0], "y": center[1], "z": center[2]},
"source_url": source_url,
"aggregation": self.aggregation,
"radius_m_all_freq": radius_m,
"samples": samples,
}
)
common_frequencies = (
set(grouped_by_receiver[0].keys())
& set(grouped_by_receiver[1].keys())
& set(grouped_by_receiver[2].keys())
)
if not common_frequencies:
raise RuntimeError("No common frequencies across all 3 receivers.")
frequency_rows: List[Dict[str, object]] = []
best_row: Optional[Dict[str, object]] = None
for frequency_hz in sorted(common_frequencies):
spheres_for_frequency: List[Sphere] = []
row_receivers: List[Dict[str, object]] = []
for index, receiver in enumerate(self.receivers):
center = receiver["center"]
measurement_subset = grouped_by_receiver[index][frequency_hz]
radius_m = aggregate_radius(
measurement_subset, model=self.model, method=self.aggregation
)
spheres_for_frequency.append(Sphere(center=center, radius=radius_m))
row_receivers.append(
{
"receiver_id": str(receiver["receiver_id"]),
"radius_m": radius_m,
"samples_count": len(measurement_subset),
}
)
result = solve_three_sphere_intersection(
spheres=spheres_for_frequency,
tolerance=self.tolerance,
z_preference=self.z_preference, # type: ignore[arg-type]
)
for index, residual in enumerate(result.residuals):
row_receivers[index]["residual_m"] = residual
receiver_payloads[index].setdefault("per_frequency", []).append(
{
"frequency_hz": frequency_hz,
"radius_m": spheres_for_frequency[index].radius,
"residual_m": residual,
"samples_count": len(grouped_by_receiver[index][frequency_hz]),
}
)
row = {
"frequency_hz": frequency_hz,
"position": {
"x": result.point[0],
"y": result.point[1],
"z": result.point[2],
},
"exact": result.exact,
"rmse_m": result.rmse,
"receivers": row_receivers,
}
frequency_rows.append(row)
if best_row is None or float(row["rmse_m"]) < float(best_row["rmse_m"]):
best_row = row
if best_row is None:
raise RuntimeError("Cannot build frequency table for trilateration.")
payload = {
"timestamp_utc": datetime.now(timezone.utc).isoformat(),
"selected_frequency_hz": best_row["frequency_hz"],
"position": best_row["position"],
"exact": best_row["exact"],
"rmse_m": best_row["rmse_m"],
"frequency_table": frequency_rows,
"model": {
"tx_power_dbm": self.model.tx_power_dbm,
"tx_gain_dbi": self.model.tx_gain_dbi,
"rx_gain_dbi": self.model.rx_gain_dbi,
"path_loss_exponent": self.model.path_loss_exponent,
"reference_distance_m": self.model.reference_distance_m,
},
"receivers": receiver_payloads,
}
with self.state_lock:
self.latest_payload = payload
self.updated_at_utc = payload["timestamp_utc"] # type: ignore[index]
self.last_error = ""
if self.output_enabled:
status_code, response_body = send_payload_to_server(
server_ip=self.output_ip,
payload=payload,
port=self.output_port,
path=self.output_path,
timeout_s=self.output_timeout_s,
)
with self.state_lock:
self.last_output_delivery = {
"enabled": True,
"status": "ok" if 200 <= status_code < 300 else "error",
"http_status": status_code,
"response_body": response_body,
"sent_at_utc": datetime.now(timezone.utc).isoformat(),
"target": {
"ip": self.output_ip,
"port": self.output_port,
"path": self.output_path,
},
}
if status_code < 200 or status_code >= 300:
raise RuntimeError(
"Output server rejected payload: "
f"HTTP {status_code}, body={response_body}"
)
def _poll_loop(self) -> None:
while not self.stop_event.is_set():
try:
self.refresh_once()
except Exception as exc:
with self.state_lock:
self.last_error = str(exc)
self.stop_event.wait(self.poll_interval_s)
def snapshot(self) -> Dict[str, object]:
with self.state_lock:
return {
"updated_at_utc": self.updated_at_utc,
"last_error": self.last_error,
"payload": self.latest_payload,
"output_delivery": self.last_output_delivery,
}
def _make_handler(service: AutoService):
class ServiceHandler(BaseHTTPRequestHandler):
def _write_bytes(
self,
status_code: int,
content: bytes,
content_type: str,
) -> None:
self.send_response(status_code)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(content)))
self.end_headers()
self.wfile.write(content)
def _write_json(self, status_code: int, payload: Dict[str, object]) -> None:
raw = json.dumps(payload, ensure_ascii=False).encode("utf-8")
self._write_bytes(
status_code=status_code,
content=raw,
content_type="application/json; charset=utf-8",
)
def _write_static(self, relative_path: str) -> None:
web_root = Path(__file__).resolve().parent / "web"
file_path = (web_root / relative_path).resolve()
if not str(file_path).startswith(str(web_root.resolve())):
self._write_json(404, {"error": "not_found"})
return
if not file_path.exists() or not file_path.is_file():
self._write_json(404, {"error": "not_found"})
return
mime_type, _ = mimetypes.guess_type(str(file_path))
if mime_type is None:
mime_type = "application/octet-stream"
self._write_bytes(200, file_path.read_bytes(), mime_type)
def log_message(self, format: str, *args) -> None:
return
def do_GET(self) -> None:
path = parse.urlparse(self.path).path
snapshot = service.snapshot()
if path == "/" or path == "/ui":
self._write_static("index.html")
return
if path.startswith("/static/"):
self._write_static(path.removeprefix("/static/"))
return
if path == "/health":
status = "ok" if snapshot["payload"] else "warming_up"
http_code = 200 if status == "ok" else 503
self._write_json(
http_code,
{
"status": status,
"updated_at_utc": snapshot["updated_at_utc"],
"error": snapshot["last_error"],
},
)
return
if path == "/result":
payload = snapshot["payload"]
if payload is None:
self._write_json(
503,
{
"status": "warming_up",
"updated_at_utc": snapshot["updated_at_utc"],
"error": snapshot["last_error"],
},
)
return
self._write_json(
200,
{
"status": "ok",
"updated_at_utc": snapshot["updated_at_utc"],
"data": payload,
"output_delivery": snapshot["output_delivery"],
},
)
return
if path == "/frequencies":
payload = snapshot["payload"]
if payload is None:
self._write_json(
503,
{
"status": "warming_up",
"updated_at_utc": snapshot["updated_at_utc"],
"error": snapshot["last_error"],
},
)
return
self._write_json(
200,
{
"status": "ok",
"updated_at_utc": snapshot["updated_at_utc"],
"selected_frequency_hz": payload.get("selected_frequency_hz"),
"frequency_table": payload.get("frequency_table", []),
"output_delivery": snapshot["output_delivery"],
},
)
return
if path == "/config":
self._write_json(
200,
{
"status": "ok",
"config_path": service.config_path,
"config": service.config,
},
)
return
self._write_json(404, {"error": "not_found"})
def do_POST(self) -> None:
path = parse.urlparse(self.path).path
if path == "/config":
try:
content_length = int(self.headers.get("Content-Length", "0"))
except ValueError:
self._write_json(400, {"status": "error", "error": "Invalid Content-Length"})
return
body = self.rfile.read(content_length) if content_length > 0 else b""
try:
new_config = json.loads(body.decode("utf-8"))
except json.JSONDecodeError as exc:
self._write_json(400, {"status": "error", "error": f"Invalid JSON: {exc}"})
return
if not isinstance(new_config, dict):
self._write_json(400, {"status": "error", "error": "Config must be JSON object"})
return
try:
AutoService(new_config)
except Exception as exc:
self._write_json(
400,
{"status": "error", "error": f"Config validation failed: {exc}"},
)
return
service.config = new_config
if service.config_path:
Path(service.config_path).write_text(
json.dumps(new_config, ensure_ascii=False, indent=2),
encoding="utf-8",
)
self._write_json(
200,
{
"status": "ok",
"saved": bool(service.config_path),
"restart_required": True,
"config_path": service.config_path,
},
)
return
if path != "/refresh":
self._write_json(404, {"error": "not_found"})
return
try:
service.refresh_once()
except Exception as exc:
self._write_json(500, {"status": "error", "error": str(exc)})
return
snapshot = service.snapshot()
self._write_json(
200,
{
"status": "ok",
"updated_at_utc": snapshot["updated_at_utc"],
},
)
return ServiceHandler
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Automatic trilateration service: polls 3 receiver servers and exposes result API."
)
parser.add_argument("--config", type=str, default="config.json")
parser.add_argument("--host", type=str, default="")
parser.add_argument("--port", type=int, default=0)
return parser.parse_args()
def main() -> int:
args = parse_args()
config = _load_json(args.config)
runtime = config.get("runtime", {})
if not isinstance(runtime, dict):
raise SystemExit("runtime must be object.")
host = args.host or str(runtime.get("listen_host", "0.0.0.0"))
port = args.port or int(runtime.get("listen_port", 8081))
service = AutoService(config, config_path=args.config)
service.start()
server = ThreadingHTTPServer((host, port), _make_handler(service))
print(f"service_listen: http://{host}:{port}")
try:
server.serve_forever()
except KeyboardInterrupt:
pass
finally:
server.server_close()
service.stop()
return 0
if __name__ == "__main__":
raise SystemExit(main())