Compare commits
10 Commits
8801da18c4
...
3bf93aab3f
| Author | SHA1 | Date |
|---|---|---|
|
|
3bf93aab3f | 2 days ago |
|
|
7ad17bb4c4 | 2 days ago |
|
|
0b65c2980d | 2 days ago |
|
|
c70a25cb8f | 2 days ago |
|
|
94856d0fb8 | 2 days ago |
|
|
783fb40eb0 | 2 days ago |
|
|
a1c99ebf9f | 2 days ago |
|
|
0fad5d6404 | 2 days ago |
|
|
c0ccecc270 | 2 days ago |
|
|
6a492a036b | 2 days ago |
@ -0,0 +1,197 @@
|
|||||||
|
from torchvision import models
|
||||||
|
import torch.nn as nn
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import gc
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
def _render_plot(values, figsize=(16, 16), dpi=16):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
plt.axes(ylim=(-1, 1))
|
||||||
|
plt.plot(values, color="black")
|
||||||
|
plt.gca().set_axis_off()
|
||||||
|
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||||
|
plt.margins(0, 0)
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format="png", dpi=dpi)
|
||||||
|
buf.seek(0)
|
||||||
|
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
img = cv2.imdecode(img_arr, 1)
|
||||||
|
if img is None:
|
||||||
|
raise RuntimeError("failed to decode plot image")
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close()
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return np.asarray(cv2.split(img), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def pre_func_ensemble(data=None, src="", ind_inference=0):
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
plt.ioff()
|
||||||
|
|
||||||
|
real = np.asarray(data[0], dtype=np.float32)
|
||||||
|
imag = np.asarray(data[1], dtype=np.float32)
|
||||||
|
signal = real + 1j * imag
|
||||||
|
|
||||||
|
img_real = _render_plot(signal.real)
|
||||||
|
img_mag = _render_plot(np.abs(signal))
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Подготовка данных завершена")
|
||||||
|
print()
|
||||||
|
return [img_real, img_mag]
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_func_ensemble(file_model="", file_config="", num_classes=None):
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
plt.ioff()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
num_classes = 2
|
||||||
|
model1 = models.resnet18(pretrained=False)
|
||||||
|
model2 = models.resnet50(pretrained=False)
|
||||||
|
|
||||||
|
model1.fc = nn.Linear(model1.fc.in_features, num_classes)
|
||||||
|
model2.fc = nn.Linear(model2.fc.in_features, num_classes)
|
||||||
|
|
||||||
|
class Ensemble(nn.Module):
|
||||||
|
def __init__(self, model1, model2):
|
||||||
|
super().__init__()
|
||||||
|
self.model1 = model1
|
||||||
|
self.model2 = model2
|
||||||
|
self.fc = nn.Linear(2 * num_classes, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x1 = x[0]
|
||||||
|
x2 = x[1] if len(x) > 1 else x[0]
|
||||||
|
else:
|
||||||
|
x1 = x
|
||||||
|
x2 = x
|
||||||
|
y1 = self.model1(x1)
|
||||||
|
y2 = self.model2(x2)
|
||||||
|
y = torch.cat((y1, y2), dim=1)
|
||||||
|
return self.fc(y)
|
||||||
|
|
||||||
|
model = Ensemble(model1, model2)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if device != "cpu":
|
||||||
|
model = model.to(device)
|
||||||
|
model.load_state_dict(torch.load(file_model, map_location=device))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Инициализация модели завершена")
|
||||||
|
print()
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""):
|
||||||
|
try:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if isinstance(data, (list, tuple)) and len(data) >= 2:
|
||||||
|
inputs = [
|
||||||
|
torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(),
|
||||||
|
torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float()
|
||||||
|
inputs = [tensor, tensor]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(inputs)
|
||||||
|
_, predict = torch.max(output.data, 1)
|
||||||
|
|
||||||
|
prediction = mapping[int(np.asarray(predict.cpu())[0])]
|
||||||
|
print("PREDICTION" + shablon + ": " + prediction)
|
||||||
|
|
||||||
|
output = output.cpu()
|
||||||
|
label = np.asarray(np.argmax(output, axis=1))[0]
|
||||||
|
output = np.asarray(torch.squeeze(output, 0))
|
||||||
|
expon = np.exp(output - np.max(output))
|
||||||
|
probability = round((expon / expon.sum())[label], 2)
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
print("Уверенность" + shablon + " в предсказании: " + str(probability))
|
||||||
|
print("Инференс завершен")
|
||||||
|
print()
|
||||||
|
return [prediction, probability]
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None):
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
plt.ioff()
|
||||||
|
|
||||||
|
if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2:
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(np.moveaxis(data[0], 0, -1))
|
||||||
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close(fig)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(np.moveaxis(data[1], 0, -1))
|
||||||
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close(fig)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Постобработка завершена")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
@ -0,0 +1,197 @@
|
|||||||
|
from torchvision import models
|
||||||
|
import torch.nn as nn
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import gc
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
def _render_plot(values, figsize=(16, 16), dpi=16):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
plt.axes(ylim=(-1, 1))
|
||||||
|
plt.plot(values, color="black")
|
||||||
|
plt.gca().set_axis_off()
|
||||||
|
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||||
|
plt.margins(0, 0)
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format="png", dpi=dpi)
|
||||||
|
buf.seek(0)
|
||||||
|
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
img = cv2.imdecode(img_arr, 1)
|
||||||
|
if img is None:
|
||||||
|
raise RuntimeError("failed to decode plot image")
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close()
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return np.asarray(cv2.split(img), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def pre_func_ensemble(data=None, src="", ind_inference=0):
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
plt.ioff()
|
||||||
|
|
||||||
|
real = np.asarray(data[0], dtype=np.float32)
|
||||||
|
imag = np.asarray(data[1], dtype=np.float32)
|
||||||
|
signal = real + 1j * imag
|
||||||
|
|
||||||
|
img_real = _render_plot(signal.real)
|
||||||
|
img_mag = _render_plot(np.abs(signal))
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Подготовка данных завершена")
|
||||||
|
print()
|
||||||
|
return [img_real, img_mag]
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_func_ensemble(file_model="", file_config="", num_classes=None):
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
plt.ioff()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
num_classes = 2
|
||||||
|
model1 = models.resnet18(pretrained=False)
|
||||||
|
model2 = models.resnet50(pretrained=False)
|
||||||
|
|
||||||
|
model1.fc = nn.Linear(model1.fc.in_features, num_classes)
|
||||||
|
model2.fc = nn.Linear(model2.fc.in_features, num_classes)
|
||||||
|
|
||||||
|
class Ensemble(nn.Module):
|
||||||
|
def __init__(self, model1, model2):
|
||||||
|
super().__init__()
|
||||||
|
self.model1 = model1
|
||||||
|
self.model2 = model2
|
||||||
|
self.fc = nn.Linear(2 * num_classes, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x1 = x[0]
|
||||||
|
x2 = x[1] if len(x) > 1 else x[0]
|
||||||
|
else:
|
||||||
|
x1 = x
|
||||||
|
x2 = x
|
||||||
|
y1 = self.model1(x1)
|
||||||
|
y2 = self.model2(x2)
|
||||||
|
y = torch.cat((y1, y2), dim=1)
|
||||||
|
return self.fc(y)
|
||||||
|
|
||||||
|
model = Ensemble(model1, model2)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if device != "cpu":
|
||||||
|
model = model.to(device)
|
||||||
|
model.load_state_dict(torch.load(file_model, map_location=device))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Инициализация модели завершена")
|
||||||
|
print()
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def inference_func_ensemble(data=None, model=None, mapping=None, shablon=""):
|
||||||
|
try:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if isinstance(data, (list, tuple)) and len(data) >= 2:
|
||||||
|
inputs = [
|
||||||
|
torch.unsqueeze(torch.tensor(data[0]).cpu(), 0).to(device).float(),
|
||||||
|
torch.unsqueeze(torch.tensor(data[1]).cpu(), 0).to(device).float(),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tensor = torch.unsqueeze(torch.tensor(data).cpu(), 0).to(device).float()
|
||||||
|
inputs = [tensor, tensor]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(inputs)
|
||||||
|
_, predict = torch.max(output.data, 1)
|
||||||
|
|
||||||
|
prediction = mapping[int(np.asarray(predict.cpu())[0])]
|
||||||
|
print("PREDICTION" + shablon + ": " + prediction)
|
||||||
|
|
||||||
|
output = output.cpu()
|
||||||
|
label = np.asarray(np.argmax(output, axis=1))[0]
|
||||||
|
output = np.asarray(torch.squeeze(output, 0))
|
||||||
|
expon = np.exp(output - np.max(output))
|
||||||
|
probability = round((expon / expon.sum())[label], 2)
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
print("Уверенность" + shablon + " в предсказании: " + str(probability))
|
||||||
|
print("Инференс завершен")
|
||||||
|
print()
|
||||||
|
return [prediction, probability]
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inference=0, data=None):
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
plt.ioff()
|
||||||
|
|
||||||
|
if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2:
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(np.moveaxis(data[0], 0, -1))
|
||||||
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close(fig)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(np.moveaxis(data[1], 0, -1))
|
||||||
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close(fig)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Постобработка завершена")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc))
|
||||||
|
return None
|
||||||
@ -0,0 +1,523 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "e1db882b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.\n",
|
||||||
|
" warnings.warn(\"Unable to import Axes3D. This may be due to multiple versions of \"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"cuda\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"220"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"from torch import default_generator, randperm\n",
|
||||||
|
"from torch.utils.data.dataset import Subset\n",
|
||||||
|
"import torchvision.transforms as transforms\n",
|
||||||
|
"from torchvision.io import read_image\n",
|
||||||
|
"from importlib import import_module\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from torchvision import models\n",
|
||||||
|
"import torch, torchvision\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import matplotlib\n",
|
||||||
|
"import os, shutil\n",
|
||||||
|
"import mlconfig\n",
|
||||||
|
"import random\n",
|
||||||
|
"import shutil\n",
|
||||||
|
"import timeit\n",
|
||||||
|
"import copy\n",
|
||||||
|
"import time\n",
|
||||||
|
"import cv2\n",
|
||||||
|
"import csv\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import io\n",
|
||||||
|
"import gc\n",
|
||||||
|
"\n",
|
||||||
|
"plt.rcParams[\"savefig.bbox\"] = 'tight'\n",
|
||||||
|
"torch.manual_seed(1)\n",
|
||||||
|
"#matplotlib.use('Agg')\n",
|
||||||
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||||
|
"print(device)\n",
|
||||||
|
"torch.cuda.empty_cache()\n",
|
||||||
|
"cv2.destroyAllWindows()\n",
|
||||||
|
"gc.collect()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "8e009995",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def prepare_and_learning_detection(num_classes, num_samples, path_dataset, model_name, config_name, model,selected_freq):\n",
|
||||||
|
" num_samples_per_class = num_samples // num_classes\n",
|
||||||
|
"\n",
|
||||||
|
" #----------Создаём папку для сохранения результатов обучения--------------\n",
|
||||||
|
" os.makedirs(\"models\", exist_ok=True)\n",
|
||||||
|
" ind = 1\n",
|
||||||
|
" while True:\n",
|
||||||
|
" if os.path.exists(\"models/\" + model_name + str(ind)):\n",
|
||||||
|
" ind += 1\n",
|
||||||
|
" else:\n",
|
||||||
|
" os.mkdir(\"models/\" + model_name + str(ind))\n",
|
||||||
|
" path_res = \"models/\" + model_name + str(ind) + '/'\n",
|
||||||
|
" break\n",
|
||||||
|
" \n",
|
||||||
|
" #----------Создаём файл dataset.csv для обучения--------------\n",
|
||||||
|
" \n",
|
||||||
|
" pd_columns = ['file_name']\n",
|
||||||
|
" df = pd.DataFrame(columns=pd_columns)\n",
|
||||||
|
" \n",
|
||||||
|
" subdirs = os.listdir(path_dataset)\n",
|
||||||
|
" \n",
|
||||||
|
" for subdir in subdirs:\n",
|
||||||
|
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n",
|
||||||
|
" if not os.path.isdir(freq_dir):\n",
|
||||||
|
" print(\"Error1\")\n",
|
||||||
|
" continue\n",
|
||||||
|
" \n",
|
||||||
|
" files_k=[f for f in os.listdir(freq_dir)]\n",
|
||||||
|
" print(len(files_k))\n",
|
||||||
|
" \n",
|
||||||
|
" files = [\n",
|
||||||
|
" f for f in os.listdir(freq_dir)\n",
|
||||||
|
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n",
|
||||||
|
" ]\n",
|
||||||
|
" num_samples_per_class = min(num_samples_per_class, len(files))\n",
|
||||||
|
" print(f\"num_samples per class {subdir} is {num_samples_per_class}\")\n",
|
||||||
|
"\n",
|
||||||
|
" for subdir in subdirs:\n",
|
||||||
|
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n",
|
||||||
|
" if not os.path.isdir(freq_dir):\n",
|
||||||
|
" print(\"Error1\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" files = [\n",
|
||||||
|
" f for f in os.listdir(freq_dir)\n",
|
||||||
|
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n",
|
||||||
|
" ]\n",
|
||||||
|
" random.shuffle(files)\n",
|
||||||
|
" files_to_process = files[:num_samples_per_class]\n",
|
||||||
|
"\n",
|
||||||
|
" for file in files_to_process:\n",
|
||||||
|
" row = pd.DataFrame({\n",
|
||||||
|
" pd_columns[0]: [str(os.path.join(freq_dir, file))]\n",
|
||||||
|
" })\n",
|
||||||
|
" df = pd.concat([df, row], ignore_index=True)\n",
|
||||||
|
"\n",
|
||||||
|
" dataset_csv_path = os.path.join(path_res, 'dataset.csv')\n",
|
||||||
|
" df.to_csv(dataset_csv_path, index=False)\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(dataset_csv_path):\n",
|
||||||
|
" raise RuntimeError(f'dataset.csv was not created: {dataset_csv_path}')\n",
|
||||||
|
" #----------Импортируем параметры для обучения--------------\n",
|
||||||
|
" \n",
|
||||||
|
" def load_function(attr):\n",
|
||||||
|
" module_, func = attr.rsplit('.', maxsplit=1)\n",
|
||||||
|
" return getattr(import_module(module_), func)\n",
|
||||||
|
" \n",
|
||||||
|
" config = mlconfig.load('config_' + config_name + '.yaml')\n",
|
||||||
|
" \n",
|
||||||
|
" #----------Создаём класс датасета--------------\n",
|
||||||
|
" \n",
|
||||||
|
" class MyDataset(Dataset):\n",
|
||||||
|
" def __init__(self, path_dataset, csv_file):\n",
|
||||||
|
" data=[]\n",
|
||||||
|
" with open(os.path.join(path_dataset, csv_file), newline='') as csvfile:\n",
|
||||||
|
" reader = csv.reader(csvfile, delimiter=' ', quotechar='|')\n",
|
||||||
|
" for row in list(reader)[1:]:\n",
|
||||||
|
" row = str(row)\n",
|
||||||
|
" data.append(row[2: len(row)-2])\n",
|
||||||
|
" self.path_dataset = path_dataset\n",
|
||||||
|
" self.target_shape = None\n",
|
||||||
|
" self.target_hw = None\n",
|
||||||
|
" self.sig_filenames = self._validate_files(data)\n",
|
||||||
|
"\n",
|
||||||
|
" def _pair_paths(self, filename):\n",
|
||||||
|
" base = os.path.splitext(filename)[0]\n",
|
||||||
|
" if base.endswith(\"real\"):\n",
|
||||||
|
" return base + \".png\", base[:-4] + \"imag.png\"\n",
|
||||||
|
" if base.endswith(\"imag\"):\n",
|
||||||
|
" return base[:-4] + \"real.png\", base + \".png\"\n",
|
||||||
|
" return None, None\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def _read_shape(path):\n",
|
||||||
|
" img = cv2.imread(path)\n",
|
||||||
|
" if img is None:\n",
|
||||||
|
" return None\n",
|
||||||
|
" return img.shape\n",
|
||||||
|
"\n",
|
||||||
|
" def _validate_files(self, filenames):\n",
|
||||||
|
" from collections import Counter\n",
|
||||||
|
"\n",
|
||||||
|
" candidates = []\n",
|
||||||
|
" dropped = []\n",
|
||||||
|
" shape_counter = Counter()\n",
|
||||||
|
"\n",
|
||||||
|
" for filename in filenames:\n",
|
||||||
|
" real_path, imag_path = self._pair_paths(filename)\n",
|
||||||
|
" if real_path is None or imag_path is None:\n",
|
||||||
|
" dropped.append((filename, \"bad file suffix\"))\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" real_shape = self._read_shape(real_path)\n",
|
||||||
|
" imag_shape = self._read_shape(imag_path)\n",
|
||||||
|
" if real_shape is None or imag_shape is None:\n",
|
||||||
|
" dropped.append((filename, f\"read failed real={real_shape} imag={imag_shape}\"))\n",
|
||||||
|
" continue\n",
|
||||||
|
" if real_shape != imag_shape:\n",
|
||||||
|
" dropped.append((filename, f\"pair shape mismatch real={real_shape} imag={imag_shape}\"))\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" shape_counter[real_shape] += 1\n",
|
||||||
|
" candidates.append((filename, real_shape))\n",
|
||||||
|
"\n",
|
||||||
|
" if not candidates:\n",
|
||||||
|
" raise RuntimeError(\"No valid image pairs left after shape validation\")\n",
|
||||||
|
"\n",
|
||||||
|
" preferred_shape = (1600, 1600, 3)\n",
|
||||||
|
" if shape_counter.get(preferred_shape, 0) > 0:\n",
|
||||||
|
" target_shape = preferred_shape\n",
|
||||||
|
" else:\n",
|
||||||
|
" target_shape = shape_counter.most_common(1)[0][0]\n",
|
||||||
|
"\n",
|
||||||
|
" self.target_shape = target_shape\n",
|
||||||
|
" self.target_hw = target_shape[:2]\n",
|
||||||
|
" valid = [filename for filename, _shape in candidates]\n",
|
||||||
|
" resized_count = sum(1 for _filename, shape in candidates if shape != target_shape)\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"[dataset-shape-filter] shape_distribution={dict(shape_counter)}\")\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"[dataset-shape-filter] target_shape={target_shape} \"\n",
|
||||||
|
" f\"kept={len(valid)} dropped={len(dropped)} will_resize={resized_count}\"\n",
|
||||||
|
" )\n",
|
||||||
|
" for filename, reason in dropped[:20]:\n",
|
||||||
|
" print(f\"[dataset-shape-filter] drop {filename}: {reason}\")\n",
|
||||||
|
" if len(dropped) > 20:\n",
|
||||||
|
" print(f\"[dataset-shape-filter] ... {len(dropped) - 20} more dropped\")\n",
|
||||||
|
"\n",
|
||||||
|
" return valid\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.sig_filenames)\n",
|
||||||
|
"\n",
|
||||||
|
" def _read_image_chw(self, path):\n",
|
||||||
|
" img = cv2.imread(path)\n",
|
||||||
|
" if img is None:\n",
|
||||||
|
" raise RuntimeError(f\"failed to read image: {path}\")\n",
|
||||||
|
" if img.shape[:2] != self.target_hw:\n",
|
||||||
|
" img = cv2.resize(\n",
|
||||||
|
" img,\n",
|
||||||
|
" (self.target_hw[1], self.target_hw[0]),\n",
|
||||||
|
" interpolation=cv2.INTER_AREA,\n",
|
||||||
|
" )\n",
|
||||||
|
" return np.asarray(cv2.split(img), dtype=np.float32)\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" real_path, imag_path = self._pair_paths(self.sig_filenames[idx])\n",
|
||||||
|
" image_real = self._read_image_chw(real_path)\n",
|
||||||
|
" image_imag = self._read_image_chw(imag_path)\n",
|
||||||
|
"\n",
|
||||||
|
" path_parts = set(self.sig_filenames[idx].split('/'))\n",
|
||||||
|
" if 'drone' in path_parts:\n",
|
||||||
|
" label = torch.tensor(0)\n",
|
||||||
|
" elif 'noise' in path_parts:\n",
|
||||||
|
" label = torch.tensor(1)\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise RuntimeError(f\"cannot infer label from path: {self.sig_filenames[idx]}\")\n",
|
||||||
|
"\n",
|
||||||
|
" return image_real, image_imag, label\n",
|
||||||
|
"\n",
|
||||||
|
" #----------Создаём датасет--------------\n",
|
||||||
|
" \n",
|
||||||
|
" dataset = MyDataset(path_dataset=path_res, csv_file='dataset.csv')\n",
|
||||||
|
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n",
|
||||||
|
" batch_size = config.batch_size\n",
|
||||||
|
" train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
|
||||||
|
" valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, drop_last=False)\n",
|
||||||
|
" \n",
|
||||||
|
" dataloaders = {}\n",
|
||||||
|
" dataloaders['train'] = train_dataloader\n",
|
||||||
|
" dataloaders['val'] = valid_dataloader\n",
|
||||||
|
" dataset_sizes = {}\n",
|
||||||
|
" dataset_sizes['train'] = len(train_set)\n",
|
||||||
|
" dataset_sizes['val'] = len(valid_set)\n",
|
||||||
|
"\n",
|
||||||
|
" #----------Обучаем модель--------------\n",
|
||||||
|
"\n",
|
||||||
|
" val_loss = []\n",
|
||||||
|
" val_acc = []\n",
|
||||||
|
" train_loss = []\n",
|
||||||
|
" train_acc = []\n",
|
||||||
|
" epochs = config.epoch\n",
|
||||||
|
" min_delta = 1e-4\n",
|
||||||
|
" \n",
|
||||||
|
" best_val_loss = float('inf')\n",
|
||||||
|
" best_model = copy.deepcopy(model.state_dict())\n",
|
||||||
|
" limit = config.limit\n",
|
||||||
|
" ind_limit = 0\n",
|
||||||
|
" epoch_limit = epochs\n",
|
||||||
|
" \n",
|
||||||
|
" start = timeit.default_timer()\n",
|
||||||
|
" for epoch in range(1, epochs+1):\n",
|
||||||
|
" print(f\"Epoch : {epoch}\\n\")\n",
|
||||||
|
" \n",
|
||||||
|
" for phase in ['train', 'val']:\n",
|
||||||
|
" if phase == 'train':\n",
|
||||||
|
" model.train()\n",
|
||||||
|
" else:\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
"\n",
|
||||||
|
" running_loss = 0.0\n",
|
||||||
|
" running_corrects = 0\n",
|
||||||
|
" \n",
|
||||||
|
" for (img1, img2, label) in tqdm(dataloaders[phase]):\n",
|
||||||
|
" img1, img2, label = img1.to(device), img2.to(device), label.to(device)\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.set_grad_enabled(phase == 'train'):\n",
|
||||||
|
" output = model([img1, img2])\n",
|
||||||
|
" _, pred = torch.max(output.data, 1)\n",
|
||||||
|
" loss = criterion(output, label)\n",
|
||||||
|
" if phase == 'train':\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" \n",
|
||||||
|
" running_loss += loss.item() * img1.size(0)\n",
|
||||||
|
" running_corrects += torch.sum(pred == label.data)\n",
|
||||||
|
" \n",
|
||||||
|
" epoch_loss = running_loss / dataset_sizes[phase]\n",
|
||||||
|
" epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
|
||||||
|
" \n",
|
||||||
|
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
|
||||||
|
" \n",
|
||||||
|
" if phase == 'train':\n",
|
||||||
|
" train_loss.append(epoch_loss)\n",
|
||||||
|
" train_acc.append(epoch_acc)\n",
|
||||||
|
" else:\n",
|
||||||
|
" val_loss.append(epoch_loss)\n",
|
||||||
|
" val_acc.append(epoch_acc)\n",
|
||||||
|
" scheduler.step(epoch_loss)\n",
|
||||||
|
"\n",
|
||||||
|
" current_lr = optimizer.param_groups[0]['lr']\n",
|
||||||
|
" print(f'val lr: {current_lr:.8f}')\n",
|
||||||
|
"\n",
|
||||||
|
" if epoch_loss < (best_val_loss - min_delta):\n",
|
||||||
|
" ind_limit = 0\n",
|
||||||
|
" best_val_loss = epoch_loss\n",
|
||||||
|
" best_model = copy.deepcopy(model.state_dict())\n",
|
||||||
|
" torch.save(best_model, path_res + model_name + '.pth')\n",
|
||||||
|
" print(f'saved best model with val_loss={best_val_loss:.4f}')\n",
|
||||||
|
" else:\n",
|
||||||
|
" ind_limit += 1\n",
|
||||||
|
" print(f'early stopping patience: {ind_limit}/{limit}')\n",
|
||||||
|
" \n",
|
||||||
|
" if ind_limit >= limit:\n",
|
||||||
|
" break\n",
|
||||||
|
" \n",
|
||||||
|
" if ind_limit >= limit:\n",
|
||||||
|
" epoch_limit = epoch\n",
|
||||||
|
" break\n",
|
||||||
|
" \n",
|
||||||
|
" print()\n",
|
||||||
|
" \n",
|
||||||
|
" end = timeit.default_timer()\n",
|
||||||
|
" print(f\"Total time elapsed = {end - start} seconds\")\n",
|
||||||
|
" epoch_limit += 1\n",
|
||||||
|
" \n",
|
||||||
|
" #----------Вывод графиков и сохранение результатов обучения--------------\n",
|
||||||
|
" \n",
|
||||||
|
" train_acc = np.asarray(list(map(lambda x: x.item(), train_acc)))\n",
|
||||||
|
" val_acc = np.asarray(list(map(lambda x: x.item(), val_acc)))\n",
|
||||||
|
" \n",
|
||||||
|
" np.save(path_res+'train_acc.npy', train_acc)\n",
|
||||||
|
" np.save(path_res+'val_acc.npy', val_acc)\n",
|
||||||
|
" np.save(path_res+'train_loss.npy', train_loss)\n",
|
||||||
|
" np.save(path_res+'val_loss.npy', val_loss)\n",
|
||||||
|
" \n",
|
||||||
|
" plt.figure()\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), train_loss, color='blue')\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), val_loss, color='red')\n",
|
||||||
|
" plt.xlabel('Epoch')\n",
|
||||||
|
" plt.ylabel('Loss') \n",
|
||||||
|
" \n",
|
||||||
|
" plt.title('Loss Curve')\n",
|
||||||
|
" plt.legend(['Train Loss', 'Validation Loss'])\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
" plt.clf()\n",
|
||||||
|
" plt.cla()\n",
|
||||||
|
" plt.close()\n",
|
||||||
|
" \n",
|
||||||
|
" plt.figure()\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), train_acc, color='blue')\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), val_acc, color='red')\n",
|
||||||
|
" plt.xlabel('Epoch')\n",
|
||||||
|
" plt.ylabel('Accuracy')\n",
|
||||||
|
" plt.title('Accuracy Curve')\n",
|
||||||
|
" plt.legend(['Train Accuracy', 'Validation Accuracy'])\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
" \n",
|
||||||
|
" plt.clf()\n",
|
||||||
|
" plt.cla()\n",
|
||||||
|
" plt.close()\n",
|
||||||
|
" torch.cuda.empty_cache()\n",
|
||||||
|
" cv2.destroyAllWindows()\n",
|
||||||
|
" del model\n",
|
||||||
|
" gc.collect()\n",
|
||||||
|
"\n",
|
||||||
|
" return path_res, model_name"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "bbbe7fea",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
|
||||||
|
" warnings.warn(msg)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "RuntimeError",
|
||||||
|
"evalue": "No valid image pairs left after shape validation",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||||
|
"\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 45\u001b[39m\n\u001b[32m 41\u001b[39m model = model.to(device)\n\u001b[32m 42\u001b[39m \n\u001b[32m 43\u001b[39m \u001b[38;5;66;03m#----------Создания датасета и обучение модели--------------\u001b[39;00m\n\u001b[32m 44\u001b[39m \n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n\u001b[32m 46\u001b[39m selected_freq=\u001b[32m1200\u001b[39m,model_name = config_name+\u001b[33m\"1200_\"\u001b[39m, config_name = config_name, model=model)\n\u001b[32m 47\u001b[39m \n\u001b[32m 48\u001b[39m \n",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 183\u001b[39m, in \u001b[36mprepare_and_learning_detection\u001b[39m\u001b[34m(num_classes, num_samples, path_dataset, model_name, config_name, model, selected_freq)\u001b[39m\n\u001b[32m 179\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m image_real, image_imag, label\n\u001b[32m 180\u001b[39m \n\u001b[32m 181\u001b[39m \u001b[38;5;66;03m#----------Создаём датасет--------------\u001b[39;00m\n\u001b[32m 182\u001b[39m \n\u001b[32m--> \u001b[39m\u001b[32m183\u001b[39m dataset = MyDataset(path_dataset=path_res, csv_file=\u001b[33m'dataset.csv'\u001b[39m)\n\u001b[32m 184\u001b[39m train_set, valid_set = torch.utils.data.random_split(dataset, [\u001b[32m0.7\u001b[39m, \u001b[32m0.3\u001b[39m], generator=torch.Generator().manual_seed(\u001b[32m42\u001b[39m))\n\u001b[32m 185\u001b[39m batch_size = config.batch_size\n\u001b[32m 186\u001b[39m train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 83\u001b[39m, in \u001b[36mprepare_and_learning_detection.<locals>.MyDataset.__init__\u001b[39m\u001b[34m(self, path_dataset, csv_file)\u001b[39m\n\u001b[32m 79\u001b[39m data.append(row[\u001b[32m2\u001b[39m: len(row)-\u001b[32m2\u001b[39m])\n\u001b[32m 80\u001b[39m self.path_dataset = path_dataset\n\u001b[32m 81\u001b[39m self.target_shape = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 82\u001b[39m self.target_hw = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m83\u001b[39m self.sig_filenames = self._validate_files(data)\n",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 126\u001b[39m, in \u001b[36mprepare_and_learning_detection.<locals>.MyDataset._validate_files\u001b[39m\u001b[34m(self, filenames)\u001b[39m\n\u001b[32m 122\u001b[39m shape_counter[real_shape] += \u001b[32m1\u001b[39m\n\u001b[32m 123\u001b[39m candidates.append((filename, real_shape))\n\u001b[32m 124\u001b[39m \n\u001b[32m 125\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;01mnot\u001b[39;00m candidates:\n\u001b[32m--> \u001b[39m\u001b[32m126\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(\u001b[33m\"No valid image pairs left after shape validation\"\u001b[39m)\n\u001b[32m 127\u001b[39m \n\u001b[32m 128\u001b[39m preferred_shape = (\u001b[32m1600\u001b[39m, \u001b[32m1600\u001b[39m, \u001b[32m3\u001b[39m)\n\u001b[32m 129\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m shape_counter.get(preferred_shape, \u001b[32m0\u001b[39m) > \u001b[32m0\u001b[39m:\n",
|
||||||
|
"\u001b[31mRuntimeError\u001b[39m: No valid image pairs left after shape validation"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"torch.cuda.empty_cache()\n",
|
||||||
|
"cv2.destroyAllWindows()\n",
|
||||||
|
"gc.collect()\n",
|
||||||
|
"\n",
|
||||||
|
"config_name = \"ensemble\"\n",
|
||||||
|
" \n",
|
||||||
|
"def load_function(attr):\n",
|
||||||
|
" module_, func = attr.rsplit('.', maxsplit=1)\n",
|
||||||
|
" return getattr(import_module(module_), func)\n",
|
||||||
|
" \n",
|
||||||
|
"config = mlconfig.load('config_' + config_name + '.yaml')\n",
|
||||||
|
"\n",
|
||||||
|
"model1 = models.resnet18(pretrained=False)\n",
|
||||||
|
"model2 = models.resnet50(pretrained=False)\n",
|
||||||
|
"\n",
|
||||||
|
"num_classes = 2\n",
|
||||||
|
"\n",
|
||||||
|
"model1.fc = nn.Linear(model1.fc.in_features, num_classes)\n",
|
||||||
|
"model2.fc = nn.Linear(model2.fc.in_features, num_classes)\n",
|
||||||
|
"\n",
|
||||||
|
"class Ensemble(nn.Module):\n",
|
||||||
|
" def __init__(self, model1, model2):\n",
|
||||||
|
" super(Ensemble, self).__init__()\n",
|
||||||
|
" self.model1 = model1\n",
|
||||||
|
" self.model2 = model2\n",
|
||||||
|
" self.fc = nn.Linear(2 * num_classes, num_classes)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x1 = self.model1(x[0])\n",
|
||||||
|
" x2 = self.model2(x[1])\n",
|
||||||
|
" x = torch.cat((x1, x2), dim=1)\n",
|
||||||
|
" x = self.fc(x)\n",
|
||||||
|
" return x\n",
|
||||||
|
"model = Ensemble(model1, model2)\n",
|
||||||
|
"\n",
|
||||||
|
"optimizer = load_function(config.optimizer.name)(model.parameters(), lr=config.optimizer.lr)\n",
|
||||||
|
"criterion = load_function(config.loss_function.name)()\n",
|
||||||
|
"scheduler = load_function(config.scheduler.name)(optimizer, step_size=config.scheduler.step_size, gamma=config.scheduler.gamma)\n",
|
||||||
|
"\n",
|
||||||
|
"if device != 'cpu':\n",
|
||||||
|
" model = model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"#----------Создания датасета и обучение модели--------------\n",
|
||||||
|
"\n",
|
||||||
|
"path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n",
|
||||||
|
" selected_freq=2400,model_name = config_name+\"2400_\", config_name = config_name, model=model)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"torch.cuda.empty_cache()\n",
|
||||||
|
"cv2.destroyAllWindows()\n",
|
||||||
|
"del model\n",
|
||||||
|
"gc.collect()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "usr",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@ -0,0 +1,217 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_FREQS = (1200, 2400)
|
||||||
|
PNG_SUFFIXES = ("_real.png", "_imag.png", "_spec.png")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"Build a two-class image dataset with drone signatures overlaid on noise images. "
|
||||||
|
"The output is ready for Training_models2pic_val_loss.ipynb."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser.add_argument("--drone-root", default="/mnt/data/Dataset/drone")
|
||||||
|
parser.add_argument("--noise-img-root", default="/mnt/data/Dataset_img/noise")
|
||||||
|
parser.add_argument("--output-root", default="/mnt/data/Dataset_overlay")
|
||||||
|
parser.add_argument("--freqs", default=",".join(str(v) for v in DEFAULT_FREQS))
|
||||||
|
parser.add_argument("--alpha", type=float, default=1.0, help="Overlay strength: 1.0 keeps the darkest drone/noise pixels.")
|
||||||
|
parser.add_argument("--limit-per-freq", type=int, default=0, help="0 means use all available noise images per frequency.")
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--copy-noise", action="store_true", help="Copy noise files instead of hardlinking them.")
|
||||||
|
parser.add_argument("--overwrite", action="store_true")
|
||||||
|
parser.add_argument("--dry-run", action="store_true")
|
||||||
|
parser.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bars.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_freqs(value):
|
||||||
|
return [int(item.strip()) for item in value.split(",") if item.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def collect_drone_npys(root, freq):
|
||||||
|
root = Path(root)
|
||||||
|
candidates = []
|
||||||
|
candidates.extend((root / str(freq)).rglob("*.npy"))
|
||||||
|
candidates.extend((root / f"{freq}_jpg").glob("*.npy"))
|
||||||
|
return sorted({p for p in candidates if p.is_file()})
|
||||||
|
|
||||||
|
|
||||||
|
def collect_noise_npys(root, freq):
|
||||||
|
root = Path(root)
|
||||||
|
candidates = []
|
||||||
|
candidates.extend((root / f"{freq}_jpg").glob("*.npy"))
|
||||||
|
candidates.extend((root / str(freq)).rglob("*.npy"))
|
||||||
|
return sorted({p for p in candidates if p.is_file()})
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_tensor(path):
|
||||||
|
arr = np.load(path)
|
||||||
|
if arr.ndim == 2:
|
||||||
|
arr = np.stack([arr, arr, arr], axis=0)
|
||||||
|
if arr.ndim == 3 and arr.shape[-1] in (1, 3) and arr.shape[0] not in (1, 3):
|
||||||
|
arr = np.moveaxis(arr, -1, 0)
|
||||||
|
if arr.ndim != 3:
|
||||||
|
raise ValueError(f"expected 3D image tensor, got shape={arr.shape} path={path}")
|
||||||
|
if arr.shape[0] == 1:
|
||||||
|
arr = np.repeat(arr, 3, axis=0)
|
||||||
|
if arr.shape[0] < 3:
|
||||||
|
raise ValueError(f"expected at least 3 channels, got shape={arr.shape} path={path}")
|
||||||
|
return arr[:3].astype(np.float32, copy=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_like(arr, shape):
|
||||||
|
if arr.shape == shape:
|
||||||
|
return arr
|
||||||
|
channels, height, width = shape
|
||||||
|
resized = []
|
||||||
|
for channel in arr[:channels]:
|
||||||
|
resized.append(cv2.resize(channel, (width, height), interpolation=cv2.INTER_LINEAR))
|
||||||
|
return np.asarray(resized, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def overlay_tensors(noise, drone, alpha):
|
||||||
|
drone = resize_like(drone, noise.shape)
|
||||||
|
bright_overlay = np.minimum(noise, drone)
|
||||||
|
mixed = (1.0 - alpha) * noise + alpha * bright_overlay
|
||||||
|
return np.clip(mixed, 0, 255).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def to_uint8(channel):
|
||||||
|
arr = np.asarray(channel, dtype=np.float32)
|
||||||
|
finite = arr[np.isfinite(arr)]
|
||||||
|
if finite.size == 0:
|
||||||
|
return np.zeros(arr.shape, dtype=np.uint8)
|
||||||
|
min_v = float(finite.min())
|
||||||
|
max_v = float(finite.max())
|
||||||
|
if min_v >= 0.0 and max_v <= 255.0:
|
||||||
|
return np.clip(arr, 0, 255).astype(np.uint8)
|
||||||
|
if max_v == min_v:
|
||||||
|
return np.zeros(arr.shape, dtype=np.uint8)
|
||||||
|
norm = (arr - min_v) / (max_v - min_v)
|
||||||
|
return np.clip(norm * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def save_notebook_style_png(path, channel):
|
||||||
|
fig = plt.figure(figsize=(16, 16))
|
||||||
|
plt.imshow(channel)
|
||||||
|
plt.savefig(path)
|
||||||
|
plt.clf()
|
||||||
|
plt.cla()
|
||||||
|
plt.close()
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def save_sample(base_path, tensor):
|
||||||
|
np.save(str(base_path) + ".npy", tensor.astype(np.float32))
|
||||||
|
for idx, suffix in enumerate(PNG_SUFFIXES):
|
||||||
|
save_notebook_style_png(str(base_path) + suffix, tensor[idx])
|
||||||
|
|
||||||
|
|
||||||
|
def link_or_copy(src, dst, copy_file):
|
||||||
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if dst.exists():
|
||||||
|
return
|
||||||
|
if copy_file:
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
os.link(src, dst)
|
||||||
|
except OSError:
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_noise_family(noise_npy, out_dir, copy_file):
|
||||||
|
base_name = noise_npy.name[:-4] if noise_npy.name.endswith(".npy") else noise_npy.name
|
||||||
|
link_or_copy(noise_npy, out_dir / noise_npy.name, copy_file)
|
||||||
|
for suffix in PNG_SUFFIXES:
|
||||||
|
sidecar = noise_npy.with_name(base_name + suffix)
|
||||||
|
if sidecar.exists():
|
||||||
|
link_or_copy(sidecar, out_dir / sidecar.name, copy_file)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
random.seed(args.seed)
|
||||||
|
freqs = parse_freqs(args.freqs)
|
||||||
|
output_root = Path(args.output_root)
|
||||||
|
manifest_rows = []
|
||||||
|
|
||||||
|
if args.overwrite and output_root.exists() and not args.dry_run:
|
||||||
|
shutil.rmtree(output_root)
|
||||||
|
|
||||||
|
for freq in freqs:
|
||||||
|
drone_files = collect_drone_npys(args.drone_root, freq)
|
||||||
|
noise_files = collect_noise_npys(args.noise_img_root, freq)
|
||||||
|
if not drone_files:
|
||||||
|
raise RuntimeError(f"no drone npy files found for freq={freq} under {args.drone_root}")
|
||||||
|
if not noise_files:
|
||||||
|
raise RuntimeError(f"no noise image npy files found for freq={freq} under {args.noise_img_root}")
|
||||||
|
|
||||||
|
count = len(noise_files) if args.limit_per_freq <= 0 else min(args.limit_per_freq, len(noise_files))
|
||||||
|
selected_noise = noise_files[:]
|
||||||
|
random.shuffle(selected_noise)
|
||||||
|
selected_noise = selected_noise[:count]
|
||||||
|
|
||||||
|
out_drone_dir = output_root / "drone" / f"{freq}_jpg"
|
||||||
|
out_noise_dir = output_root / "noise" / f"{freq}_jpg"
|
||||||
|
|
||||||
|
print(f"freq={freq}: drone_source={len(drone_files)} noise_source={len(noise_files)} output_per_class={count}")
|
||||||
|
if args.dry_run:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_drone_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_noise_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
iterator = tqdm(
|
||||||
|
enumerate(selected_noise),
|
||||||
|
total=count,
|
||||||
|
desc=f"overlay freq={freq}",
|
||||||
|
unit="sample",
|
||||||
|
disable=args.no_progress,
|
||||||
|
)
|
||||||
|
for idx, noise_path in iterator:
|
||||||
|
drone_path = drone_files[idx % len(drone_files)]
|
||||||
|
noise_tensor = load_image_tensor(noise_path)
|
||||||
|
drone_tensor = load_image_tensor(drone_path)
|
||||||
|
mixed = overlay_tensors(noise_tensor, drone_tensor, args.alpha)
|
||||||
|
|
||||||
|
out_base = out_drone_dir / f"overlay_{freq}_{idx:06d}"
|
||||||
|
save_sample(out_base, mixed)
|
||||||
|
copy_noise_family(noise_path, out_noise_dir, args.copy_noise)
|
||||||
|
|
||||||
|
manifest_rows.append({
|
||||||
|
"freq": freq,
|
||||||
|
"output": str(out_base) + ".npy",
|
||||||
|
"noise_source": str(noise_path),
|
||||||
|
"drone_source": str(drone_path),
|
||||||
|
"alpha": args.alpha,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not args.dry_run:
|
||||||
|
manifest_path = output_root / "overlay_manifest.csv"
|
||||||
|
with manifest_path.open("w", newline="", encoding="utf-8") as fh:
|
||||||
|
writer = csv.DictWriter(fh, fieldnames=["freq", "output", "noise_source", "drone_source", "alpha"])
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(manifest_rows)
|
||||||
|
print(f"wrote {len(manifest_rows)} overlay samples")
|
||||||
|
print(f"manifest: {manifest_path}")
|
||||||
|
print(f"dataset: {output_root}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue