You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/NN_server/Models/ensemble_915_v44.py

250 lines
7.3 KiB
Python

from torchvision import models
import torch.nn as nn
import matplotlib
import numpy as np
import torch
import cv2
import gc
import io
import os
import re
def _as_training_colormap_image(image):
arr = np.asarray(image)
if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}:
arr = np.moveaxis(arr, 0, -1)
if arr.ndim == 3:
arr = arr[..., :3].astype(np.float32)
arr = 0.299 * arr[..., 0] + 0.587 * arr[..., 1] + 0.114 * arr[..., 2]
return np.nan_to_num(arr.astype(np.float32), nan=0.0, posinf=255.0, neginf=0.0)
def _prune_old_inference_images(src, model_type, model_id, keep_last=200):
try:
keep_last = int(os.getenv("INFERENCE_IMAGE_KEEP_LAST", str(keep_last)))
except ValueError:
keep_last = keep_last
if keep_last <= 0 or not src or not os.path.isdir(src):
return
pattern = re.compile(
r"_inference_(\d+)_.*_"
+ re.escape(str(model_id))
+ "_"
+ re.escape(str(model_type))
+ r"\.png$"
)
grouped = {}
for name in os.listdir(src):
match = pattern.match(name)
if match is None:
continue
grouped.setdefault(int(match.group(1)), []).append(name)
if len(grouped) <= keep_last:
return
for old_result_id in sorted(grouped)[: len(grouped) - keep_last]:
for name in grouped[old_result_id]:
try:
os.remove(os.path.join(src, name))
except FileNotFoundError:
pass
except OSError as exc:
print(f"failed to remove old inference image {name}: {exc}")
def _render_plot(values, figsize=(16, 16), dpi=16):
import matplotlib.pyplot as plt
fig = plt.figure(figsize=figsize)
plt.axes(ylim=(-1, 1))
plt.plot(values, color="black")
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
if img is None:
raise RuntimeError("failed to decode plot image")
plt.clf()
plt.cla()
plt.close()
plt.close(fig)
return np.asarray(cv2.split(img), dtype=np.float32)
def pre_func_ensemble(data=None, src="", ind_inference=0):
try:
import matplotlib.pyplot as plt
matplotlib.use("Agg")
plt.ioff()
real = np.asarray(data[0], dtype=np.float32)
imag = np.asarray(data[1], dtype=np.float32)
signal = real + 1j * imag
img_real = _render_plot(signal.real)
img_mag = _render_plot(np.abs(signal))
cv2.destroyAllWindows()
gc.collect()
print("Подготовка данных завершена")
print()
return [img_real, img_mag]
except Exception as exc:
print(str(exc))
return None
def build_func_ensemble(file_model="", file_config="", num_classes=None):
try:
import matplotlib.pyplot as plt
matplotlib.use("Agg")
plt.ioff()
torch.cuda.empty_cache()
num_classes = 2
model1 = models.resnet18(pretrained=False)
model2 = models.resnet50(pretrained=False)
model1.fc = nn.Linear(model1.fc.in_features, num_classes)
model2.fc = nn.Linear(model2.fc.in_features, num_classes)
class Ensemble(nn.Module):
def __init__(self, model1, model2):
super().__init__()
self.model1 = model1
self.model2 = model2
self.fc = nn.Linear(2 * num_classes, num_classes)
def forward(self, x):
if isinstance(x, (list, tuple)):
x1 = x[0]
x2 = x[1] if len(x) > 1 else x[0]
else:
x1 = x
x2 = x
y1 = self.model1(x1)
y2 = self.model2(x2)
y = torch.cat((y1, y2), dim=1)
return self.fc(y)
model = Ensemble(model1, model2)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cpu":
model = model.to(device)
model.load_state_dict(torch.load(file_model, map_location=device))
model.eval()
cv2.destroyAllWindows()
gc.collect()
print("Инициализация модели завершена")
print()
return model
except Exception as exc:
print(str(exc))
return None
def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""):
try:
cv2.destroyAllWindows()
gc.collect()
torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(data, (list, tuple)) and len(data) >= 2:
inputs = [
torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(),
torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(),
]
else:
tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float()
inputs = [tensor, tensor]
with torch.no_grad():
output = model(inputs)
_, predict = torch.max(output.data, 1)
prediction = mapping[int(np.asarray(predict.cpu())[0])]
print("PREDICTION" + shablon + ": " + prediction)
output = output.cpu()
label = np.asarray(np.argmax(output, axis=1))[0]
output = np.asarray(torch.squeeze(output, 0))
expon = np.exp(output - np.max(output))
probability = round((expon / expon.sum())[label], 2)
cv2.destroyAllWindows()
gc.collect()
print("Уверенность" + shablon + " в предсказании: " + str(probability))
print("Инференс завершен")
print()
return [prediction, probability]
except Exception as exc:
print(str(exc))
return None
def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None):
try:
import matplotlib.pyplot as plt
matplotlib.use("Agg")
plt.ioff()
if isinstance(data, (list, tuple)) and len(data) >= 2:
fig, ax = plt.subplots()
ax.imshow(_as_training_colormap_image(data[0]), cmap="plasma")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
plt.clf()
plt.cla()
plt.close(fig)
cv2.destroyAllWindows()
gc.collect()
fig, ax = plt.subplots()
ax.imshow(_as_training_colormap_image(data[1]), cmap="plasma")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
plt.clf()
plt.cla()
plt.close(fig)
cv2.destroyAllWindows()
gc.collect()
_prune_old_inference_images(src, model_type, model_id)
plt.clf()
plt.cla()
plt.close()
cv2.destroyAllWindows()
gc.collect()
print("Постобработка завершена")
print()
except Exception as exc:
print(str(exc))
return None