You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/tests/test_nn_profile_schedule.py

85 lines
2.8 KiB
Python

from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
import unittest
from common.nn_profile_schedule import (
DEFAULT_PROFILE,
format_active_profile_env,
get_available_profiles,
get_profile_model_entries,
get_requested_active_profile,
load_simple_env_file,
parse_schedule,
resolve_active_profile,
)
MODEL_2400 = "model_2400 && cfg && ex && out && Resnet18_1_2400 && build && pre && infer && post && [drone,noise] && 10 && 1 && /dataset"
MODEL_915 = "model_915 && cfg && ex && out && ensemble_915 && build && pre && infer && post && [drone,noise] && 10 && 1 && /dataset"
class NNProfileScheduleTests(unittest.TestCase):
def test_parse_schedule_supports_wraparound(self):
rules = parse_schedule("08:00-20:00=DAY;20:00-08:00=NIGHT")
self.assertEqual(len(rules), 2)
self.assertEqual(rules[0].profile, "DAY")
self.assertEqual(rules[1].profile, "NIGHT")
def test_requested_profile_uses_local_time_schedule(self):
config = {"NN_SCHEDULE": "08:00-20:00=DAY;20:00-08:00=NIGHT", "NN_ACTIVE_PROFILE": "DEFAULT"}
self.assertEqual(
get_requested_active_profile(config, now=datetime(2026, 4, 23, 9, 0)),
"DAY",
)
self.assertEqual(
get_requested_active_profile(config, now=datetime(2026, 4, 23, 21, 0)),
"NIGHT",
)
def test_resolve_active_profile_falls_back_to_default(self):
config = {
"NN_ACTIVE_PROFILE": "DAY",
"NN_1": MODEL_2400,
}
self.assertEqual(resolve_active_profile(config), DEFAULT_PROFILE)
def test_available_profiles_include_default_and_named(self):
config = {
"NN_1": MODEL_2400,
"NN_PROFILE_DAY_21": MODEL_915,
"NN_PROFILE_NIGHT_22": MODEL_2400,
}
self.assertEqual(get_available_profiles(config), {"DEFAULT", "DAY", "NIGHT"})
def test_get_profile_model_entries_renumbers_to_logical_nn_keys(self):
config = {
"NN_PROFILE_DAY_21": MODEL_915,
"NN_PROFILE_DAY_1": MODEL_2400,
}
entries = get_profile_model_entries(config, "day")
self.assertEqual(entries[0][0], "NN_1")
self.assertEqual(entries[1][0], "NN_21")
def test_load_simple_env_file_strips_quotes(self):
with TemporaryDirectory() as tmpdir:
env_path = Path(tmpdir) / "profile.env"
env_path.write_text("NN_ACTIVE_PROFILE='night'\n", encoding="utf-8")
values = load_simple_env_file(env_path)
self.assertEqual(values["NN_ACTIVE_PROFILE"], "night")
def test_format_active_profile_env_normalizes_value(self):
self.assertEqual(format_active_profile_env("night"), "NN_ACTIVE_PROFILE=NIGHT\n")
if __name__ == "__main__":
unittest.main()