import torchsig.transforms.transforms as transform from importlib import import_module import matplotlib.pyplot as plt import torch.nn as nn import numpy as np import mlconfig import torch import cv2 import io def pre_func_resnet18(data=None, src ='', ind_inference=0): try: figsize = (16, 16) dpi = 64 signal = np.vectorize(complex)(data[0], data[1]) spec = transform.Spectrogram(nperseg=1024) spectr = np.array(spec(signal)[:, :figsize[0] * dpi]) mag = np.abs(signal) real = signal.real fig2 = plt.figure(figsize=figsize) plt.axes(ylim=(-1, 1)) plt.plot(real, color='black') plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) buf2 = io.BytesIO() fig2.savefig(buf2, format="png", dpi=dpi) buf2.seek(0) img_arr2 = np.frombuffer(buf2.getvalue(), dtype=np.uint8) buf2.close() img2 = cv2.imdecode(img_arr2, 1) img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) plt.clf() plt.cla() plt.close() plt.close(fig2) fig3 = plt.figure(figsize=figsize) plt.axes(ylim=(-1, 1)) plt.plot(mag, color='black') plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) buf3 = io.BytesIO() fig3.savefig(buf3, format="png", dpi=dpi) buf3.seek(0) img_arr3 = np.frombuffer(buf3.getvalue(), dtype=np.uint8) buf3.close() img3 = cv2.imdecode(img_arr3, 1) img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2GRAY) plt.clf() plt.cla() plt.close() plt.close(fig3) resize = (256, 256) resized_real = cv2.resize(img2, resize) resized_mag = cv2.resize(img3, resize) resized_spectr = cv2.resize(spectr, resize) img = np.array([resized_real, resized_mag, resized_spectr]) print('Подготовка данных завершена') print() return img except Exception as e: print(str(e)) return None def build_func_resnet18(file_model='', file_config='', num_classes=None): try: config = mlconfig.load(file_config) model = getattr(import_module(config.model.architecture.rsplit('.', maxsplit=1)[0]), config.model.architecture.rsplit('.', maxsplit=1)[1])() model.conv1 = torch.nn.Sequential(torch.nn.Conv2d(2, 3, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False), model.conv1) model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' if device != 'cpu': model = model.to(device) model.load_state_dict(torch.load(file_model, map_location=device)) model.eval() print('Инициализация модели завершена') print() return model except Exception as exc: print(str(exc)) return None def inference_func_resnet18(data=None, model=None, mapping=None, shablon=''): try: device = 'cuda' if torch.cuda.is_available() else 'cpu' img = torch.unsqueeze(torch.tensor(data), 0).to(device) with torch.no_grad(): output = model(img) _, predict = torch.max(output.data, 1) prediction = mapping[int(np.asarray(predict.cpu())[0])] print('PREDICTION' + shablon + ': ' + prediction) label = np.asarray(np.argmax(output, axis=1))[0] output = np.asarray(torch.squeeze(output, 0)) expon = np.exp(output - np.max(output)) probability = round((expon / expon.sum())[label], 2) print('Уверенность' + shablon + ' в предсказании: ' + str(probability)) print('Инференс завершен') print() return [prediction, probability] except Exception as exc: print(str(exc)) return None def post_func_resnet18(src='', model_type='', prediction='', model_id=0, ind_inference=0, data=None): try: fig, ax = plt.subplots() ax.imshow(data[0], cmap='gray') plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_real_' + str( model_id) + '_' + model_type + '.png') plt.clf() plt.cla() plt.close() fig, ax = plt.subplots() ax.imshow(data[1], cmap='gray') plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_imag_' + str( model_id) + '_' + model_type + '.png') plt.clf() plt.cla() plt.close() fig, ax = plt.subplots() ax.imshow(data[2], cmap='gray') plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_spec_' + str( model_id) + '_' + model_type + '.png') plt.clf() plt.cla() plt.close() print('Постобработка завершена') print() except Exception as exc: print(str(exc)) return None