You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/NN_server/Models/detection_250_2.py

155 lines
5.3 KiB
Python

import torchsig.torchsig.transforms as transform
from importlib import import_module
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import mlconfig
import torch
import cv2
import io
def pre_func_resnet18(data=None, src ='', ind_inference=0):
try:
figsize = (16, 16)
dpi = 64
signal = np.vectorize(complex)(data[0], data[1])
spec = transform.Spectrogram(nperseg=1024)
spectr = np.array(spec(signal)[:, :figsize[0] * dpi])
mag = np.abs(signal)
real = signal.real
fig2 = plt.figure(figsize=figsize)
plt.axes(ylim=(-1, 1))
plt.plot(real, color='black')
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
buf2 = io.BytesIO()
fig2.savefig(buf2, format="png", dpi=dpi)
buf2.seek(0)
img_arr2 = np.frombuffer(buf2.getvalue(), dtype=np.uint8)
buf2.close()
img2 = cv2.imdecode(img_arr2, 1)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
plt.clf()
plt.cla()
plt.close()
plt.close(fig2)
fig3 = plt.figure(figsize=figsize)
plt.axes(ylim=(-1, 1))
plt.plot(mag, color='black')
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
buf3 = io.BytesIO()
fig3.savefig(buf3, format="png", dpi=dpi)
buf3.seek(0)
img_arr3 = np.frombuffer(buf3.getvalue(), dtype=np.uint8)
buf3.close()
img3 = cv2.imdecode(img_arr3, 1)
img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2GRAY)
plt.clf()
plt.cla()
plt.close()
plt.close(fig3)
resize = (256, 256)
resized_real = cv2.resize(img2, resize)
resized_mag = cv2.resize(img3, resize)
resized_spectr = cv2.resize(spectr, resize)
img = np.array([resized_real, resized_mag, resized_spectr])
print('Подготовка данных завершена')
print()
return img
except Exception as e:
print(str(e))
return None
def build_func_resnet18(file_model='', file_config='', num_classes=None):
try:
config = mlconfig.load(file_config)
model = getattr(import_module(config.model.architecture.rsplit('.', maxsplit=1)[0]),
config.model.architecture.rsplit('.', maxsplit=1)[1])()
model.conv1 = torch.nn.Sequential(torch.nn.Conv2d(2, 3, kernel_size=(7, 7), stride=(2, 2),
padding=(3, 3), bias=False), model.conv1)
model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device != 'cpu':
model = model.to(device)
model.load_state_dict(torch.load(file_model, map_location=device))
model.eval()
print('Инициализация модели завершена')
print()
return model
except Exception as exc:
print(str(exc))
return None
def inference_func_resnet18(data=None, model=None, mapping=None, shablon=''):
try:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img = torch.unsqueeze(torch.tensor(data), 0).to(device)
with torch.no_grad():
output = model(img)
_, predict = torch.max(output.data, 1)
prediction = mapping[int(np.asarray(predict.cpu())[0])]
print('PREDICTION' + shablon + ': ' + prediction)
label = np.asarray(np.argmax(output, axis=1))[0]
output = np.asarray(torch.squeeze(output, 0))
expon = np.exp(output - np.max(output))
probability = round((expon / expon.sum())[label], 2)
print('Уверенность' + shablon + ' в предсказании: ' + str(probability))
print('Инференс завершен')
print()
return [prediction, probability]
except Exception as exc:
print(str(exc))
return None
def post_func_resnet18(src='', model_type='', prediction='', model_id=0, ind_inference=0, data=None):
try:
fig, ax = plt.subplots()
ax.imshow(data[0], cmap='gray')
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_real_' + str(
model_id) + '_' + model_type + '.png')
plt.clf()
plt.cla()
plt.close()
fig, ax = plt.subplots()
ax.imshow(data[1], cmap='gray')
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_imag_' + str(
model_id) + '_' + model_type + '.png')
plt.clf()
plt.cla()
plt.close()
fig, ax = plt.subplots()
ax.imshow(data[2], cmap='gray')
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_spec_' + str(
model_id) + '_' + model_type + '.png')
plt.clf()
plt.cla()
plt.close()
print('Постобработка завершена')
print()
except Exception as exc:
print(str(exc))
return None