You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
6.6 KiB
Python
217 lines
6.6 KiB
Python
import torchsig.transforms.transforms as transform
|
|
import torchsig.transforms.functional as F
|
|
from importlib import import_module
|
|
import matplotlib.pyplot as plt
|
|
import torch.nn as nn
|
|
import matplotlib
|
|
import numpy as np
|
|
import mlconfig
|
|
import torch
|
|
import cv2
|
|
import gc
|
|
import io
|
|
|
|
|
|
def pre_func_resnet18(data=None, src ='', ind_inference=0):
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
matplotlib.use('Agg')
|
|
plt.ioff()
|
|
|
|
figsize = (16, 16)
|
|
dpi = 16
|
|
|
|
signal = np.vectorize(complex)(data[0], data[1])
|
|
|
|
# if int(ind_inference) <= 1500:
|
|
# np.save(src + '_inference_2400_' + str(ind_inference) + '.npy', signal)
|
|
print(0)
|
|
spectr = np.asarray(F.spectrogram(signal,fft_size=256,fft_stride=256), dtype=np.float32)
|
|
print('a')
|
|
print(spectr.shape)
|
|
print('b')
|
|
fig1 = plt.figure(figsize = figsize)
|
|
plt.axes(ylim=(-1, 1))
|
|
sigr = signal.real
|
|
sigi = signal.imag
|
|
print(1)
|
|
|
|
plt.plot(sigr, color='black')
|
|
plt.gca().set_axis_off()
|
|
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
|
|
plt.margins(0,0)
|
|
buf1 = io.BytesIO()
|
|
fig1.savefig(buf1, format="png", dpi=dpi)
|
|
buf1.seek(0)
|
|
img_arr1 = np.frombuffer(buf1.getvalue(), dtype=np.uint8)
|
|
buf1.close()
|
|
img1 = cv2.imdecode(img_arr1, 1)
|
|
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
plt.close(fig1)
|
|
print(2)
|
|
|
|
fig2 = plt.figure(figsize = figsize)
|
|
plt.axes(ylim=(-1, 1))
|
|
sigr = signal.real
|
|
sigi = signal.imag
|
|
|
|
plt.plot(sigi, color='black')
|
|
plt.gca().set_axis_off()
|
|
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
|
|
plt.margins(0,0)
|
|
buf = io.BytesIO()
|
|
fig2.savefig(buf, format="png", dpi=dpi)
|
|
buf.seek(0)
|
|
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
|
|
buf.close()
|
|
img = cv2.imdecode(img_arr, 1)
|
|
img2 = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
plt.close(fig2)
|
|
print(3)
|
|
|
|
img = np.array([img1, img2, spectr[:,:figsize[0]*dpi]])
|
|
|
|
cv2.destroyAllWindows()
|
|
del signal
|
|
del spectr
|
|
del img1
|
|
del img2
|
|
del sigr
|
|
del sigi
|
|
del buf
|
|
del buf1
|
|
del img_arr
|
|
del img_arr1
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
print('Подготовка данных завершена')
|
|
print()
|
|
return img
|
|
|
|
except Exception as e:
|
|
print(str(e))
|
|
return None
|
|
|
|
|
|
def build_func_resnet18(file_model='', file_config='', num_classes=None):
|
|
try:
|
|
matplotlib.use('Agg')
|
|
plt.ioff()
|
|
torch.cuda.empty_cache()
|
|
config = mlconfig.load(file_config)
|
|
model = getattr(import_module(config.model.architecture.rsplit('.', maxsplit=1)[0]),
|
|
config.model.architecture.rsplit('.', maxsplit=1)[1])()
|
|
model.conv1 = nn.Sequential(nn.Conv2d(3, 3, kernel_size=(7, 7), stride=(2, 2),
|
|
padding=(3, 3), bias=False), model.conv1)
|
|
model.fc = nn.Sequential(
|
|
nn.Linear(in_features=512, out_features=128, bias=True),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(p=0.5, inplace=False),
|
|
nn.Linear(in_features=128, out_features=32, bias=True),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(p=0.5, inplace=False),
|
|
nn.Linear(in_features=32, out_features=16, bias=True),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(p=0.5, inplace=False),
|
|
nn.Linear(in_features=16, out_features=3, bias=True)
|
|
)
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
if device != 'cpu':
|
|
model = model.to(device)
|
|
model.load_state_dict(torch.load(file_model, map_location=device))
|
|
model.eval()
|
|
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
print('Инициализация модели завершена')
|
|
print()
|
|
return model
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
return None
|
|
|
|
|
|
def inference_func_resnet18(data=None, model=None, mapping=None, shablon=''):
|
|
try:
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
img = torch.unsqueeze(torch.tensor(data), 0).to(device)
|
|
|
|
with torch.no_grad():
|
|
output = model(img.float())
|
|
_, predict = torch.max(output.data, 1)
|
|
prediction = mapping[int(np.asarray(predict.cpu())[0])]
|
|
print('PREDICTION' + shablon + ': ' + prediction)
|
|
|
|
label = np.asarray(np.argmax(output, axis=1))[0]
|
|
output = np.asarray(torch.squeeze(output, 0))
|
|
expon = np.exp(output - np.max(output))
|
|
probability = round((expon / expon.sum())[label], 2)
|
|
|
|
del label
|
|
del expon
|
|
del output
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
print('Уверенность' + shablon + ' в предсказании: ' + str(probability))
|
|
|
|
print('Инференс завершен')
|
|
print()
|
|
return [prediction, probability]
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
return None
|
|
|
|
|
|
def post_func_resnet18(src='', model_type='', prediction='', model_id=0, ind_inference=0, data=None):
|
|
try:
|
|
matplotlib.use('Agg')
|
|
plt.ioff()
|
|
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(data[0], cmap='gray')
|
|
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_real_' + str(
|
|
model_id) + '_' + model_type + '.png')
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(data[1], cmap='gray')
|
|
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_imag_' + str(
|
|
model_id) + '_' + model_type + '.png')
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(data[2], cmap='gray')
|
|
plt.savefig(src + '_inference_' + str(ind_inference) + '_' + prediction + '_spec_' + str(
|
|
model_id) + '_' + model_type + '.png')
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
|
|
del fig
|
|
del ax
|
|
cv2.destroyAllWindows()
|
|
gc.collect()
|
|
|
|
print('Постобработка завершена')
|
|
print()
|
|
|
|
except Exception as exc:
|
|
print(str(exc))
|
|
return None
|