You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
266 lines
7.5 KiB
Python
266 lines
7.5 KiB
Python
from torchvision import models
|
|
import torch.nn as nn
|
|
import matplotlib
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
import gc
|
|
import io
|
|
import os
|
|
import re
|
|
|
|
|
|
def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)):
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig = plt.figure(figsize=figsize)
|
|
plt.axes(ylim=(-1, 1))
|
|
plt.plot(values, color="black")
|
|
plt.gca().set_axis_off()
|
|
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
plt.margins(0, 0)
|
|
|
|
buf = io.BytesIO()
|
|
fig.savefig(buf, format="png", dpi=dpi)
|
|
buf.seek(0)
|
|
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
|
|
buf.close()
|
|
|
|
img = cv2.imdecode(img_arr, 1)
|
|
if img is None:
|
|
raise RuntimeError("failed to decode plot image")
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
if resize is not None:
|
|
img = cv2.resize(img, resize)
|
|
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
plt.close(fig)
|
|
|
|
return img
|
|
|
|
|
|
def _render_training_png(image):
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig = plt.figure()
|
|
plt.imshow(image)
|
|
|
|
buf = io.BytesIO()
|
|
fig.savefig(buf, format="png")
|
|
buf.seek(0)
|
|
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
|
|
buf.close()
|
|
|
|
img = cv2.imdecode(img_arr, 1)
|
|
if img is None:
|
|
raise RuntimeError("failed to decode training-style image")
|
|
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
plt.close(fig)
|
|
|
|
return np.asarray(cv2.split(img), dtype=np.float32)
|
|
|
|
|
|
def _prune_old_inference_images(src, model_type, model_id, keep_last=200):
|
|
try:
|
|
keep_last = int(os.getenv("INFERENCE_IMAGE_KEEP_LAST", str(keep_last)))
|
|
except ValueError:
|
|
keep_last = keep_last
|
|
|
|
if keep_last <= 0 or not src or not os.path.isdir(src):
|
|
return
|
|
|
|
pattern = re.compile(
|
|
r"_inference_(\d+)_.*_"
|
|
+ re.escape(str(model_id))
|
|
+ "_"
|
|
+ re.escape(str(model_type))
|
|
+ r"\.png$"
|
|
)
|
|
grouped = {}
|
|
for name in os.listdir(src):
|
|
match = pattern.match(name)
|
|
if match is None:
|
|
continue
|
|
grouped.setdefault(int(match.group(1)), []).append(name)
|
|
|
|
if len(grouped) <= keep_last:
|
|
return
|
|
|
|
for old_result_id in sorted(grouped)[: len(grouped) - keep_last]:
|
|
for name in grouped[old_result_id]:
|
|
try:
|
|
os.remove(os.path.join(src, name))
|
|
except FileNotFoundError:
|
|
pass
|
|
except OSError as exc:
|
|
print(f"failed to remove old inference image {name}: {exc}")
|
|
|
|
|
|
def pre_func_ensemble(data=None, src="", ind_inference=0):
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
|
|
matplotlib.use("Agg")
|
|
plt.ioff()
|
|
|
|
real = np.asarray(data[0], dtype=np.float32)
|
|
imag = np.asarray(data[1], dtype=np.float32)
|
|
signal = real + 1j * imag
|
|
|
|
img_real = _render_training_png(_render_signal_channel(signal.real))
|
|
img_mag = _render_training_png(_render_signal_channel(np.abs(signal)))
|
|
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
print("Подготовка данных завершена")
|
|
print()
|
|
return [img_real, img_mag]
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
return None
|
|
|
|
|
|
def build_func_ensemble(file_model="", file_config="", num_classes=None):
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
|
|
matplotlib.use("Agg")
|
|
plt.ioff()
|
|
torch.cuda.empty_cache()
|
|
|
|
num_classes = 2
|
|
model1 = models.resnet18(pretrained=False)
|
|
model2 = models.resnet50(pretrained=False)
|
|
|
|
model1.fc = nn.Linear(model1.fc.in_features, num_classes)
|
|
model2.fc = nn.Linear(model2.fc.in_features, num_classes)
|
|
|
|
class Ensemble(nn.Module):
|
|
def __init__(self, model1, model2):
|
|
super().__init__()
|
|
self.model1 = model1
|
|
self.model2 = model2
|
|
self.fc = nn.Linear(2 * num_classes, num_classes)
|
|
|
|
def forward(self, x):
|
|
if isinstance(x, (list, tuple)):
|
|
x1 = x[0]
|
|
x2 = x[1] if len(x) > 1 else x[0]
|
|
else:
|
|
x1 = x
|
|
x2 = x
|
|
y1 = self.model1(x1)
|
|
y2 = self.model2(x2)
|
|
y = torch.cat((y1, y2), dim=1)
|
|
return self.fc(y)
|
|
|
|
model = Ensemble(model1, model2)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
if device != "cpu":
|
|
model = model.to(device)
|
|
model.load_state_dict(torch.load(file_model, map_location=device))
|
|
model.eval()
|
|
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
print("Инициализация модели завершена")
|
|
print()
|
|
return model
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
return None
|
|
|
|
|
|
def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""):
|
|
try:
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
if isinstance(data, (list, tuple)) and len(data) >= 2:
|
|
inputs = [
|
|
torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(),
|
|
torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(),
|
|
]
|
|
else:
|
|
tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float()
|
|
inputs = [tensor, tensor]
|
|
|
|
with torch.no_grad():
|
|
output = model(inputs)
|
|
_, predict = torch.max(output.data, 1)
|
|
|
|
prediction = mapping[int(np.asarray(predict.cpu())[0])]
|
|
print("PREDICTION" + shablon + ": " + prediction)
|
|
|
|
output = output.cpu()
|
|
label = np.asarray(np.argmax(output, axis=1))[0]
|
|
output = np.asarray(torch.squeeze(output, 0))
|
|
expon = np.exp(output - np.max(output))
|
|
probability = round((expon / expon.sum())[label], 2)
|
|
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
print("Уверенность" + shablon + " в предсказании: " + str(probability))
|
|
print("Инференс завершен")
|
|
print()
|
|
return [prediction, probability]
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
return None
|
|
|
|
|
|
def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None):
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
|
|
matplotlib.use("Agg")
|
|
plt.ioff()
|
|
|
|
if isinstance(data, (list, tuple)) and len(data) >= 2:
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(np.moveaxis(data[0], 0, -1))
|
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close(fig)
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(np.moveaxis(data[1], 0, -1))
|
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close(fig)
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
_prune_old_inference_images(src, model_type, model_id)
|
|
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
print("Постобработка завершена")
|
|
print()
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
return None
|