Новая версия ноутбука для обучения
parent
94856d0fb8
commit
c70a25cb8f
@ -0,0 +1,523 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "e1db882b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.\n",
|
||||||
|
" warnings.warn(\"Unable to import Axes3D. This may be due to multiple versions of \"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"cuda\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"220"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"from torch import default_generator, randperm\n",
|
||||||
|
"from torch.utils.data.dataset import Subset\n",
|
||||||
|
"import torchvision.transforms as transforms\n",
|
||||||
|
"from torchvision.io import read_image\n",
|
||||||
|
"from importlib import import_module\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from torchvision import models\n",
|
||||||
|
"import torch, torchvision\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import matplotlib\n",
|
||||||
|
"import os, shutil\n",
|
||||||
|
"import mlconfig\n",
|
||||||
|
"import random\n",
|
||||||
|
"import shutil\n",
|
||||||
|
"import timeit\n",
|
||||||
|
"import copy\n",
|
||||||
|
"import time\n",
|
||||||
|
"import cv2\n",
|
||||||
|
"import csv\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import io\n",
|
||||||
|
"import gc\n",
|
||||||
|
"\n",
|
||||||
|
"plt.rcParams[\"savefig.bbox\"] = 'tight'\n",
|
||||||
|
"torch.manual_seed(1)\n",
|
||||||
|
"#matplotlib.use('Agg')\n",
|
||||||
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||||
|
"print(device)\n",
|
||||||
|
"torch.cuda.empty_cache()\n",
|
||||||
|
"cv2.destroyAllWindows()\n",
|
||||||
|
"gc.collect()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "8e009995",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def prepare_and_learning_detection(num_classes, num_samples, path_dataset, model_name, config_name, model,selected_freq):\n",
|
||||||
|
" num_samples_per_class = num_samples // num_classes\n",
|
||||||
|
"\n",
|
||||||
|
" #----------Создаём папку для сохранения результатов обучения--------------\n",
|
||||||
|
" os.makedirs(\"models\", exist_ok=True)\n",
|
||||||
|
" ind = 1\n",
|
||||||
|
" while True:\n",
|
||||||
|
" if os.path.exists(\"models/\" + model_name + str(ind)):\n",
|
||||||
|
" ind += 1\n",
|
||||||
|
" else:\n",
|
||||||
|
" os.mkdir(\"models/\" + model_name + str(ind))\n",
|
||||||
|
" path_res = \"models/\" + model_name + str(ind) + '/'\n",
|
||||||
|
" break\n",
|
||||||
|
" \n",
|
||||||
|
" #----------Создаём файл dataset.csv для обучения--------------\n",
|
||||||
|
" \n",
|
||||||
|
" pd_columns = ['file_name']\n",
|
||||||
|
" df = pd.DataFrame(columns=pd_columns)\n",
|
||||||
|
" \n",
|
||||||
|
" subdirs = os.listdir(path_dataset)\n",
|
||||||
|
" \n",
|
||||||
|
" for subdir in subdirs:\n",
|
||||||
|
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n",
|
||||||
|
" if not os.path.isdir(freq_dir):\n",
|
||||||
|
" print(\"Error1\")\n",
|
||||||
|
" continue\n",
|
||||||
|
" \n",
|
||||||
|
" files_k=[f for f in os.listdir(freq_dir)]\n",
|
||||||
|
" print(len(files_k))\n",
|
||||||
|
" \n",
|
||||||
|
" files = [\n",
|
||||||
|
" f for f in os.listdir(freq_dir)\n",
|
||||||
|
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n",
|
||||||
|
" ]\n",
|
||||||
|
" num_samples_per_class = min(num_samples_per_class, len(files))\n",
|
||||||
|
" print(f\"num_samples per class {subdir} is {num_samples_per_class}\")\n",
|
||||||
|
"\n",
|
||||||
|
" for subdir in subdirs:\n",
|
||||||
|
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n",
|
||||||
|
" if not os.path.isdir(freq_dir):\n",
|
||||||
|
" print(\"Error1\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" files = [\n",
|
||||||
|
" f for f in os.listdir(freq_dir)\n",
|
||||||
|
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n",
|
||||||
|
" ]\n",
|
||||||
|
" random.shuffle(files)\n",
|
||||||
|
" files_to_process = files[:num_samples_per_class]\n",
|
||||||
|
"\n",
|
||||||
|
" for file in files_to_process:\n",
|
||||||
|
" row = pd.DataFrame({\n",
|
||||||
|
" pd_columns[0]: [str(os.path.join(freq_dir, file))]\n",
|
||||||
|
" })\n",
|
||||||
|
" df = pd.concat([df, row], ignore_index=True)\n",
|
||||||
|
"\n",
|
||||||
|
" dataset_csv_path = os.path.join(path_res, 'dataset.csv')\n",
|
||||||
|
" df.to_csv(dataset_csv_path, index=False)\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(dataset_csv_path):\n",
|
||||||
|
" raise RuntimeError(f'dataset.csv was not created: {dataset_csv_path}')\n",
|
||||||
|
" #----------Импортируем параметры для обучения--------------\n",
|
||||||
|
" \n",
|
||||||
|
" def load_function(attr):\n",
|
||||||
|
" module_, func = attr.rsplit('.', maxsplit=1)\n",
|
||||||
|
" return getattr(import_module(module_), func)\n",
|
||||||
|
" \n",
|
||||||
|
" config = mlconfig.load('config_' + config_name + '.yaml')\n",
|
||||||
|
" \n",
|
||||||
|
" #----------Создаём класс датасета--------------\n",
|
||||||
|
" \n",
|
||||||
|
" class MyDataset(Dataset):\n",
|
||||||
|
" def __init__(self, path_dataset, csv_file):\n",
|
||||||
|
" data=[]\n",
|
||||||
|
" with open(os.path.join(path_dataset, csv_file), newline='') as csvfile:\n",
|
||||||
|
" reader = csv.reader(csvfile, delimiter=' ', quotechar='|')\n",
|
||||||
|
" for row in list(reader)[1:]:\n",
|
||||||
|
" row = str(row)\n",
|
||||||
|
" data.append(row[2: len(row)-2])\n",
|
||||||
|
" self.path_dataset = path_dataset\n",
|
||||||
|
" self.target_shape = None\n",
|
||||||
|
" self.target_hw = None\n",
|
||||||
|
" self.sig_filenames = self._validate_files(data)\n",
|
||||||
|
"\n",
|
||||||
|
" def _pair_paths(self, filename):\n",
|
||||||
|
" base = os.path.splitext(filename)[0]\n",
|
||||||
|
" if base.endswith(\"real\"):\n",
|
||||||
|
" return base + \".png\", base[:-4] + \"imag.png\"\n",
|
||||||
|
" if base.endswith(\"imag\"):\n",
|
||||||
|
" return base[:-4] + \"real.png\", base + \".png\"\n",
|
||||||
|
" return None, None\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def _read_shape(path):\n",
|
||||||
|
" img = cv2.imread(path)\n",
|
||||||
|
" if img is None:\n",
|
||||||
|
" return None\n",
|
||||||
|
" return img.shape\n",
|
||||||
|
"\n",
|
||||||
|
" def _validate_files(self, filenames):\n",
|
||||||
|
" from collections import Counter\n",
|
||||||
|
"\n",
|
||||||
|
" candidates = []\n",
|
||||||
|
" dropped = []\n",
|
||||||
|
" shape_counter = Counter()\n",
|
||||||
|
"\n",
|
||||||
|
" for filename in filenames:\n",
|
||||||
|
" real_path, imag_path = self._pair_paths(filename)\n",
|
||||||
|
" if real_path is None or imag_path is None:\n",
|
||||||
|
" dropped.append((filename, \"bad file suffix\"))\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" real_shape = self._read_shape(real_path)\n",
|
||||||
|
" imag_shape = self._read_shape(imag_path)\n",
|
||||||
|
" if real_shape is None or imag_shape is None:\n",
|
||||||
|
" dropped.append((filename, f\"read failed real={real_shape} imag={imag_shape}\"))\n",
|
||||||
|
" continue\n",
|
||||||
|
" if real_shape != imag_shape:\n",
|
||||||
|
" dropped.append((filename, f\"pair shape mismatch real={real_shape} imag={imag_shape}\"))\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" shape_counter[real_shape] += 1\n",
|
||||||
|
" candidates.append((filename, real_shape))\n",
|
||||||
|
"\n",
|
||||||
|
" if not candidates:\n",
|
||||||
|
" raise RuntimeError(\"No valid image pairs left after shape validation\")\n",
|
||||||
|
"\n",
|
||||||
|
" preferred_shape = (1600, 1600, 3)\n",
|
||||||
|
" if shape_counter.get(preferred_shape, 0) > 0:\n",
|
||||||
|
" target_shape = preferred_shape\n",
|
||||||
|
" else:\n",
|
||||||
|
" target_shape = shape_counter.most_common(1)[0][0]\n",
|
||||||
|
"\n",
|
||||||
|
" self.target_shape = target_shape\n",
|
||||||
|
" self.target_hw = target_shape[:2]\n",
|
||||||
|
" valid = [filename for filename, _shape in candidates]\n",
|
||||||
|
" resized_count = sum(1 for _filename, shape in candidates if shape != target_shape)\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"[dataset-shape-filter] shape_distribution={dict(shape_counter)}\")\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"[dataset-shape-filter] target_shape={target_shape} \"\n",
|
||||||
|
" f\"kept={len(valid)} dropped={len(dropped)} will_resize={resized_count}\"\n",
|
||||||
|
" )\n",
|
||||||
|
" for filename, reason in dropped[:20]:\n",
|
||||||
|
" print(f\"[dataset-shape-filter] drop {filename}: {reason}\")\n",
|
||||||
|
" if len(dropped) > 20:\n",
|
||||||
|
" print(f\"[dataset-shape-filter] ... {len(dropped) - 20} more dropped\")\n",
|
||||||
|
"\n",
|
||||||
|
" return valid\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.sig_filenames)\n",
|
||||||
|
"\n",
|
||||||
|
" def _read_image_chw(self, path):\n",
|
||||||
|
" img = cv2.imread(path)\n",
|
||||||
|
" if img is None:\n",
|
||||||
|
" raise RuntimeError(f\"failed to read image: {path}\")\n",
|
||||||
|
" if img.shape[:2] != self.target_hw:\n",
|
||||||
|
" img = cv2.resize(\n",
|
||||||
|
" img,\n",
|
||||||
|
" (self.target_hw[1], self.target_hw[0]),\n",
|
||||||
|
" interpolation=cv2.INTER_AREA,\n",
|
||||||
|
" )\n",
|
||||||
|
" return np.asarray(cv2.split(img), dtype=np.float32)\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" real_path, imag_path = self._pair_paths(self.sig_filenames[idx])\n",
|
||||||
|
" image_real = self._read_image_chw(real_path)\n",
|
||||||
|
" image_imag = self._read_image_chw(imag_path)\n",
|
||||||
|
"\n",
|
||||||
|
" path_parts = set(self.sig_filenames[idx].split('/'))\n",
|
||||||
|
" if 'drone' in path_parts:\n",
|
||||||
|
" label = torch.tensor(0)\n",
|
||||||
|
" elif 'noise' in path_parts:\n",
|
||||||
|
" label = torch.tensor(1)\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise RuntimeError(f\"cannot infer label from path: {self.sig_filenames[idx]}\")\n",
|
||||||
|
"\n",
|
||||||
|
" return image_real, image_imag, label\n",
|
||||||
|
"\n",
|
||||||
|
" #----------Создаём датасет--------------\n",
|
||||||
|
" \n",
|
||||||
|
" dataset = MyDataset(path_dataset=path_res, csv_file='dataset.csv')\n",
|
||||||
|
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n",
|
||||||
|
" batch_size = config.batch_size\n",
|
||||||
|
" train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
|
||||||
|
" valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, drop_last=False)\n",
|
||||||
|
" \n",
|
||||||
|
" dataloaders = {}\n",
|
||||||
|
" dataloaders['train'] = train_dataloader\n",
|
||||||
|
" dataloaders['val'] = valid_dataloader\n",
|
||||||
|
" dataset_sizes = {}\n",
|
||||||
|
" dataset_sizes['train'] = len(train_set)\n",
|
||||||
|
" dataset_sizes['val'] = len(valid_set)\n",
|
||||||
|
"\n",
|
||||||
|
" #----------Обучаем модель--------------\n",
|
||||||
|
"\n",
|
||||||
|
" val_loss = []\n",
|
||||||
|
" val_acc = []\n",
|
||||||
|
" train_loss = []\n",
|
||||||
|
" train_acc = []\n",
|
||||||
|
" epochs = config.epoch\n",
|
||||||
|
" min_delta = 1e-4\n",
|
||||||
|
" \n",
|
||||||
|
" best_val_loss = float('inf')\n",
|
||||||
|
" best_model = copy.deepcopy(model.state_dict())\n",
|
||||||
|
" limit = config.limit\n",
|
||||||
|
" ind_limit = 0\n",
|
||||||
|
" epoch_limit = epochs\n",
|
||||||
|
" \n",
|
||||||
|
" start = timeit.default_timer()\n",
|
||||||
|
" for epoch in range(1, epochs+1):\n",
|
||||||
|
" print(f\"Epoch : {epoch}\\n\")\n",
|
||||||
|
" \n",
|
||||||
|
" for phase in ['train', 'val']:\n",
|
||||||
|
" if phase == 'train':\n",
|
||||||
|
" model.train()\n",
|
||||||
|
" else:\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
"\n",
|
||||||
|
" running_loss = 0.0\n",
|
||||||
|
" running_corrects = 0\n",
|
||||||
|
" \n",
|
||||||
|
" for (img1, img2, label) in tqdm(dataloaders[phase]):\n",
|
||||||
|
" img1, img2, label = img1.to(device), img2.to(device), label.to(device)\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.set_grad_enabled(phase == 'train'):\n",
|
||||||
|
" output = model([img1, img2])\n",
|
||||||
|
" _, pred = torch.max(output.data, 1)\n",
|
||||||
|
" loss = criterion(output, label)\n",
|
||||||
|
" if phase == 'train':\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" \n",
|
||||||
|
" running_loss += loss.item() * img1.size(0)\n",
|
||||||
|
" running_corrects += torch.sum(pred == label.data)\n",
|
||||||
|
" \n",
|
||||||
|
" epoch_loss = running_loss / dataset_sizes[phase]\n",
|
||||||
|
" epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
|
||||||
|
" \n",
|
||||||
|
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
|
||||||
|
" \n",
|
||||||
|
" if phase == 'train':\n",
|
||||||
|
" train_loss.append(epoch_loss)\n",
|
||||||
|
" train_acc.append(epoch_acc)\n",
|
||||||
|
" else:\n",
|
||||||
|
" val_loss.append(epoch_loss)\n",
|
||||||
|
" val_acc.append(epoch_acc)\n",
|
||||||
|
" scheduler.step(epoch_loss)\n",
|
||||||
|
"\n",
|
||||||
|
" current_lr = optimizer.param_groups[0]['lr']\n",
|
||||||
|
" print(f'val lr: {current_lr:.8f}')\n",
|
||||||
|
"\n",
|
||||||
|
" if epoch_loss < (best_val_loss - min_delta):\n",
|
||||||
|
" ind_limit = 0\n",
|
||||||
|
" best_val_loss = epoch_loss\n",
|
||||||
|
" best_model = copy.deepcopy(model.state_dict())\n",
|
||||||
|
" torch.save(best_model, path_res + model_name + '.pth')\n",
|
||||||
|
" print(f'saved best model with val_loss={best_val_loss:.4f}')\n",
|
||||||
|
" else:\n",
|
||||||
|
" ind_limit += 1\n",
|
||||||
|
" print(f'early stopping patience: {ind_limit}/{limit}')\n",
|
||||||
|
" \n",
|
||||||
|
" if ind_limit >= limit:\n",
|
||||||
|
" break\n",
|
||||||
|
" \n",
|
||||||
|
" if ind_limit >= limit:\n",
|
||||||
|
" epoch_limit = epoch\n",
|
||||||
|
" break\n",
|
||||||
|
" \n",
|
||||||
|
" print()\n",
|
||||||
|
" \n",
|
||||||
|
" end = timeit.default_timer()\n",
|
||||||
|
" print(f\"Total time elapsed = {end - start} seconds\")\n",
|
||||||
|
" epoch_limit += 1\n",
|
||||||
|
" \n",
|
||||||
|
" #----------Вывод графиков и сохранение результатов обучения--------------\n",
|
||||||
|
" \n",
|
||||||
|
" train_acc = np.asarray(list(map(lambda x: x.item(), train_acc)))\n",
|
||||||
|
" val_acc = np.asarray(list(map(lambda x: x.item(), val_acc)))\n",
|
||||||
|
" \n",
|
||||||
|
" np.save(path_res+'train_acc.npy', train_acc)\n",
|
||||||
|
" np.save(path_res+'val_acc.npy', val_acc)\n",
|
||||||
|
" np.save(path_res+'train_loss.npy', train_loss)\n",
|
||||||
|
" np.save(path_res+'val_loss.npy', val_loss)\n",
|
||||||
|
" \n",
|
||||||
|
" plt.figure()\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), train_loss, color='blue')\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), val_loss, color='red')\n",
|
||||||
|
" plt.xlabel('Epoch')\n",
|
||||||
|
" plt.ylabel('Loss') \n",
|
||||||
|
" \n",
|
||||||
|
" plt.title('Loss Curve')\n",
|
||||||
|
" plt.legend(['Train Loss', 'Validation Loss'])\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
" plt.clf()\n",
|
||||||
|
" plt.cla()\n",
|
||||||
|
" plt.close()\n",
|
||||||
|
" \n",
|
||||||
|
" plt.figure()\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), train_acc, color='blue')\n",
|
||||||
|
" plt.plot(range(1,epoch_limit), val_acc, color='red')\n",
|
||||||
|
" plt.xlabel('Epoch')\n",
|
||||||
|
" plt.ylabel('Accuracy')\n",
|
||||||
|
" plt.title('Accuracy Curve')\n",
|
||||||
|
" plt.legend(['Train Accuracy', 'Validation Accuracy'])\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
" \n",
|
||||||
|
" plt.clf()\n",
|
||||||
|
" plt.cla()\n",
|
||||||
|
" plt.close()\n",
|
||||||
|
" torch.cuda.empty_cache()\n",
|
||||||
|
" cv2.destroyAllWindows()\n",
|
||||||
|
" del model\n",
|
||||||
|
" gc.collect()\n",
|
||||||
|
"\n",
|
||||||
|
" return path_res, model_name"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "bbbe7fea",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
|
||||||
|
" warnings.warn(msg)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n",
|
||||||
|
"Error1\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "RuntimeError",
|
||||||
|
"evalue": "No valid image pairs left after shape validation",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||||
|
"\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 45\u001b[39m\n\u001b[32m 41\u001b[39m model = model.to(device)\n\u001b[32m 42\u001b[39m \n\u001b[32m 43\u001b[39m \u001b[38;5;66;03m#----------Создания датасета и обучение модели--------------\u001b[39;00m\n\u001b[32m 44\u001b[39m \n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n\u001b[32m 46\u001b[39m selected_freq=\u001b[32m1200\u001b[39m,model_name = config_name+\u001b[33m\"1200_\"\u001b[39m, config_name = config_name, model=model)\n\u001b[32m 47\u001b[39m \n\u001b[32m 48\u001b[39m \n",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 183\u001b[39m, in \u001b[36mprepare_and_learning_detection\u001b[39m\u001b[34m(num_classes, num_samples, path_dataset, model_name, config_name, model, selected_freq)\u001b[39m\n\u001b[32m 179\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m image_real, image_imag, label\n\u001b[32m 180\u001b[39m \n\u001b[32m 181\u001b[39m \u001b[38;5;66;03m#----------Создаём датасет--------------\u001b[39;00m\n\u001b[32m 182\u001b[39m \n\u001b[32m--> \u001b[39m\u001b[32m183\u001b[39m dataset = MyDataset(path_dataset=path_res, csv_file=\u001b[33m'dataset.csv'\u001b[39m)\n\u001b[32m 184\u001b[39m train_set, valid_set = torch.utils.data.random_split(dataset, [\u001b[32m0.7\u001b[39m, \u001b[32m0.3\u001b[39m], generator=torch.Generator().manual_seed(\u001b[32m42\u001b[39m))\n\u001b[32m 185\u001b[39m batch_size = config.batch_size\n\u001b[32m 186\u001b[39m train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 83\u001b[39m, in \u001b[36mprepare_and_learning_detection.<locals>.MyDataset.__init__\u001b[39m\u001b[34m(self, path_dataset, csv_file)\u001b[39m\n\u001b[32m 79\u001b[39m data.append(row[\u001b[32m2\u001b[39m: len(row)-\u001b[32m2\u001b[39m])\n\u001b[32m 80\u001b[39m self.path_dataset = path_dataset\n\u001b[32m 81\u001b[39m self.target_shape = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 82\u001b[39m self.target_hw = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m83\u001b[39m self.sig_filenames = self._validate_files(data)\n",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 126\u001b[39m, in \u001b[36mprepare_and_learning_detection.<locals>.MyDataset._validate_files\u001b[39m\u001b[34m(self, filenames)\u001b[39m\n\u001b[32m 122\u001b[39m shape_counter[real_shape] += \u001b[32m1\u001b[39m\n\u001b[32m 123\u001b[39m candidates.append((filename, real_shape))\n\u001b[32m 124\u001b[39m \n\u001b[32m 125\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;01mnot\u001b[39;00m candidates:\n\u001b[32m--> \u001b[39m\u001b[32m126\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(\u001b[33m\"No valid image pairs left after shape validation\"\u001b[39m)\n\u001b[32m 127\u001b[39m \n\u001b[32m 128\u001b[39m preferred_shape = (\u001b[32m1600\u001b[39m, \u001b[32m1600\u001b[39m, \u001b[32m3\u001b[39m)\n\u001b[32m 129\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m shape_counter.get(preferred_shape, \u001b[32m0\u001b[39m) > \u001b[32m0\u001b[39m:\n",
|
||||||
|
"\u001b[31mRuntimeError\u001b[39m: No valid image pairs left after shape validation"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"torch.cuda.empty_cache()\n",
|
||||||
|
"cv2.destroyAllWindows()\n",
|
||||||
|
"gc.collect()\n",
|
||||||
|
"\n",
|
||||||
|
"config_name = \"ensemble\"\n",
|
||||||
|
" \n",
|
||||||
|
"def load_function(attr):\n",
|
||||||
|
" module_, func = attr.rsplit('.', maxsplit=1)\n",
|
||||||
|
" return getattr(import_module(module_), func)\n",
|
||||||
|
" \n",
|
||||||
|
"config = mlconfig.load('config_' + config_name + '.yaml')\n",
|
||||||
|
"\n",
|
||||||
|
"model1 = models.resnet18(pretrained=False)\n",
|
||||||
|
"model2 = models.resnet50(pretrained=False)\n",
|
||||||
|
"\n",
|
||||||
|
"num_classes = 2\n",
|
||||||
|
"\n",
|
||||||
|
"model1.fc = nn.Linear(model1.fc.in_features, num_classes)\n",
|
||||||
|
"model2.fc = nn.Linear(model2.fc.in_features, num_classes)\n",
|
||||||
|
"\n",
|
||||||
|
"class Ensemble(nn.Module):\n",
|
||||||
|
" def __init__(self, model1, model2):\n",
|
||||||
|
" super(Ensemble, self).__init__()\n",
|
||||||
|
" self.model1 = model1\n",
|
||||||
|
" self.model2 = model2\n",
|
||||||
|
" self.fc = nn.Linear(2 * num_classes, num_classes)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x1 = self.model1(x[0])\n",
|
||||||
|
" x2 = self.model2(x[1])\n",
|
||||||
|
" x = torch.cat((x1, x2), dim=1)\n",
|
||||||
|
" x = self.fc(x)\n",
|
||||||
|
" return x\n",
|
||||||
|
"model = Ensemble(model1, model2)\n",
|
||||||
|
"\n",
|
||||||
|
"optimizer = load_function(config.optimizer.name)(model.parameters(), lr=config.optimizer.lr)\n",
|
||||||
|
"criterion = load_function(config.loss_function.name)()\n",
|
||||||
|
"scheduler = load_function(config.scheduler.name)(optimizer, step_size=config.scheduler.step_size, gamma=config.scheduler.gamma)\n",
|
||||||
|
"\n",
|
||||||
|
"if device != 'cpu':\n",
|
||||||
|
" model = model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"#----------Создания датасета и обучение модели--------------\n",
|
||||||
|
"\n",
|
||||||
|
"path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n",
|
||||||
|
" selected_freq=2400,model_name = config_name+\"2400_\", config_name = config_name, model=model)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"torch.cuda.empty_cache()\n",
|
||||||
|
"cv2.destroyAllWindows()\n",
|
||||||
|
"del model\n",
|
||||||
|
"gc.collect()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "usr",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue