You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/train_scripts/Training_models2pic_val_los...

524 lines
25 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e1db882b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.\n",
" warnings.warn(\"Unable to import Axes3D. This may be due to multiple versions of \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
},
{
"data": {
"text/plain": [
"220"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch import default_generator, randperm\n",
"from torch.utils.data.dataset import Subset\n",
"import torchvision.transforms as transforms\n",
"from torchvision.io import read_image\n",
"from importlib import import_module\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import models\n",
"import torch, torchvision\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import torch.nn as nn\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib\n",
"import os, shutil\n",
"import mlconfig\n",
"import random\n",
"import shutil\n",
"import timeit\n",
"import copy\n",
"import time\n",
"import cv2\n",
"import csv\n",
"import sys\n",
"import io\n",
"import gc\n",
"\n",
"plt.rcParams[\"savefig.bbox\"] = 'tight'\n",
"torch.manual_seed(1)\n",
"#matplotlib.use('Agg')\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(device)\n",
"torch.cuda.empty_cache()\n",
"cv2.destroyAllWindows()\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e009995",
"metadata": {},
"outputs": [],
"source": [
"def prepare_and_learning_detection(num_classes, num_samples, path_dataset, model_name, config_name, model,selected_freq):\n",
" num_samples_per_class = num_samples // num_classes\n",
"\n",
" #----------Создаём папку для сохранения результатов обучения--------------\n",
" os.makedirs(\"models\", exist_ok=True)\n",
" ind = 1\n",
" while True:\n",
" if os.path.exists(\"models/\" + model_name + str(ind)):\n",
" ind += 1\n",
" else:\n",
" os.mkdir(\"models/\" + model_name + str(ind))\n",
" path_res = \"models/\" + model_name + str(ind) + '/'\n",
" break\n",
" \n",
" #----------Создаём файл dataset.csv для обучения--------------\n",
" \n",
" pd_columns = ['file_name']\n",
" df = pd.DataFrame(columns=pd_columns)\n",
" \n",
" subdirs = os.listdir(path_dataset)\n",
" \n",
" for subdir in subdirs:\n",
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n",
" if not os.path.isdir(freq_dir):\n",
" print(\"Error1\")\n",
" continue\n",
" \n",
" files_k=[f for f in os.listdir(freq_dir)]\n",
" print(len(files_k))\n",
" \n",
" files = [\n",
" f for f in os.listdir(freq_dir)\n",
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n",
" ]\n",
" num_samples_per_class = min(num_samples_per_class, len(files))\n",
" print(f\"num_samples per class {subdir} is {num_samples_per_class}\")\n",
"\n",
" for subdir in subdirs:\n",
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n",
" if not os.path.isdir(freq_dir):\n",
" print(\"Error1\")\n",
" continue\n",
"\n",
" files = [\n",
" f for f in os.listdir(freq_dir)\n",
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n",
" ]\n",
" random.shuffle(files)\n",
" files_to_process = files[:num_samples_per_class]\n",
"\n",
" for file in files_to_process:\n",
" row = pd.DataFrame({\n",
" pd_columns[0]: [str(os.path.join(freq_dir, file))]\n",
" })\n",
" df = pd.concat([df, row], ignore_index=True)\n",
"\n",
" dataset_csv_path = os.path.join(path_res, 'dataset.csv')\n",
" df.to_csv(dataset_csv_path, index=False)\n",
"\n",
" if not os.path.exists(dataset_csv_path):\n",
" raise RuntimeError(f'dataset.csv was not created: {dataset_csv_path}')\n",
" #----------Импортируем параметры для обучения--------------\n",
" \n",
" def load_function(attr):\n",
" module_, func = attr.rsplit('.', maxsplit=1)\n",
" return getattr(import_module(module_), func)\n",
" \n",
" config = mlconfig.load('config_' + config_name + '.yaml')\n",
" \n",
" #----------Создаём класс датасета--------------\n",
" \n",
" class MyDataset(Dataset):\n",
" def __init__(self, path_dataset, csv_file):\n",
" data=[]\n",
" with open(os.path.join(path_dataset, csv_file), newline='') as csvfile:\n",
" reader = csv.reader(csvfile, delimiter=' ', quotechar='|')\n",
" for row in list(reader)[1:]:\n",
" row = str(row)\n",
" data.append(row[2: len(row)-2])\n",
" self.path_dataset = path_dataset\n",
" self.target_shape = None\n",
" self.target_hw = None\n",
" self.sig_filenames = self._validate_files(data)\n",
"\n",
" def _pair_paths(self, filename):\n",
" base = os.path.splitext(filename)[0]\n",
" if base.endswith(\"real\"):\n",
" return base + \".png\", base[:-4] + \"imag.png\"\n",
" if base.endswith(\"imag\"):\n",
" return base[:-4] + \"real.png\", base + \".png\"\n",
" return None, None\n",
"\n",
" @staticmethod\n",
" def _read_shape(path):\n",
" img = cv2.imread(path)\n",
" if img is None:\n",
" return None\n",
" return img.shape\n",
"\n",
" def _validate_files(self, filenames):\n",
" from collections import Counter\n",
"\n",
" candidates = []\n",
" dropped = []\n",
" shape_counter = Counter()\n",
"\n",
" for filename in filenames:\n",
" real_path, imag_path = self._pair_paths(filename)\n",
" if real_path is None or imag_path is None:\n",
" dropped.append((filename, \"bad file suffix\"))\n",
" continue\n",
"\n",
" real_shape = self._read_shape(real_path)\n",
" imag_shape = self._read_shape(imag_path)\n",
" if real_shape is None or imag_shape is None:\n",
" dropped.append((filename, f\"read failed real={real_shape} imag={imag_shape}\"))\n",
" continue\n",
" if real_shape != imag_shape:\n",
" dropped.append((filename, f\"pair shape mismatch real={real_shape} imag={imag_shape}\"))\n",
" continue\n",
"\n",
" shape_counter[real_shape] += 1\n",
" candidates.append((filename, real_shape))\n",
"\n",
" if not candidates:\n",
" raise RuntimeError(\"No valid image pairs left after shape validation\")\n",
"\n",
" preferred_shape = (1600, 1600, 3)\n",
" if shape_counter.get(preferred_shape, 0) > 0:\n",
" target_shape = preferred_shape\n",
" else:\n",
" target_shape = shape_counter.most_common(1)[0][0]\n",
"\n",
" self.target_shape = target_shape\n",
" self.target_hw = target_shape[:2]\n",
" valid = [filename for filename, _shape in candidates]\n",
" resized_count = sum(1 for _filename, shape in candidates if shape != target_shape)\n",
"\n",
" print(f\"[dataset-shape-filter] shape_distribution={dict(shape_counter)}\")\n",
" print(\n",
" f\"[dataset-shape-filter] target_shape={target_shape} \"\n",
" f\"kept={len(valid)} dropped={len(dropped)} will_resize={resized_count}\"\n",
" )\n",
" for filename, reason in dropped[:20]:\n",
" print(f\"[dataset-shape-filter] drop {filename}: {reason}\")\n",
" if len(dropped) > 20:\n",
" print(f\"[dataset-shape-filter] ... {len(dropped) - 20} more dropped\")\n",
"\n",
" return valid\n",
"\n",
" def __len__(self):\n",
" return len(self.sig_filenames)\n",
"\n",
" def _read_image_chw(self, path):\n",
" img = cv2.imread(path)\n",
" if img is None:\n",
" raise RuntimeError(f\"failed to read image: {path}\")\n",
" if img.shape[:2] != self.target_hw:\n",
" img = cv2.resize(\n",
" img,\n",
" (self.target_hw[1], self.target_hw[0]),\n",
" interpolation=cv2.INTER_AREA,\n",
" )\n",
" return np.asarray(cv2.split(img), dtype=np.float32)\n",
"\n",
" def __getitem__(self, idx):\n",
" real_path, imag_path = self._pair_paths(self.sig_filenames[idx])\n",
" image_real = self._read_image_chw(real_path)\n",
" image_imag = self._read_image_chw(imag_path)\n",
"\n",
" path_parts = set(self.sig_filenames[idx].split('/'))\n",
" if 'drone' in path_parts:\n",
" label = torch.tensor(0)\n",
" elif 'noise' in path_parts:\n",
" label = torch.tensor(1)\n",
" else:\n",
" raise RuntimeError(f\"cannot infer label from path: {self.sig_filenames[idx]}\")\n",
"\n",
" return image_real, image_imag, label\n",
"\n",
" #----------Создаём датасет--------------\n",
" \n",
" dataset = MyDataset(path_dataset=path_res, csv_file='dataset.csv')\n",
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n",
" batch_size = config.batch_size\n",
" train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
" valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, drop_last=False)\n",
" \n",
" dataloaders = {}\n",
" dataloaders['train'] = train_dataloader\n",
" dataloaders['val'] = valid_dataloader\n",
" dataset_sizes = {}\n",
" dataset_sizes['train'] = len(train_set)\n",
" dataset_sizes['val'] = len(valid_set)\n",
"\n",
" #----------Обучаем модель--------------\n",
"\n",
" val_loss = []\n",
" val_acc = []\n",
" train_loss = []\n",
" train_acc = []\n",
" epochs = config.epoch\n",
" min_delta = 1e-4\n",
" \n",
" best_val_loss = float('inf')\n",
" best_model = copy.deepcopy(model.state_dict())\n",
" limit = config.limit\n",
" ind_limit = 0\n",
" epoch_limit = epochs\n",
" \n",
" start = timeit.default_timer()\n",
" for epoch in range(1, epochs+1):\n",
" print(f\"Epoch : {epoch}\\n\")\n",
" \n",
" for phase in ['train', 'val']:\n",
" if phase == 'train':\n",
" model.train()\n",
" else:\n",
" model.eval()\n",
"\n",
" running_loss = 0.0\n",
" running_corrects = 0\n",
" \n",
" for (img1, img2, label) in tqdm(dataloaders[phase]):\n",
" img1, img2, label = img1.to(device), img2.to(device), label.to(device)\n",
" optimizer.zero_grad()\n",
" \n",
" with torch.set_grad_enabled(phase == 'train'):\n",
" output = model([img1, img2])\n",
" _, pred = torch.max(output.data, 1)\n",
" loss = criterion(output, label)\n",
" if phase == 'train':\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" running_loss += loss.item() * img1.size(0)\n",
" running_corrects += torch.sum(pred == label.data)\n",
" \n",
" epoch_loss = running_loss / dataset_sizes[phase]\n",
" epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
" \n",
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
" \n",
" if phase == 'train':\n",
" train_loss.append(epoch_loss)\n",
" train_acc.append(epoch_acc)\n",
" else:\n",
" val_loss.append(epoch_loss)\n",
" val_acc.append(epoch_acc)\n",
" scheduler.step(epoch_loss)\n",
"\n",
" current_lr = optimizer.param_groups[0]['lr']\n",
" print(f'val lr: {current_lr:.8f}')\n",
"\n",
" if epoch_loss < (best_val_loss - min_delta):\n",
" ind_limit = 0\n",
" best_val_loss = epoch_loss\n",
" best_model = copy.deepcopy(model.state_dict())\n",
" torch.save(best_model, path_res + model_name + '.pth')\n",
" print(f'saved best model with val_loss={best_val_loss:.4f}')\n",
" else:\n",
" ind_limit += 1\n",
" print(f'early stopping patience: {ind_limit}/{limit}')\n",
" \n",
" if ind_limit >= limit:\n",
" break\n",
" \n",
" if ind_limit >= limit:\n",
" epoch_limit = epoch\n",
" break\n",
" \n",
" print()\n",
" \n",
" end = timeit.default_timer()\n",
" print(f\"Total time elapsed = {end - start} seconds\")\n",
" epoch_limit += 1\n",
" \n",
" #----------Вывод графиков и сохранение результатов обучения--------------\n",
" \n",
" train_acc = np.asarray(list(map(lambda x: x.item(), train_acc)))\n",
" val_acc = np.asarray(list(map(lambda x: x.item(), val_acc)))\n",
" \n",
" np.save(path_res+'train_acc.npy', train_acc)\n",
" np.save(path_res+'val_acc.npy', val_acc)\n",
" np.save(path_res+'train_loss.npy', train_loss)\n",
" np.save(path_res+'val_loss.npy', val_loss)\n",
" \n",
" plt.figure()\n",
" plt.plot(range(1,epoch_limit), train_loss, color='blue')\n",
" plt.plot(range(1,epoch_limit), val_loss, color='red')\n",
" plt.xlabel('Epoch')\n",
" plt.ylabel('Loss') \n",
" \n",
" plt.title('Loss Curve')\n",
" plt.legend(['Train Loss', 'Validation Loss'])\n",
" plt.show()\n",
" plt.clf()\n",
" plt.cla()\n",
" plt.close()\n",
" \n",
" plt.figure()\n",
" plt.plot(range(1,epoch_limit), train_acc, color='blue')\n",
" plt.plot(range(1,epoch_limit), val_acc, color='red')\n",
" plt.xlabel('Epoch')\n",
" plt.ylabel('Accuracy')\n",
" plt.title('Accuracy Curve')\n",
" plt.legend(['Train Accuracy', 'Validation Accuracy'])\n",
" plt.show()\n",
" \n",
" plt.clf()\n",
" plt.cla()\n",
" plt.close()\n",
" torch.cuda.empty_cache()\n",
" cv2.destroyAllWindows()\n",
" del model\n",
" gc.collect()\n",
"\n",
" return path_res, model_name"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbbe7fea",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error1\n",
"Error1\n",
"Error1\n",
"Error1\n",
"Error1\n",
"Error1\n"
]
},
{
"ename": "RuntimeError",
"evalue": "No valid image pairs left after shape validation",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 45\u001b[39m\n\u001b[32m 41\u001b[39m model = model.to(device)\n\u001b[32m 42\u001b[39m \n\u001b[32m 43\u001b[39m \u001b[38;5;66;03m#----------Создания датасета и обучение модели--------------\u001b[39;00m\n\u001b[32m 44\u001b[39m \n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n\u001b[32m 46\u001b[39m selected_freq=\u001b[32m1200\u001b[39m,model_name = config_name+\u001b[33m\"1200_\"\u001b[39m, config_name = config_name, model=model)\n\u001b[32m 47\u001b[39m \n\u001b[32m 48\u001b[39m \n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 183\u001b[39m, in \u001b[36mprepare_and_learning_detection\u001b[39m\u001b[34m(num_classes, num_samples, path_dataset, model_name, config_name, model, selected_freq)\u001b[39m\n\u001b[32m 179\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m image_real, image_imag, label\n\u001b[32m 180\u001b[39m \n\u001b[32m 181\u001b[39m \u001b[38;5;66;03m#----------Создаём датасет--------------\u001b[39;00m\n\u001b[32m 182\u001b[39m \n\u001b[32m--> \u001b[39m\u001b[32m183\u001b[39m dataset = MyDataset(path_dataset=path_res, csv_file=\u001b[33m'dataset.csv'\u001b[39m)\n\u001b[32m 184\u001b[39m train_set, valid_set = torch.utils.data.random_split(dataset, [\u001b[32m0.7\u001b[39m, \u001b[32m0.3\u001b[39m], generator=torch.Generator().manual_seed(\u001b[32m42\u001b[39m))\n\u001b[32m 185\u001b[39m batch_size = config.batch_size\n\u001b[32m 186\u001b[39m train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 83\u001b[39m, in \u001b[36mprepare_and_learning_detection.<locals>.MyDataset.__init__\u001b[39m\u001b[34m(self, path_dataset, csv_file)\u001b[39m\n\u001b[32m 79\u001b[39m data.append(row[\u001b[32m2\u001b[39m: len(row)-\u001b[32m2\u001b[39m])\n\u001b[32m 80\u001b[39m self.path_dataset = path_dataset\n\u001b[32m 81\u001b[39m self.target_shape = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 82\u001b[39m self.target_hw = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m83\u001b[39m self.sig_filenames = self._validate_files(data)\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 126\u001b[39m, in \u001b[36mprepare_and_learning_detection.<locals>.MyDataset._validate_files\u001b[39m\u001b[34m(self, filenames)\u001b[39m\n\u001b[32m 122\u001b[39m shape_counter[real_shape] += \u001b[32m1\u001b[39m\n\u001b[32m 123\u001b[39m candidates.append((filename, real_shape))\n\u001b[32m 124\u001b[39m \n\u001b[32m 125\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;01mnot\u001b[39;00m candidates:\n\u001b[32m--> \u001b[39m\u001b[32m126\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(\u001b[33m\"No valid image pairs left after shape validation\"\u001b[39m)\n\u001b[32m 127\u001b[39m \n\u001b[32m 128\u001b[39m preferred_shape = (\u001b[32m1600\u001b[39m, \u001b[32m1600\u001b[39m, \u001b[32m3\u001b[39m)\n\u001b[32m 129\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m shape_counter.get(preferred_shape, \u001b[32m0\u001b[39m) > \u001b[32m0\u001b[39m:\n",
"\u001b[31mRuntimeError\u001b[39m: No valid image pairs left after shape validation"
]
}
],
"source": [
"torch.cuda.empty_cache()\n",
"cv2.destroyAllWindows()\n",
"gc.collect()\n",
"\n",
"config_name = \"ensemble\"\n",
" \n",
"def load_function(attr):\n",
" module_, func = attr.rsplit('.', maxsplit=1)\n",
" return getattr(import_module(module_), func)\n",
" \n",
"config = mlconfig.load('config_' + config_name + '.yaml')\n",
"\n",
"model1 = models.resnet18(pretrained=False)\n",
"model2 = models.resnet50(pretrained=False)\n",
"\n",
"num_classes = 2\n",
"\n",
"model1.fc = nn.Linear(model1.fc.in_features, num_classes)\n",
"model2.fc = nn.Linear(model2.fc.in_features, num_classes)\n",
"\n",
"class Ensemble(nn.Module):\n",
" def __init__(self, model1, model2):\n",
" super(Ensemble, self).__init__()\n",
" self.model1 = model1\n",
" self.model2 = model2\n",
" self.fc = nn.Linear(2 * num_classes, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x1 = self.model1(x[0])\n",
" x2 = self.model2(x[1])\n",
" x = torch.cat((x1, x2), dim=1)\n",
" x = self.fc(x)\n",
" return x\n",
"model = Ensemble(model1, model2)\n",
"\n",
"optimizer = load_function(config.optimizer.name)(model.parameters(), lr=config.optimizer.lr)\n",
"criterion = load_function(config.loss_function.name)()\n",
"scheduler = load_function(config.scheduler.name)(optimizer, step_size=config.scheduler.step_size, gamma=config.scheduler.gamma)\n",
"\n",
"if device != 'cpu':\n",
" model = model.to(device)\n",
"\n",
"#----------Создания датасета и обучение модели--------------\n",
"\n",
"path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n",
" selected_freq=2400,model_name = config_name+\"2400_\", config_name = config_name, model=model)\n",
"\n",
"\n",
"torch.cuda.empty_cache()\n",
"cv2.destroyAllWindows()\n",
"del model\n",
"gc.collect()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "usr",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}