You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/NN_server/Models/ensemble_1200.py

253 lines
7.3 KiB
Python

import torchsig.transforms.dataset_transforms as transform
import torchsig.transforms.functional as F
from importlib import import_module
from torchvision import models
import torch.nn as nn
import matplotlib
import numpy as np
import torch
import cv2
import gc
import io
def pre_func_ensemble(data=None, src ='', ind_inference=0):
try:
import matplotlib.pyplot as plt
matplotlib.use('Agg')
plt.ioff()
figsize = (16, 16)
dpi = 16
signal = np.vectorize(complex)(data[0], data[1])
# if int(ind_inference) <= 1500:
# np.save(src + '_inference_1200_' + str(ind_inference) + '.npy', signal)
print(0)
spectr = np.asarray(F.spectrogram(signal,fft_size=256,fft_stride=256), dtype=np.float32)
print('a')
print(spectr.shape)
print('b')
fig1 = plt.figure(figsize = figsize)
plt.axes(ylim=(-1, 1))
sigr = signal.real
sigi = signal.imag
print(1)
plt.plot(sigr, color='black')
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
plt.margins(0,0)
buf1 = io.BytesIO()
fig1.savefig(buf1, format="png", dpi=dpi)
buf1.seek(0)
img_arr1 = np.frombuffer(buf1.getvalue(), dtype=np.uint8)
buf1.close()
img1 = cv2.imdecode(img_arr1, 1)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
plt.clf()
plt.cla()
plt.close()
plt.close(fig1)
print(2)
fig2 = plt.figure(figsize = figsize)
plt.axes(ylim=(-1, 1))
sigr = signal.real
sigi = signal.imag
plt.plot(sigi, color='black')
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
plt.margins(0,0)
buf = io.BytesIO()
fig2.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img2 = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
plt.clf()
plt.cla()
plt.close()
plt.close(fig2)
print(3)
img = np.array([img1, img2, spectr[:,:figsize[0]*dpi]])
cv2.destroyAllWindows()
del signal
del spectr
del img1
del img2
del sigr
del sigi
del buf
del buf1
del img_arr
del img_arr1
cv2.destroyAllWindows()
gc.collect()
print('Подготовка данных завершена')
print()
return img
except Exception as e:
print(str(e))
return None
def build_func_ensemble(file_model='', file_config='', num_classes=None):
try:
import matplotlib.pyplot as plt
matplotlib.use('Agg')
plt.ioff()
torch.cuda.empty_cache()
model1 = models.resnet18(pretrained=False)
model2 = models.resnet50(pretrained=False)
model3 = models.resnet101(pretrained=False)
num_classes = 2
model1.fc = nn.Linear(model1.fc.in_features, num_classes)
model2.fc = nn.Linear(model2.fc.in_features, num_classes)
model3.fc = nn.Linear(model3.fc.in_features, num_classes)
class Ensemble(nn.Module):
def __init__(self, model1, model2, model3):
super(Ensemble, self).__init__()
self.model1 = model1
self.model2 = model2
self.model3 = model3
self.fc = nn.Linear(3 * num_classes, num_classes)
def forward(self, x):
x1 = self.model1(x)
x2 = self.model2(x)
x3 = self.model3(x)
x = torch.cat((x1, x2, x3), dim=1)
x = self.fc(x)
return x
model = Ensemble(model1, model2, model3)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device != 'cpu':
model = model.to(device)
model.load_state_dict(torch.load(file_model, map_location=device))
model.eval()
cv2.destroyAllWindows()
del model1
del model2
del model3
gc.collect()
print('Инициализация модели завершена')
print()
return model
except Exception as exc:
print(str(exc))
return None
def inference_func_ensemble(data=None, model=None, mapping=None, shablon=''):
try:
cv2.destroyAllWindows()
gc.collect()
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device)
with torch.no_grad():
output = model(img.float())
_, predict = torch.max(output.data, 1)
prediction = mapping[int(np.asarray(predict.cpu())[0])]
print('PREDICTION' + shablon + ': ' + prediction)
output = output.cpu()
label = np.asarray(np.argmax(output, axis=1))[0]
output = np.asarray(torch.squeeze(output, 0))
expon = np.exp(output - np.max(output))
probability = round((expon / expon.sum())[label], 2)
del img
del label
del expon
del output
cv2.destroyAllWindows()
gc.collect()
print('Уверенность' + shablon + ' в предсказании: ' + str(probability))
print('Инференс завершен')
print()
return [prediction, probability]
except Exception as exc:
print(str(exc))
return None
def post_func_ensemble(src='', model_type='', prediction='', model_id=0, ind_inference=0, data=None):
try:
import matplotlib.pyplot as plt
matplotlib.use('Agg')
plt.ioff()
if int(ind_inference) <= 100:
fig, ax = plt.subplots()
ax.imshow(data[0], cmap='gray')
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_real_' + str(
model_id) + '_' + model_type + '.png')
plt.clf()
plt.cla()
plt.close(fig)
del fig
del ax
cv2.destroyAllWindows()
gc.collect()
fig, ax = plt.subplots()
ax.imshow(data[1], cmap='gray')
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_mod_' + str(
model_id) + '_' + model_type + '.png')
plt.clf()
plt.cla()
plt.close(fig)
del fig
del ax
cv2.destroyAllWindows()
gc.collect()
fig, ax = plt.subplots()
ax.imshow(data[2], cmap='gray')
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_spec_' + str(
model_id) + '_' + model_type + '.png')
plt.clf()
plt.cla()
plt.close(fig)
del fig
del ax
cv2.destroyAllWindows()
gc.collect()
plt.clf()
plt.cla()
plt.close()
cv2.destroyAllWindows()
gc.collect()
print('Постобработка завершена')
print()
except Exception as exc:
print(str(exc))
return None