You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/train_scripts/Training_models_1.2.ipynb

484 lines
22 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "5a13ad6b-56c9-4381-b376-1765f6dd7553",
"metadata": {
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"# Импортирование библиотек"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7311cb4a-5bf3-4268-b431-43eea10e9ed6",
"metadata": {
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
},
{
"data": {
"text/plain": [
"88"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch import default_generator, randperm\n",
"from torch.utils.data.dataset import Subset\n",
"import torchvision.transforms as transforms\n",
"from torchvision.io import read_image\n",
"from importlib import import_module\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import models\n",
"import torch, torchvision\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import torch.nn as nn\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib\n",
"import os, shutil\n",
"import mlconfig\n",
"import random\n",
"import shutil\n",
"import timeit\n",
"import copy\n",
"import time\n",
"import cv2\n",
"import csv\n",
"import sys\n",
"import io\n",
"import gc\n",
"\n",
"plt.rcParams[\"savefig.bbox\"] = 'tight'\n",
"torch.manual_seed(1)\n",
"#matplotlib.use('Agg')\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(device)\n",
"torch.cuda.empty_cache()\n",
"cv2.destroyAllWindows()\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"id": "384de097-82c6-41f5-bda9-b2f54bc99593",
"metadata": {
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"# Подготовка и обучение детектирование"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46e4dc99-6994-4fee-a32e-f3983bd991bd",
"metadata": {},
"outputs": [],
"source": [
"def prepare_and_learning_detection(num_classes, num_samples, path_dataset, model_name, config_name, model,selected_freq):\n",
" num_samples_per_class = num_samples // num_classes\n",
"\n",
" #----------Создаём папку для сохранения результатов обучения--------------\n",
" os.makedirs(\"models\", exist_ok=True)\n",
" ind = 1\n",
" while True:\n",
" if os.path.exists(\"models/\" + model_name + str(ind)):\n",
" ind += 1\n",
" else:\n",
" os.mkdir(\"models/\" + model_name + str(ind))\n",
" path_res = \"models/\" + model_name + str(ind) + '/'\n",
" break\n",
" \n",
" #----------Создаём файл dataset.csv для обучения--------------\n",
" \n",
" pd_columns = ['file_name']\n",
" df = pd.DataFrame(columns=pd_columns)\n",
" \n",
" subdirs = os.listdir(path_dataset)\n",
" print(subdirs)\n",
" print(\"huy\")\n",
" \n",
" for subdir in subdirs:\n",
" freq_dir = os.path.join(path_dataset, subdir, str(str(selected_freq))+\"_jpg\")\n",
" print(freq_dir)\n",
" if not os.path.isdir(freq_dir):\n",
" continue\n",
"\n",
" files = [\n",
" f for f in os.listdir(freq_dir)\n",
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('.npy')\n",
" ]\n",
" num_samples_per_class = min(num_samples_per_class, len(files))\n",
"\n",
" for subdir in subdirs:\n",
" freq_dir = os.path.join(path_dataset, subdir,str(str(selected_freq))+\"_jpg\")\n",
" print(freq_dir)\n",
" if not os.path.isdir(freq_dir):\n",
" continue\n",
"\n",
" files = [\n",
" f for f in os.listdir(freq_dir)\n",
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('.npy')\n",
" ]\n",
" random.shuffle(files)\n",
" files_to_process = files[:num_samples_per_class]\n",
"\n",
" for file in files_to_process:\n",
" row = pd.DataFrame({\n",
" pd_columns[0]: [str(os.path.join(freq_dir, file))]\n",
" })\n",
" df = pd.concat([df, row], ignore_index=True)\n",
"\n",
" dataset_csv_path = os.path.join(path_res, 'dataset.csv')\n",
" df.to_csv(dataset_csv_path, index=False)\n",
"\n",
" if not os.path.exists(dataset_csv_path):\n",
" raise RuntimeError(f'dataset.csv was not created: {dataset_csv_path}')\n",
" #----------Импортируем параметры для обучения--------------\n",
" \n",
" def load_function(attr):\n",
" module_, func = attr.rsplit('.', maxsplit=1)\n",
" return getattr(import_module(module_), func)\n",
" \n",
" config = mlconfig.load('config_' + config_name + '.yaml')\n",
" \n",
" #----------Создаём класс датасета--------------\n",
" \n",
" class MyDataset(Dataset):\n",
" def __init__(self, path_dataset, csv_file):\n",
" data=[]\n",
" with open(os.path.join(path_dataset, csv_file), newline='') as csvfile:\n",
" reader = csv.reader(csvfile, delimiter=' ', quotechar='|')\n",
" for row in list(reader)[1:]:\n",
" row = str(row)\n",
" data.append(row[2: len(row)-2])\n",
" self.sig_filenames = data\n",
" self.path_dataset = path_dataset\n",
" \n",
" def __len__(self):\n",
" return len(self.sig_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" base = os.path.splitext(self.sig_filenames[idx])[0]\n",
"\n",
" image_real = np.asarray(cv2.split(cv2.imread(base + '_real.png')), dtype=np.float32)\n",
" image_imag = np.asarray(cv2.split(cv2.imread(base + '_imag.png')), dtype=np.float32)\n",
" image_spec = np.asarray(cv2.split(cv2.imread(base + '_spec.png')), dtype=np.float32)\n",
"\n",
" if 'drone' in list(self.sig_filenames[idx].split('/')):\n",
" label = torch.tensor(0)\n",
" if 'noise' in list(self.sig_filenames[idx].split('/')):\n",
" label = torch.tensor(1)\n",
" return image_real, image_imag, image_spec, label\n",
" \n",
" #----------Создаём датасет--------------\n",
" \n",
" dataset = MyDataset(path_dataset=path_res, csv_file='dataset.csv')\n",
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n",
" batch_size = config.batch_size\n",
" train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
" valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
" \n",
" dataloaders = {}\n",
" dataloaders['train'] = train_dataloader\n",
" dataloaders['val'] = valid_dataloader\n",
" dataset_sizes = {}\n",
" dataset_sizes['train'] = len(train_set)\n",
" dataset_sizes['val'] = len(valid_set)\n",
"\n",
" #----------Обучаем модель--------------\n",
"\n",
" val_loss = []\n",
" val_acc = []\n",
" train_loss = []\n",
" train_acc = []\n",
" epochs = config.epoch\n",
" \n",
" best_acc = 0.0\n",
" best_model = copy.deepcopy(model.state_dict())\n",
" limit = config.limit\n",
" epoch_limit = epochs\n",
" \n",
" start = timeit.default_timer()\n",
" for epoch in range(1, epochs+1):\n",
" print(f\"Epoch : {epoch}\\n\")\n",
" dataloader = None\n",
" \n",
" for phase in ['train', 'val']:\n",
" running_loss = 0.0\n",
" running_corrects = 0\n",
" \n",
" for (img1, img2, img3, label) in tqdm(dataloaders[phase]):\n",
" img1, img2, img3, label = img1.to(device), img2.to(device), img3.to(device), label.to(device)\n",
" optimizer.zero_grad()\n",
" \n",
" with torch.set_grad_enabled(phase == 'train'):\n",
" output = model([img1, img2, img3])\n",
" _, pred = torch.max(output.data, 1)\n",
" loss = criterion(output, label)\n",
" if phase=='train' :\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" running_loss += loss.item() * 3 * img1.size(0)\n",
" running_corrects += torch.sum(pred == label.data)\n",
" \n",
" epoch_loss = running_loss / dataset_sizes[phase]\n",
" epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
" \n",
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
" \n",
" if phase=='train' :\n",
" train_loss.append(epoch_loss)\n",
" train_acc.append(epoch_acc)\n",
" else :\n",
" val_loss.append(epoch_loss)\n",
" val_acc.append(epoch_acc)\n",
" if val_acc[-1] > best_acc :\n",
" ind_limit = 0\n",
" best_acc = val_acc[-1]\n",
" best_model = copy.deepcopy(model.state_dict())\n",
" torch.save(best_model, path_res + model_name + '.pth')\n",
" else:\n",
" ind_limit += 1\n",
" \n",
" if ind_limit >= limit:\n",
" break\n",
" \n",
" if ind_limit >= limit:\n",
" epoch_limit = epoch\n",
" break\n",
" \n",
" print()\n",
" \n",
" end = timeit.default_timer()\n",
" print(f\"Total time elapsed = {end - start} seconds\")\n",
" epoch_limit += 1\n",
" \n",
" #----------Вывод графиков и сохранение результатов обучения--------------\n",
" \n",
" train_acc = np.asarray(list(map(lambda x: x.item(), train_acc)))\n",
" val_acc = np.asarray(list(map(lambda x: x.item(), val_acc)))\n",
" \n",
" np.save(path_res+'train_acc.npy', train_acc)\n",
" np.save(path_res+'val_acc.npy', val_acc)\n",
" np.save(path_res+'train_loss.npy', train_loss)\n",
" np.save(path_res+'val_loss.npy', val_loss)\n",
" \n",
" plt.figure()\n",
" plt.plot(range(1,epoch_limit), train_loss, color='blue')\n",
" plt.plot(range(1,epoch_limit), val_loss, color='red')\n",
" plt.xlabel('Epoch')\n",
" plt.ylabel('Loss')\n",
" plt.title('Loss Curve')\n",
" plt.legend(['Train Loss', 'Validation Loss'])\n",
" plt.show()\n",
" plt.clf()\n",
" plt.cla()\n",
" plt.close()\n",
" \n",
" plt.figure()\n",
" plt.plot(range(1,epoch_limit), train_acc, color='blue')\n",
" plt.plot(range(1,epoch_limit), val_acc, color='red')\n",
" plt.xlabel('Epoch')\n",
" plt.ylabel('Accuracy')\n",
" plt.title('Accuracy Curve')\n",
" plt.legend(['Train Accuracy', 'Validation Accuracy'])\n",
" plt.show()\n",
" \n",
" plt.clf()\n",
" plt.cla()\n",
" plt.close()\n",
" torch.cuda.empty_cache()\n",
" cv2.destroyAllWindows()\n",
" del model\n",
" gc.collect()\n",
"\n",
" return path_res, model_name"
]
},
{
"cell_type": "markdown",
"id": "93c136ee",
"metadata": {},
"source": [
"### Ensemble"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "52e8d4c5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sibsci/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/home/sibsci/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['noise', 'drone']\n",
"huy\n",
"/home/sibsci/dataset_img/noise/5800_jpg\n",
"/home/sibsci/dataset_img/drone/5800_jpg\n",
"/home/sibsci/dataset_img/noise/5800/_jpg\n",
"/home/sibsci/dataset_img/drone/5800/_jpg\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_84402/382832652.py:100: UserWarning: Length of split at index 0 is 0. This might result in an empty dataset.\n",
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n",
"/tmp/ipykernel_84402/382832652.py:100: UserWarning: Length of split at index 1 is 0. This might result in an empty dataset.\n",
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n"
]
},
{
"ename": "ValueError",
"evalue": "num_samples should be a positive integer value, but got num_samples=0",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mValueError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 53\u001b[39m\n\u001b[32m 49\u001b[39m model = model.to(device)\n\u001b[32m 50\u001b[39m \n\u001b[32m 51\u001b[39m \u001b[38;5;66;03m#----------Создания датасета и обучение модели--------------\u001b[39;00m\n\u001b[32m 52\u001b[39m \n\u001b[32m---> \u001b[39m\u001b[32m53\u001b[39m path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 1000, path_dataset = \"/home/sibsci/dataset_img\", \n\u001b[32m 54\u001b[39m selected_freq=\u001b[32m5800\u001b[39m,model_name = config_name+\u001b[33m\"_5.8_jpg_\"\u001b[39m, config_name = config_name, model=model)\n\u001b[32m 55\u001b[39m \n\u001b[32m 56\u001b[39m \n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 102\u001b[39m, in \u001b[36mprepare_and_learning_detection\u001b[39m\u001b[34m(num_classes, num_samples, path_dataset, model_name, config_name, model, selected_freq)\u001b[39m\n\u001b[32m 98\u001b[39m \n\u001b[32m 99\u001b[39m dataset = MyDataset(path_dataset=path_res, csv_file=\u001b[33m'dataset.csv'\u001b[39m)\n\u001b[32m 100\u001b[39m train_set, valid_set = torch.utils.data.random_split(dataset, [\u001b[32m0.7\u001b[39m, \u001b[32m0.3\u001b[39m], generator=torch.Generator().manual_seed(\u001b[32m42\u001b[39m))\n\u001b[32m 101\u001b[39m batch_size = config.batch_size\n\u001b[32m--> \u001b[39m\u001b[32m102\u001b[39m train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m 103\u001b[39m valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m 104\u001b[39m \n\u001b[32m 105\u001b[39m dataloaders = {}\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/DroneDetector/.venv-train/lib/python3.12/site-packages/torch/utils/data/dataloader.py:394\u001b[39m, in \u001b[36mDataLoader.__init__\u001b[39m\u001b[34m(self, dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn, multiprocessing_context, generator, prefetch_factor, persistent_workers, pin_memory_device, in_order)\u001b[39m\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m: \u001b[38;5;66;03m# map-style\u001b[39;00m\n\u001b[32m 393\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m shuffle:\n\u001b[32m--> \u001b[39m\u001b[32m394\u001b[39m sampler = \u001b[43mRandomSampler\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgenerator\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgenerator\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[32m 395\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 396\u001b[39m sampler = SequentialSampler(dataset) \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/DroneDetector/.venv-train/lib/python3.12/site-packages/torch/utils/data/sampler.py:149\u001b[39m, in \u001b[36mRandomSampler.__init__\u001b[39m\u001b[34m(self, data_source, replacement, num_samples, generator)\u001b[39m\n\u001b[32m 144\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 145\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mreplacement should be a boolean value, but got replacement=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.replacement\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 146\u001b[39m )\n\u001b[32m 148\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m.num_samples, \u001b[38;5;28mint\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m.num_samples <= \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 150\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mnum_samples should be a positive integer value, but got num_samples=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.num_samples\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 151\u001b[39m )\n",
"\u001b[31mValueError\u001b[39m: num_samples should be a positive integer value, but got num_samples=0"
]
}
],
"source": [
"#----------Инициализируем модель и параметры обучения--------------\n",
"\n",
"torch.cuda.empty_cache()\n",
"cv2.destroyAllWindows()\n",
"gc.collect()\n",
"\n",
"num_classes = 3\n",
"config_name = \"ensemble\"\n",
" \n",
"def load_function(attr):\n",
" module_, func = attr.rsplit('.', maxsplit=1)\n",
" return getattr(import_module(module_), func)\n",
" \n",
"config = mlconfig.load('config_' + config_name + '.yaml')\n",
"\n",
"model1 = models.resnet18(pretrained=False)\n",
"model2 = models.resnet50(pretrained=False)\n",
"model3 = models.resnet101(pretrained=False)\n",
"\n",
"num_classes = 2\n",
"\n",
"model1.fc = nn.Linear(model1.fc.in_features, num_classes)\n",
"model2.fc = nn.Linear(model2.fc.in_features, num_classes)\n",
"model3.fc = nn.Linear(model3.fc.in_features, num_classes)\n",
"\n",
"class Ensemble(nn.Module):\n",
" def __init__(self, model1, model2, model3):\n",
" super(Ensemble, self).__init__()\n",
" self.model1 = model1\n",
" self.model2 = model2\n",
" self.model3 = model3\n",
" self.fc = nn.Linear(3 * num_classes, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x1 = self.model1(x[0])\n",
" x2 = self.model2(x[1])\n",
" x3 = self.model3(x[2])\n",
" x = torch.cat((x1, x2, x3), dim=1)\n",
" x = self.fc(x)\n",
" return x\n",
"\n",
"model = Ensemble(model1, model2, model3)\n",
"\n",
"optimizer = load_function(config.optimizer.name)(model.parameters(), lr=config.optimizer.lr)\n",
"criterion = load_function(config.loss_function.name)()\n",
"scheduler = load_function(config.scheduler.name)(optimizer, step_size=config.scheduler.step_size, gamma=config.scheduler.gamma)\n",
"\n",
"if device != 'cpu':\n",
" model = model.to(device)\n",
"\n",
"#----------Создания датасета и обучение модели--------------\n",
"\n",
"path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 1000, path_dataset = \"/home/sibsci/dataset_img\", \n",
" selected_freq=5800,model_name = config_name+\"_5.8_jpg_\", config_name = config_name, model=model)\n",
"\n",
"\n",
"torch.cuda.empty_cache()\n",
"cv2.destroyAllWindows()\n",
"del model\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4234ee26",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"celltoolbar": "Отсутствует",
"kernelspec": {
"display_name": ".venv-train (3.12.3)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}