From 53539df2fdcbe3d8e71707b9ac2ce6e6c96eeea2 Mon Sep 17 00:00:00 2001 From: Sergey Revyakin Date: Mon, 27 Apr 2026 15:28:25 +0700 Subject: [PATCH] =?UTF-8?q?=D0=98=D0=B7=D0=BC=D0=B5=D0=BD=D0=B8=D0=BB=20?= =?UTF-8?q?=D1=81=D1=82=D1=80=D1=83=D0=BA=D1=82=D1=83=D1=80=D0=B0=20=D0=BF?= =?UTF-8?q?=D0=B0=D1=80=D1=81=D0=B8=D0=BD=D0=B3=D0=B0=20=D0=BF=D0=B5=D1=80?= =?UTF-8?q?=D0=B5=D0=BC=D0=B5=D0=BD=D0=BD=D1=8B=D1=85=20=D1=81=D1=80=D0=B5?= =?UTF-8?q?=D0=B4=D1=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 20 +- NN_server/Model.py | 6 +- NN_server/Models/ensemble_1200_2pic.py | 198 +------------ NN_server/Models/ensemble_2400_2pic.py | 198 +------------ NN_server/Models/ensemble_2_pic.py | 197 +++++++++++++ NN_server/Models/ensemble_915_2pic.py | 198 +------------ NN_server/server.py | 380 +++++++++++-------------- 7 files changed, 389 insertions(+), 808 deletions(-) create mode 100644 NN_server/Models/ensemble_2_pic.py diff --git a/.env.example b/.env.example index b9117bf..e7989d7 100644 --- a/.env.example +++ b/.env.example @@ -253,15 +253,29 @@ SERVER_PORT_2=8080 ################# # NN_SERVER +NN_MODEL_2400=ensemble_2_pic +NN_WEIGHTS_2400=${PATH_TO_NN}ensemble2400_2pic.pth +NN_CLASSES_2400=drone,noise +NN_MODEL_1200=ensemble_2_pic +NN_WEIGHTS_1200=${PATH_TO_NN}ensemble1200_2pic.pth +NN_CLASSES_1200=drone,noise +NN_MODEL_915=ensemble_2_pic +NN_WEIGHTS_915=${PATH_TO_NN}ensemble915_2pic.pth +NN_CLASSES_915=drone,noise +NN_BUILD_FUNC=build_func_ensemble +NN_PRE_FUNC=pre_func_ensemble +NN_INFERENCE_FUNC=inference_func_ensemble +NN_POST_FUNC=post_func_ensemble +NN_SYNTHETIC_EXAMPLES=10 +NN_SYNTHETIC_MIX_COUNT=1 +NN_SRC_DATASET=/app/NN_server/datasets/full_dataset/ + ################# FREQS=915,1200,2400 PATH_TO_NN=/app/NN_server/NN/ SRC_RESULT=/app/NN_server/result/ SRC_EXAMPLE=${PATH_TO_NN}example/ -NN_1='${PATH_TO_NN}resnet18_1.pth && ${PATH_TO_NN}config_resnet18.yaml && ${SRC_EXAMPLE} && ${SRC_RESULT} && Resnet18_1_2400 && build_func_resnet18 && pre_func_resnet18 && inference_func_resnet18 && post_func_resnet18 && [drone,noise,wifi] && 10 && 1 && /app/NN_server/datasets/full_dataset_pic/' -NN_21='${PATH_TO_NN}ensemble_1.2.pth && ${PATH_TO_NN}config_ensemble.yaml && ${SRC_EXAMPLE} && ${SRC_RESULT} && ensemble_1200 && build_func_ensemble && pre_func_ensemble && inference_func_ensemble && post_func_ensemble && [drone,noise] && 10 && 1 && /app/NN_server/datasets/full_dataset/' -NN_22='${PATH_TO_NN}ensemble_915.pth && ${PATH_TO_NN}config_ensemble.yaml && ${SRC_EXAMPLE} && ${SRC_RESULT} && ensemble_915 && build_func_ensemble && pre_func_ensemble && inference_func_ensemble && post_func_ensemble && [drone,noise] && 10 && 1 && /app/NN_server/datasets/full_dataset/' GENERAL_SERVER_IP=dronedetector-server-to-master GENERAL_SERVER_PORT=5010 diff --git a/NN_server/Model.py b/NN_server/Model.py index 776a658..fe495b3 100644 --- a/NN_server/Model.py +++ b/NN_server/Model.py @@ -99,10 +99,11 @@ class Model(object): except Exception as exc: print(str(exc)) - def __init__(self, file_model='', file_config='', src_example='', src_result='', type_model='', + def __init__(self, freq=0, file_model='', file_config='', src_example='', src_result='', type_model='', build_model_func=None, pre_func=None, inference_func=None, post_func=None, classes=None, number_synthetic_examples=0, number_src_data_for_one_synthetic_example=0, path_to_src_dataset=''): try: + self._freq = int(freq) self._file_model = file_model self._file_config = file_config self._src_example = src_example @@ -137,6 +138,9 @@ class Model(object): def get_mapping(self): return list(self._classes.values()) + def get_freq(self): + return self._freq + def get_model_name(self): return self._type_model diff --git a/NN_server/Models/ensemble_1200_2pic.py b/NN_server/Models/ensemble_1200_2pic.py index d805a4f..7e54253 100644 --- a/NN_server/Models/ensemble_1200_2pic.py +++ b/NN_server/Models/ensemble_1200_2pic.py @@ -1,197 +1 @@ -from torchvision import models -import torch.nn as nn -import matplotlib -import numpy as np -import torch -import cv2 -import gc -import io - - -def _render_plot(values, figsize=(16, 16), dpi=16): - import matplotlib.pyplot as plt - - fig = plt.figure(figsize=figsize) - plt.axes(ylim=(-1, 1)) - plt.plot(values, color="black") - plt.gca().set_axis_off() - plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) - plt.margins(0, 0) - - buf = io.BytesIO() - fig.savefig(buf, format="png", dpi=dpi) - buf.seek(0) - img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) - buf.close() - - img = cv2.imdecode(img_arr, 1) - if img is None: - raise RuntimeError("failed to decode plot image") - - plt.clf() - plt.cla() - plt.close() - plt.close(fig) - - return np.asarray(cv2.split(img), dtype=np.float32) - - -def pre_func_ensemble(data=None, src="", ind_inference=0): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - - real = np.asarray(data[0], dtype=np.float32) - imag = np.asarray(data[1], dtype=np.float32) - signal = real + 1j * imag - - img_real = _render_plot(signal.real) - img_mag = _render_plot(np.abs(signal)) - - cv2.destroyAllWindows() - gc.collect() - - print("Подготовка данных завершена") - print() - return [img_real, img_mag] - - except Exception as exc: - print(str(exc)) - return None - - -def build_func_ensemble(file_model="", file_config="", num_classes=None): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - torch.cuda.empty_cache() - - num_classes = 2 - model1 = models.resnet18(pretrained=False) - model2 = models.resnet50(pretrained=False) - - model1.fc = nn.Linear(model1.fc.in_features, num_classes) - model2.fc = nn.Linear(model2.fc.in_features, num_classes) - - class Ensemble(nn.Module): - def __init__(self, model1, model2): - super().__init__() - self.model1 = model1 - self.model2 = model2 - self.fc = nn.Linear(2 * num_classes, num_classes) - - def forward(self, x): - if isinstance(x, (list, tuple)): - x1 = x[0] - x2 = x[1] if len(x) > 1 else x[0] - else: - x1 = x - x2 = x - y1 = self.model1(x1) - y2 = self.model2(x2) - y = torch.cat((y1, y2), dim=1) - return self.fc(y) - - model = Ensemble(model1, model2) - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device != "cpu": - model = model.to(device) - model.load_state_dict(torch.load(file_model, map_location=device)) - model.eval() - - cv2.destroyAllWindows() - gc.collect() - - print("Инициализация модели завершена") - print() - return model - - except Exception as exc: - print(str(exc)) - return None - - -def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""): - try: - cv2.destroyAllWindows() - gc.collect() - torch.cuda.empty_cache() - - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(data, (list, tuple)) and len(data) >= 2: - inputs = [ - torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(), - torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(), - ] - else: - tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float() - inputs = [tensor, tensor] - - with torch.no_grad(): - output = model(inputs) - _, predict = torch.max(output.data, 1) - - prediction = mapping[int(np.asarray(predict.cpu())[0])] - print("PREDICTION" + shablon + ": " + prediction) - - output = output.cpu() - label = np.asarray(np.argmax(output, axis=1))[0] - output = np.asarray(torch.squeeze(output, 0)) - expon = np.exp(output - np.max(output)) - probability = round((expon / expon.sum())[label], 2) - - cv2.destroyAllWindows() - gc.collect() - print("Уверенность" + shablon + " в предсказании: " + str(probability)) - print("Инференс завершен") - print() - return [prediction, probability] - - except Exception as exc: - print(str(exc)) - return None - - -def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - - if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2: - fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[0], 0, -1)) - plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") - plt.clf() - plt.cla() - plt.close(fig) - cv2.destroyAllWindows() - gc.collect() - - fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[1], 0, -1)) - plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") - plt.clf() - plt.cla() - plt.close(fig) - cv2.destroyAllWindows() - gc.collect() - - plt.clf() - plt.cla() - plt.close() - cv2.destroyAllWindows() - gc.collect() - - print("Постобработка завершена") - print() - - except Exception as exc: - print(str(exc)) - return None +from Models.ensemble_2_pic import * diff --git a/NN_server/Models/ensemble_2400_2pic.py b/NN_server/Models/ensemble_2400_2pic.py index d805a4f..7e54253 100644 --- a/NN_server/Models/ensemble_2400_2pic.py +++ b/NN_server/Models/ensemble_2400_2pic.py @@ -1,197 +1 @@ -from torchvision import models -import torch.nn as nn -import matplotlib -import numpy as np -import torch -import cv2 -import gc -import io - - -def _render_plot(values, figsize=(16, 16), dpi=16): - import matplotlib.pyplot as plt - - fig = plt.figure(figsize=figsize) - plt.axes(ylim=(-1, 1)) - plt.plot(values, color="black") - plt.gca().set_axis_off() - plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) - plt.margins(0, 0) - - buf = io.BytesIO() - fig.savefig(buf, format="png", dpi=dpi) - buf.seek(0) - img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) - buf.close() - - img = cv2.imdecode(img_arr, 1) - if img is None: - raise RuntimeError("failed to decode plot image") - - plt.clf() - plt.cla() - plt.close() - plt.close(fig) - - return np.asarray(cv2.split(img), dtype=np.float32) - - -def pre_func_ensemble(data=None, src="", ind_inference=0): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - - real = np.asarray(data[0], dtype=np.float32) - imag = np.asarray(data[1], dtype=np.float32) - signal = real + 1j * imag - - img_real = _render_plot(signal.real) - img_mag = _render_plot(np.abs(signal)) - - cv2.destroyAllWindows() - gc.collect() - - print("Подготовка данных завершена") - print() - return [img_real, img_mag] - - except Exception as exc: - print(str(exc)) - return None - - -def build_func_ensemble(file_model="", file_config="", num_classes=None): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - torch.cuda.empty_cache() - - num_classes = 2 - model1 = models.resnet18(pretrained=False) - model2 = models.resnet50(pretrained=False) - - model1.fc = nn.Linear(model1.fc.in_features, num_classes) - model2.fc = nn.Linear(model2.fc.in_features, num_classes) - - class Ensemble(nn.Module): - def __init__(self, model1, model2): - super().__init__() - self.model1 = model1 - self.model2 = model2 - self.fc = nn.Linear(2 * num_classes, num_classes) - - def forward(self, x): - if isinstance(x, (list, tuple)): - x1 = x[0] - x2 = x[1] if len(x) > 1 else x[0] - else: - x1 = x - x2 = x - y1 = self.model1(x1) - y2 = self.model2(x2) - y = torch.cat((y1, y2), dim=1) - return self.fc(y) - - model = Ensemble(model1, model2) - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device != "cpu": - model = model.to(device) - model.load_state_dict(torch.load(file_model, map_location=device)) - model.eval() - - cv2.destroyAllWindows() - gc.collect() - - print("Инициализация модели завершена") - print() - return model - - except Exception as exc: - print(str(exc)) - return None - - -def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""): - try: - cv2.destroyAllWindows() - gc.collect() - torch.cuda.empty_cache() - - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(data, (list, tuple)) and len(data) >= 2: - inputs = [ - torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(), - torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(), - ] - else: - tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float() - inputs = [tensor, tensor] - - with torch.no_grad(): - output = model(inputs) - _, predict = torch.max(output.data, 1) - - prediction = mapping[int(np.asarray(predict.cpu())[0])] - print("PREDICTION" + shablon + ": " + prediction) - - output = output.cpu() - label = np.asarray(np.argmax(output, axis=1))[0] - output = np.asarray(torch.squeeze(output, 0)) - expon = np.exp(output - np.max(output)) - probability = round((expon / expon.sum())[label], 2) - - cv2.destroyAllWindows() - gc.collect() - print("Уверенность" + shablon + " в предсказании: " + str(probability)) - print("Инференс завершен") - print() - return [prediction, probability] - - except Exception as exc: - print(str(exc)) - return None - - -def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - - if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2: - fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[0], 0, -1)) - plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") - plt.clf() - plt.cla() - plt.close(fig) - cv2.destroyAllWindows() - gc.collect() - - fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[1], 0, -1)) - plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") - plt.clf() - plt.cla() - plt.close(fig) - cv2.destroyAllWindows() - gc.collect() - - plt.clf() - plt.cla() - plt.close() - cv2.destroyAllWindows() - gc.collect() - - print("Постобработка завершена") - print() - - except Exception as exc: - print(str(exc)) - return None +from Models.ensemble_2_pic import * diff --git a/NN_server/Models/ensemble_2_pic.py b/NN_server/Models/ensemble_2_pic.py new file mode 100644 index 0000000..d805a4f --- /dev/null +++ b/NN_server/Models/ensemble_2_pic.py @@ -0,0 +1,197 @@ +from torchvision import models +import torch.nn as nn +import matplotlib +import numpy as np +import torch +import cv2 +import gc +import io + + +def _render_plot(values, figsize=(16, 16), dpi=16): + import matplotlib.pyplot as plt + + fig = plt.figure(figsize=figsize) + plt.axes(ylim=(-1, 1)) + plt.plot(values, color="black") + plt.gca().set_axis_off() + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=dpi) + buf.seek(0) + img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) + buf.close() + + img = cv2.imdecode(img_arr, 1) + if img is None: + raise RuntimeError("failed to decode plot image") + + plt.clf() + plt.cla() + plt.close() + plt.close(fig) + + return np.asarray(cv2.split(img), dtype=np.float32) + + +def pre_func_ensemble(data=None, src="", ind_inference=0): + try: + import matplotlib.pyplot as plt + + matplotlib.use("Agg") + plt.ioff() + + real = np.asarray(data[0], dtype=np.float32) + imag = np.asarray(data[1], dtype=np.float32) + signal = real + 1j * imag + + img_real = _render_plot(signal.real) + img_mag = _render_plot(np.abs(signal)) + + cv2.destroyAllWindows() + gc.collect() + + print("Подготовка данных завершена") + print() + return [img_real, img_mag] + + except Exception as exc: + print(str(exc)) + return None + + +def build_func_ensemble(file_model="", file_config="", num_classes=None): + try: + import matplotlib.pyplot as plt + + matplotlib.use("Agg") + plt.ioff() + torch.cuda.empty_cache() + + num_classes = 2 + model1 = models.resnet18(pretrained=False) + model2 = models.resnet50(pretrained=False) + + model1.fc = nn.Linear(model1.fc.in_features, num_classes) + model2.fc = nn.Linear(model2.fc.in_features, num_classes) + + class Ensemble(nn.Module): + def __init__(self, model1, model2): + super().__init__() + self.model1 = model1 + self.model2 = model2 + self.fc = nn.Linear(2 * num_classes, num_classes) + + def forward(self, x): + if isinstance(x, (list, tuple)): + x1 = x[0] + x2 = x[1] if len(x) > 1 else x[0] + else: + x1 = x + x2 = x + y1 = self.model1(x1) + y2 = self.model2(x2) + y = torch.cat((y1, y2), dim=1) + return self.fc(y) + + model = Ensemble(model1, model2) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if device != "cpu": + model = model.to(device) + model.load_state_dict(torch.load(file_model, map_location=device)) + model.eval() + + cv2.destroyAllWindows() + gc.collect() + + print("Инициализация модели завершена") + print() + return model + + except Exception as exc: + print(str(exc)) + return None + + +def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""): + try: + cv2.destroyAllWindows() + gc.collect() + torch.cuda.empty_cache() + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(data, (list, tuple)) and len(data) >= 2: + inputs = [ + torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(), + torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(), + ] + else: + tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float() + inputs = [tensor, tensor] + + with torch.no_grad(): + output = model(inputs) + _, predict = torch.max(output.data, 1) + + prediction = mapping[int(np.asarray(predict.cpu())[0])] + print("PREDICTION" + shablon + ": " + prediction) + + output = output.cpu() + label = np.asarray(np.argmax(output, axis=1))[0] + output = np.asarray(torch.squeeze(output, 0)) + expon = np.exp(output - np.max(output)) + probability = round((expon / expon.sum())[label], 2) + + cv2.destroyAllWindows() + gc.collect() + print("Уверенность" + shablon + " в предсказании: " + str(probability)) + print("Инференс завершен") + print() + return [prediction, probability] + + except Exception as exc: + print(str(exc)) + return None + + +def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None): + try: + import matplotlib.pyplot as plt + + matplotlib.use("Agg") + plt.ioff() + + if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2: + fig, ax = plt.subplots() + ax.imshow(np.moveaxis(data[0], 0, -1)) + plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") + plt.clf() + plt.cla() + plt.close(fig) + cv2.destroyAllWindows() + gc.collect() + + fig, ax = plt.subplots() + ax.imshow(np.moveaxis(data[1], 0, -1)) + plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") + plt.clf() + plt.cla() + plt.close(fig) + cv2.destroyAllWindows() + gc.collect() + + plt.clf() + plt.cla() + plt.close() + cv2.destroyAllWindows() + gc.collect() + + print("Постобработка завершена") + print() + + except Exception as exc: + print(str(exc)) + return None diff --git a/NN_server/Models/ensemble_915_2pic.py b/NN_server/Models/ensemble_915_2pic.py index d805a4f..7e54253 100644 --- a/NN_server/Models/ensemble_915_2pic.py +++ b/NN_server/Models/ensemble_915_2pic.py @@ -1,197 +1 @@ -from torchvision import models -import torch.nn as nn -import matplotlib -import numpy as np -import torch -import cv2 -import gc -import io - - -def _render_plot(values, figsize=(16, 16), dpi=16): - import matplotlib.pyplot as plt - - fig = plt.figure(figsize=figsize) - plt.axes(ylim=(-1, 1)) - plt.plot(values, color="black") - plt.gca().set_axis_off() - plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) - plt.margins(0, 0) - - buf = io.BytesIO() - fig.savefig(buf, format="png", dpi=dpi) - buf.seek(0) - img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) - buf.close() - - img = cv2.imdecode(img_arr, 1) - if img is None: - raise RuntimeError("failed to decode plot image") - - plt.clf() - plt.cla() - plt.close() - plt.close(fig) - - return np.asarray(cv2.split(img), dtype=np.float32) - - -def pre_func_ensemble(data=None, src="", ind_inference=0): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - - real = np.asarray(data[0], dtype=np.float32) - imag = np.asarray(data[1], dtype=np.float32) - signal = real + 1j * imag - - img_real = _render_plot(signal.real) - img_mag = _render_plot(np.abs(signal)) - - cv2.destroyAllWindows() - gc.collect() - - print("Подготовка данных завершена") - print() - return [img_real, img_mag] - - except Exception as exc: - print(str(exc)) - return None - - -def build_func_ensemble(file_model="", file_config="", num_classes=None): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - torch.cuda.empty_cache() - - num_classes = 2 - model1 = models.resnet18(pretrained=False) - model2 = models.resnet50(pretrained=False) - - model1.fc = nn.Linear(model1.fc.in_features, num_classes) - model2.fc = nn.Linear(model2.fc.in_features, num_classes) - - class Ensemble(nn.Module): - def __init__(self, model1, model2): - super().__init__() - self.model1 = model1 - self.model2 = model2 - self.fc = nn.Linear(2 * num_classes, num_classes) - - def forward(self, x): - if isinstance(x, (list, tuple)): - x1 = x[0] - x2 = x[1] if len(x) > 1 else x[0] - else: - x1 = x - x2 = x - y1 = self.model1(x1) - y2 = self.model2(x2) - y = torch.cat((y1, y2), dim=1) - return self.fc(y) - - model = Ensemble(model1, model2) - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device != "cpu": - model = model.to(device) - model.load_state_dict(torch.load(file_model, map_location=device)) - model.eval() - - cv2.destroyAllWindows() - gc.collect() - - print("Инициализация модели завершена") - print() - return model - - except Exception as exc: - print(str(exc)) - return None - - -def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""): - try: - cv2.destroyAllWindows() - gc.collect() - torch.cuda.empty_cache() - - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(data, (list, tuple)) and len(data) >= 2: - inputs = [ - torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(), - torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(), - ] - else: - tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float() - inputs = [tensor, tensor] - - with torch.no_grad(): - output = model(inputs) - _, predict = torch.max(output.data, 1) - - prediction = mapping[int(np.asarray(predict.cpu())[0])] - print("PREDICTION" + shablon + ": " + prediction) - - output = output.cpu() - label = np.asarray(np.argmax(output, axis=1))[0] - output = np.asarray(torch.squeeze(output, 0)) - expon = np.exp(output - np.max(output)) - probability = round((expon / expon.sum())[label], 2) - - cv2.destroyAllWindows() - gc.collect() - print("Уверенность" + shablon + " в предсказании: " + str(probability)) - print("Инференс завершен") - print() - return [prediction, probability] - - except Exception as exc: - print(str(exc)) - return None - - -def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None): - try: - import matplotlib.pyplot as plt - - matplotlib.use("Agg") - plt.ioff() - - if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2: - fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[0], 0, -1)) - plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") - plt.clf() - plt.cla() - plt.close(fig) - cv2.destroyAllWindows() - gc.collect() - - fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[1], 0, -1)) - plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") - plt.clf() - plt.cla() - plt.close(fig) - cv2.destroyAllWindows() - gc.collect() - - plt.clf() - plt.cla() - plt.close() - cv2.destroyAllWindows() - gc.collect() - - print("Постобработка завершена") - print() - - except Exception as exc: - print(str(exc)) - return None +from Models.ensemble_2_pic import * diff --git a/NN_server/server.py b/NN_server/server.py index ff0842e..1900052 100644 --- a/NN_server/server.py +++ b/NN_server/server.py @@ -3,7 +3,6 @@ from dotenv import dotenv_values from common.runtime import load_root_env, validate_env, as_int, as_str import os import sys -import re import matplotlib.pyplot as plt from Model import Model import numpy as np @@ -11,7 +10,6 @@ import matplotlib import importlib import threading import requests -import asyncio import shutil import json import gc @@ -19,17 +17,11 @@ import logging TORCHSIG_PATH = "/app/torchsig" if TORCHSIG_PATH not in sys.path: - # Ensure import torchsig resolves to /app/torchsig/torchsig package. sys.path.insert(0, TORCHSIG_PATH) logging.basicConfig(level=logging.INFO) app = Flask(__name__) -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) -queue = asyncio.Queue() -semaphore = asyncio.Semaphore(3) - prediction_list = [] result_msg = {} results = [] @@ -44,17 +36,13 @@ validate_env("NN_server/server.py", { "GENERAL_SERVER_PORT": as_int, "SERVER_IP": as_str, "SERVER_PORT": as_int, - "PATH_TO_NN": as_str, "SRC_RESULT": as_str, "SRC_EXAMPLE": as_str, + "FREQS": as_str, }) config = dict(dotenv_values(ROOT_ENV)) -def is_model_config_key(key, value): - return bool(re.fullmatch(r"NN_\d+", key or "")) and isinstance(value, str) and " && " in value - - def get_required_drone_streak(freq): return config.get(f"DRONE_STREAK_{freq}", "1") @@ -88,52 +76,133 @@ def update_drone_streak(freq, prediction, drone_probability): return 8 if triggered else 0 +def parse_freqs(raw_value): + freqs = [] + for item in (raw_value or "").split(','): + item = item.strip() + if not item: + continue + freqs.append(int(item)) + if not freqs: + raise RuntimeError("[NN_server/server.py] no NN frequencies configured in FREQS") + return freqs + + +def parse_classes(raw_value): + if raw_value is None: + raise RuntimeError("[NN_server/server.py] model classes are missing") + value = raw_value.strip() + if value.startswith('[') and value.endswith(']'): + value = value[1:-1] + classes = {} + for class_name in value.split(','): + class_name = class_name.strip() + if class_name: + classes[len(classes)] = class_name + if not classes: + raise RuntimeError("[NN_server/server.py] no classes parsed from NN_CLASSES_*") + return classes + + +def get_required_config(key): + value = config.get(key) + if value is None: + raise RuntimeError(f"[NN_server/server.py] missing required env key: {key}") + value = str(value).strip() + if not value: + raise RuntimeError(f"[NN_server/server.py] empty required env key: {key}") + return value + + +def get_optional_config(key, default=''): + value = config.get(key) + if value is None: + return default + return str(value).strip() + + +def build_model_specs(): + build_func_name = get_optional_config('NN_BUILD_FUNC', 'build_func_ensemble') + pre_func_name = get_optional_config('NN_PRE_FUNC', 'pre_func_ensemble') + inference_func_name = get_optional_config('NN_INFERENCE_FUNC', 'inference_func_ensemble') + post_func_name = get_optional_config('NN_POST_FUNC', 'post_func_ensemble') + src_example = get_optional_config('NN_SRC_EXAMPLE', config['SRC_EXAMPLE']) + src_result = get_optional_config('NN_SRC_RESULT', config['SRC_RESULT']) + synthetic_examples = int(get_optional_config('NN_SYNTHETIC_EXAMPLES', '0')) + synthetic_mix_count = int(get_optional_config('NN_SYNTHETIC_MIX_COUNT', '1')) + src_dataset = get_optional_config('NN_SRC_DATASET', '') + + specs = [] + for freq in parse_freqs(config.get('NN_FREQS', config.get('FREQS', ''))): + module_name = get_required_config(f'NN_MODEL_{freq}') + weights = get_required_config(f'NN_WEIGHTS_{freq}') + classes = parse_classes(get_required_config(f'NN_CLASSES_{freq}')) + file_config = get_optional_config(f'NN_CONFIG_{freq}', get_optional_config('NN_CONFIG', '')) + specs.append({ + 'freq': freq, + 'module_name': module_name, + 'weights': weights, + 'config': file_config, + 'classes': classes, + 'src_example': src_example, + 'src_result': src_result, + 'build_func_name': build_func_name, + 'pre_func_name': pre_func_name, + 'inference_func_name': inference_func_name, + 'post_func_name': post_func_name, + 'synthetic_examples': synthetic_examples, + 'synthetic_mix_count': synthetic_mix_count, + 'src_dataset': src_dataset, + }) + return specs + + if not config: raise RuntimeError("[NN_server/server.py] .env was loaded but no keys were parsed") -if not any(is_model_config_key(key, value) for key, value in config.items()): - raise RuntimeError("[NN_server/server.py] no NN_* model entries configured") logging.info("NN config loaded from %s", ROOT_ENV) gen_server_ip = config['GENERAL_SERVER_IP'] gen_server_port = config['GENERAL_SERVER_PORT'] drone_streaks = {} +MODEL_SPECS = build_model_specs() + + +def recreate_directory(path): + if os.path.isdir(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + def init_data_for_inference(): try: - if os.path.isdir(config['SRC_RESULT']): - shutil.rmtree(config['SRC_RESULT']) - os.mkdir(config['SRC_RESULT']) - if os.path.isdir(config['SRC_EXAMPLE']): - shutil.rmtree(config['SRC_EXAMPLE']) - os.mkdir(config['SRC_EXAMPLE']) + if MODEL_SPECS: + recreate_directory(MODEL_SPECS[0]['src_result']) + recreate_directory(MODEL_SPECS[0]['src_example']) except Exception as exc: print(str(exc)) print() try: global model_list - for key, value in config.items(): - if is_model_config_key(key, value): - params = value.split(' && ') - module = importlib.import_module('Models.' + params[4]) - classes = {} - for value in params[9][1:-1].split(','): - classes[len(classes)] = value - model = Model(file_model=params[0], file_config=params[1], src_example=params[2], src_result=params[3], - type_model=params[4], build_model_func=getattr(module, params[5]), - pre_func=getattr(module, params[6]), inference_func=getattr(module, params[7]), - post_func=getattr(module, params[8]), classes=classes, number_synthetic_examples=int(params[10]), - number_src_data_for_one_synthetic_example=int(params[11]), path_to_src_dataset=params[12]) - model_list.append(model) - # if key.startswith('ALG_'): - # params = config[key].split(' && ') - # module = importlib.import_module('Algorithms.' + params[2]) - # classes = {} - # for value in params[6][1:-1].split(','): - # classes[len(classes)] = value - # alg = Algorithm(src_example=params[0], src_result=params[1], type_alg=params[2], pre_func=getattr(module, params[3]), - # inference_func=getattr(module, params[4]), post_func=getattr(module, params[5]), classes=classes, - # number_synthetic_examples=int(params[7]), number_src_data_for_one_synthetic_example=int(params[8]), path_to_src_dataset=params[9]) - # alg_list.append(alg) + model_list.clear() + for spec in MODEL_SPECS: + module = importlib.import_module('Models.' + spec['module_name']) + model = Model( + freq=spec['freq'], + file_model=spec['weights'], + file_config=spec['config'], + src_example=spec['src_example'], + src_result=spec['src_result'], + type_model=f"{spec['module_name']}@{spec['freq']}", + build_model_func=getattr(module, spec['build_func_name']), + pre_func=getattr(module, spec['pre_func_name']), + inference_func=getattr(module, spec['inference_func_name']), + post_func=getattr(module, spec['post_func_name']), + classes=spec['classes'], + number_synthetic_examples=spec['synthetic_examples'], + number_src_data_for_one_synthetic_example=spec['synthetic_mix_count'], + path_to_src_dataset=spec['src_dataset'], + ) + model_list.append(model) except Exception as exc: print(str(exc)) print() @@ -147,6 +216,13 @@ def run_example(): print(str(exc)) +def find_model_for_freq(freq): + for model in model_list: + if model.get_freq() == freq: + return model + return None + + @app.route('/receive_data', methods=['POST']) def receive_data(): try: @@ -156,52 +232,52 @@ def receive_data(): print('Получен пакет ' + str(Model.get_ind_inference())) freq = int(data['freq']) print('Частота: ' + str(freq)) - # print('Канал: ' + str(data['channel'])) result_msg = {} data_to_send = {} prediction_list = [] - #print(model_list) - for model in model_list: - #print(str(freq)) - #print(model.get_model_name()) - if str(freq) in model.get_model_name(): - print('-' * 100) - print(str(model)) - result_msg[str(model.get_model_name())] = {'freq': freq} - inference_result = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)]) - if inference_result is None: - raise RuntimeError(f"Inference failed for {model.get_model_name()}") - prediction, probability = inference_result[:2] - drone_probability = float(probability) if prediction == "drone" else 0.0 - result_msg[str(model.get_model_name())]['prediction'] = prediction - result_msg[str(model.get_model_name())]['probability'] = str(probability) - result_msg[str(model.get_model_name())]['drone_probability'] = str(drone_probability) - result_msg[str(model.get_model_name())]['drone_threshold'] = str(get_required_drone_prob(freq)) - prediction_list.append(prediction) - print('-' * 100) - print() - - try: - result = update_drone_streak(freq, prediction, drone_probability) - data_to_send={ - 'freq': str(freq), - #'channel': int(data['channel']), - 'amplitude': result - #'triggered': False if result < 7 else True, - #'light_len': result - } - response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send) - if response.status_code == 200: - print("Данные успешно отправлены!") - print("Частота: " + str(freq)) - print("Отправлено светодиодов: " + str(result)) - else: - print("Ошибка при отправке данных: ", response.status_code) - except Exception as exc: - print(str(exc)) - break - + model = find_model_for_freq(freq) + if model is None: + raise RuntimeError(f"No NN model configured for freq={freq}") + + print('-' * 100) + print(str(model)) + result_msg[str(model.get_model_name())] = {'freq': freq} + inference_result = model.get_inference([ + np.asarray(data['data_real'], dtype=np.float32), + np.asarray(data['data_imag'], dtype=np.float32), + ]) + if inference_result is None: + raise RuntimeError(f"Inference failed for {model.get_model_name()}") + prediction, probability = inference_result[:2] + drone_probability = float(probability) if prediction == "drone" else 0.0 + result_msg[str(model.get_model_name())]['prediction'] = prediction + result_msg[str(model.get_model_name())]['probability'] = str(probability) + result_msg[str(model.get_model_name())]['drone_probability'] = str(drone_probability) + result_msg[str(model.get_model_name())]['drone_threshold'] = str(get_required_drone_prob(freq)) + prediction_list.append(prediction) + print('-' * 100) + print() + + try: + result = update_drone_streak(freq, prediction, drone_probability) + data_to_send = { + 'freq': str(freq), + 'amplitude': result, + } + response = requests.post( + "http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), + json=data_to_send, + ) + if response.status_code == 200: + print("Данные успешно отправлены!") + print("Частота: " + str(freq)) + print("Отправлено светодиодов: " + str(result)) + else: + print("Ошибка при отправке данных: ", response.status_code) + except Exception as exc: + print(str(exc)) + Model.get_inc_ind_inference() print() print('#' * 100) @@ -209,14 +285,16 @@ def receive_data(): for alg in alg_list: print('-' * 100) print(str(alg)) - alg.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)]) + alg.get_inference([ + np.asarray(data['data_real'], dtype=np.float32), + np.asarray(data['data_imag'], dtype=np.float32), + ]) print('-' * 100) print() - #Algorithm.get_inc_ind_inference() print() print('#' * 100) - + del data gc.collect() @@ -226,128 +304,6 @@ def receive_data(): print(str(exc)) -''' -def run_flask(): - app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT'])) - - -async def process_tasks(): - workers = [asyncio.create_task(worker(queue=queue, semaphore=semaphore)) for _ in range(2)] - await asyncio.gather(*workers) - - -async def main(): - asyncio.create_task(process_tasks()) - - flask_thread = threading.Thread(target=run_flask) - flask_thread.start() - - while True: - if queue.qsize() <= 1: - asyncio.create_task(process_tasks()) - await asyncio.sleep(1) - - -@app.route('/receive_data', methods=['POST']) -def add_task(): - queue_size = queue.qsize() - if queue_size > 1: - return {} - - print() - data = json.loads(request.json) - print('#' * 100) - print('Получен пакет ' + str(Model.get_ind_inference())) - freq = int(data['freq']) - print('Частота ' + str(freq)) - - - result_msg = {} - for model in model_list: - if str(freq) in model.get_model_name(): - print('-' * 100) - print(str(model)) - result_msg[str(model.get_model_name())] = {'freq': freq} - asyncio.run_coroutine_threadsafe(queue.put({'freq': freq, 'model': model, 'data': data}), loop) - do_inference(model=model, data=data, freq=freq) - break - - del data - gc.collect() - return jsonify(result_msg) - - -async def worker(queue, semaphore): - while True: - task = await queue.get() - if task is None: - break - async with semaphore: - try: - await do_inference(model=task['model'], data=task['data'], freq=task['freq']) - except Exception as e: - print(str(e)) - print(results) - queue.task_done() - - -async def do_inference(model=None, data=None, freq=0): - prediction_list = [] - print("Длина очереди" + str(queue.qsize())) - inference(model=model, data=data, freq=freq) - - try: - results = [] - for pred in prediction_list: - if pred[1] == 'drone': - results.append([pred[0],8]) - else: - results.append([pred[0],0]) - for result in results: - try: - data_to_send={ - 'freq': result[0], - 'amplitude': result[1], - 'triggered': False if result[1] < 7 else True, - 'light_len': result[1] - } - response = requests.post("http://{0}:{1}/process_data".format(gen_server_ip, gen_server_port), json=data_to_send) - await response.text - if response.status_code == 200: - print("Данные успешно отправлены!") - print("Отправлено светодиодов: " + str(data_to_send['light_len'])) - else: - print("Ошибка при отправке данных: ", response.status_code) - except Exception as exc: - print(str(exc)) - except Exception as exc: - print(str(exc)) - - Model.get_inc_ind_inference() - print() - print('#' * 100) - - del data - gc.collect() - - -def inference(model=None, data=None, freq=0): - prediction, probability = model.get_inference([np.asarray(data['data_real'], dtype=np.float32), np.asarray(data['data_imag'], dtype=np.float32)]) - result_msg[str(model.get_model_name())]['prediction'] = prediction - result_msg[str(model.get_model_name())]['probability'] = str(probability) - queue_size = queue.qsize() - print(queue_size) - prediction_list.append([freq, prediction]) - print('-' * 100) - print() - - -if __name__ == '__main__': - init_data_for_inference() - #asyncio.run(main) - loop.run_until_complete(main()) -''' - def run_flask(): print(config['SERVER_IP']) app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT'])) @@ -358,5 +314,3 @@ if __name__ == '__main__': flask_thread = threading.Thread(target=run_flask) flask_thread.start() - - #app.run(host=config['SERVER_IP'], port=int(config['SERVER_PORT']))