from torchvision import models import torch.nn as nn import matplotlib import numpy as np import torch import cv2 import gc import io import os import re def _as_training_colormap_image(image): arr = np.asarray(image) if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}: arr = np.moveaxis(arr, 0, -1) if arr.ndim == 3: arr = arr[..., :3].astype(np.float32) arr = 0.299 * arr[..., 0] + 0.587 * arr[..., 1] + 0.114 * arr[..., 2] return np.nan_to_num(arr.astype(np.float32), nan=0.0, posinf=255.0, neginf=0.0) def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)): import matplotlib.pyplot as plt fig = plt.figure(figsize=figsize) plt.axes(ylim=(-1, 1)) plt.plot(values, color="black") plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) buf = io.BytesIO() fig.savefig(buf, format="png", dpi=dpi) buf.seek(0) img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) buf.close() img = cv2.imdecode(img_arr, 1) if img is None: raise RuntimeError("failed to decode plot image") img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if resize is not None: img = cv2.resize(img, resize) plt.clf() plt.cla() plt.close() plt.close(fig) return img def _render_training_png(image): import matplotlib.pyplot as plt fig = plt.figure() plt.imshow(image) buf = io.BytesIO() fig.savefig(buf, format="png") buf.seek(0) img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) buf.close() img = cv2.imdecode(img_arr, 1) if img is None: raise RuntimeError("failed to decode training-style image") plt.clf() plt.cla() plt.close() plt.close(fig) return np.asarray(cv2.split(img), dtype=np.float32) def _prune_old_inference_images(src, model_type, model_id, keep_last=200): try: keep_last = int(os.getenv("INFERENCE_IMAGE_KEEP_LAST", str(keep_last))) except ValueError: keep_last = keep_last if keep_last <= 0 or not src or not os.path.isdir(src): return pattern = re.compile( r"_inference_(\d+)_.*_" + re.escape(str(model_id)) + "_" + re.escape(str(model_type)) + r"\.png$" ) grouped = {} for name in os.listdir(src): match = pattern.match(name) if match is None: continue grouped.setdefault(int(match.group(1)), []).append(name) if len(grouped) <= keep_last: return for old_result_id in sorted(grouped)[: len(grouped) - keep_last]: for name in grouped[old_result_id]: try: os.remove(os.path.join(src, name)) except FileNotFoundError: pass except OSError as exc: print(f"failed to remove old inference image {name}: {exc}") def pre_func_ensemble(data=None, src="", ind_inference=0): try: import matplotlib.pyplot as plt matplotlib.use("Agg") plt.ioff() real = np.asarray(data[0], dtype=np.float32) imag = np.asarray(data[1], dtype=np.float32) signal = real + 1j * imag img_real = _render_training_png(_render_signal_channel(signal.real)) img_mag = _render_training_png(_render_signal_channel(np.abs(signal))) cv2.destroyAllWindows() gc.collect() print("Подготовка данных завершена") print() return [img_real, img_mag] except Exception as exc: print(str(exc)) return None def build_func_ensemble(file_model="", file_config="", num_classes=None): try: import matplotlib.pyplot as plt matplotlib.use("Agg") plt.ioff() torch.cuda.empty_cache() num_classes = 2 model1 = models.resnet18(pretrained=False) model2 = models.resnet50(pretrained=False) model1.fc = nn.Linear(model1.fc.in_features, num_classes) model2.fc = nn.Linear(model2.fc.in_features, num_classes) class Ensemble(nn.Module): def __init__(self, model1, model2): super().__init__() self.model1 = model1 self.model2 = model2 self.fc = nn.Linear(2 * num_classes, num_classes) def forward(self, x): if isinstance(x, (list, tuple)): x1 = x[0] x2 = x[1] if len(x) > 1 else x[0] else: x1 = x x2 = x y1 = self.model1(x1) y2 = self.model2(x2) y = torch.cat((y1, y2), dim=1) return self.fc(y) model = Ensemble(model1, model2) device = "cuda" if torch.cuda.is_available() else "cpu" if device != "cpu": model = model.to(device) model.load_state_dict(torch.load(file_model, map_location=device)) model.eval() cv2.destroyAllWindows() gc.collect() print("Инициализация модели завершена") print() return model except Exception as exc: print(str(exc)) return None def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""): try: cv2.destroyAllWindows() gc.collect() torch.cuda.empty_cache() device = "cuda" if torch.cuda.is_available() else "cpu" if isinstance(data, (list, tuple)) and len(data) >= 2: inputs = [ torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(), torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(), ] else: tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float() inputs = [tensor, tensor] with torch.no_grad(): output = model(inputs) _, predict = torch.max(output.data, 1) prediction = mapping[int(np.asarray(predict.cpu())[0])] print("PREDICTION" + shablon + ": " + prediction) output = output.cpu() label = np.asarray(np.argmax(output, axis=1))[0] output = np.asarray(torch.squeeze(output, 0)) expon = np.exp(output - np.max(output)) probability = round((expon / expon.sum())[label], 2) cv2.destroyAllWindows() gc.collect() print("Уверенность" + shablon + " в предсказании: " + str(probability)) print("Инференс завершен") print() return [prediction, probability] except Exception as exc: print(str(exc)) return None def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None): try: import matplotlib.pyplot as plt matplotlib.use("Agg") plt.ioff() if isinstance(data, (list, tuple)) and len(data) >= 2: fig, ax = plt.subplots() ax.imshow(_as_training_colormap_image(data[0]), cmap="viridis") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla() plt.close(fig) cv2.destroyAllWindows() gc.collect() fig, ax = plt.subplots() ax.imshow(_as_training_colormap_image(data[1]), cmap="viridis") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla() plt.close(fig) cv2.destroyAllWindows() gc.collect() _prune_old_inference_images(src, model_type, model_id) plt.clf() plt.cla() plt.close() cv2.destroyAllWindows() gc.collect() print("Постобработка завершена") print() except Exception as exc: print(str(exc)) return None