From c70a25cb8f9cbd4651241b93247e6f8aa78588d1 Mon Sep 17 00:00:00 2001 From: Sergey Revyakin Date: Tue, 5 May 2026 14:23:53 +0700 Subject: [PATCH] =?UTF-8?q?=D0=9D=D0=BE=D0=B2=D0=B0=D1=8F=20=20=D0=B2?= =?UTF-8?q?=D0=B5=D1=80=D1=81=D0=B8=D1=8F=20=D0=BD=D0=BE=D1=83=D1=82=D0=B1?= =?UTF-8?q?=D1=83=D0=BA=D0=B0=20=D0=B4=D0=BB=D1=8F=20=D0=BE=D0=B1=D1=83?= =?UTF-8?q?=D1=87=D0=B5=D0=BD=D0=B8=D1=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Training_models2pic_val_loss.ipynb | 523 ++++++++++++++++++ 1 file changed, 523 insertions(+) create mode 100644 train_scripts/Training_models2pic_val_loss.ipynb diff --git a/train_scripts/Training_models2pic_val_loss.ipynb b/train_scripts/Training_models2pic_val_loss.ipynb new file mode 100644 index 0000000..5109b31 --- /dev/null +++ b/train_scripts/Training_models2pic_val_loss.ipynb @@ -0,0 +1,523 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e1db882b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.\n", + " warnings.warn(\"Unable to import Axes3D. This may be due to multiple versions of \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuda\n" + ] + }, + { + "data": { + "text/plain": [ + "220" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torch import default_generator, randperm\n", + "from torch.utils.data.dataset import Subset\n", + "import torchvision.transforms as transforms\n", + "from torchvision.io import read_image\n", + "from importlib import import_module\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import models\n", + "import torch, torchvision\n", + "from pathlib import Path\n", + "from PIL import Image\n", + "import torch.nn as nn\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib\n", + "import os, shutil\n", + "import mlconfig\n", + "import random\n", + "import shutil\n", + "import timeit\n", + "import copy\n", + "import time\n", + "import cv2\n", + "import csv\n", + "import sys\n", + "import io\n", + "import gc\n", + "\n", + "plt.rcParams[\"savefig.bbox\"] = 'tight'\n", + "torch.manual_seed(1)\n", + "#matplotlib.use('Agg')\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(device)\n", + "torch.cuda.empty_cache()\n", + "cv2.destroyAllWindows()\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8e009995", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_and_learning_detection(num_classes, num_samples, path_dataset, model_name, config_name, model,selected_freq):\n", + " num_samples_per_class = num_samples // num_classes\n", + "\n", + " #----------Создаём папку для сохранения результатов обучения--------------\n", + " os.makedirs(\"models\", exist_ok=True)\n", + " ind = 1\n", + " while True:\n", + " if os.path.exists(\"models/\" + model_name + str(ind)):\n", + " ind += 1\n", + " else:\n", + " os.mkdir(\"models/\" + model_name + str(ind))\n", + " path_res = \"models/\" + model_name + str(ind) + '/'\n", + " break\n", + " \n", + " #----------Создаём файл dataset.csv для обучения--------------\n", + " \n", + " pd_columns = ['file_name']\n", + " df = pd.DataFrame(columns=pd_columns)\n", + " \n", + " subdirs = os.listdir(path_dataset)\n", + " \n", + " for subdir in subdirs:\n", + " freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n", + " if not os.path.isdir(freq_dir):\n", + " print(\"Error1\")\n", + " continue\n", + " \n", + " files_k=[f for f in os.listdir(freq_dir)]\n", + " print(len(files_k))\n", + " \n", + " files = [\n", + " f for f in os.listdir(freq_dir)\n", + " if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n", + " ]\n", + " num_samples_per_class = min(num_samples_per_class, len(files))\n", + " print(f\"num_samples per class {subdir} is {num_samples_per_class}\")\n", + "\n", + " for subdir in subdirs:\n", + " freq_dir = os.path.join(path_dataset, subdir, str(selected_freq)+\"_jpg\")\n", + " if not os.path.isdir(freq_dir):\n", + " print(\"Error1\")\n", + " continue\n", + "\n", + " files = [\n", + " f for f in os.listdir(freq_dir)\n", + " if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('imag.png')\n", + " ]\n", + " random.shuffle(files)\n", + " files_to_process = files[:num_samples_per_class]\n", + "\n", + " for file in files_to_process:\n", + " row = pd.DataFrame({\n", + " pd_columns[0]: [str(os.path.join(freq_dir, file))]\n", + " })\n", + " df = pd.concat([df, row], ignore_index=True)\n", + "\n", + " dataset_csv_path = os.path.join(path_res, 'dataset.csv')\n", + " df.to_csv(dataset_csv_path, index=False)\n", + "\n", + " if not os.path.exists(dataset_csv_path):\n", + " raise RuntimeError(f'dataset.csv was not created: {dataset_csv_path}')\n", + " #----------Импортируем параметры для обучения--------------\n", + " \n", + " def load_function(attr):\n", + " module_, func = attr.rsplit('.', maxsplit=1)\n", + " return getattr(import_module(module_), func)\n", + " \n", + " config = mlconfig.load('config_' + config_name + '.yaml')\n", + " \n", + " #----------Создаём класс датасета--------------\n", + " \n", + " class MyDataset(Dataset):\n", + " def __init__(self, path_dataset, csv_file):\n", + " data=[]\n", + " with open(os.path.join(path_dataset, csv_file), newline='') as csvfile:\n", + " reader = csv.reader(csvfile, delimiter=' ', quotechar='|')\n", + " for row in list(reader)[1:]:\n", + " row = str(row)\n", + " data.append(row[2: len(row)-2])\n", + " self.path_dataset = path_dataset\n", + " self.target_shape = None\n", + " self.target_hw = None\n", + " self.sig_filenames = self._validate_files(data)\n", + "\n", + " def _pair_paths(self, filename):\n", + " base = os.path.splitext(filename)[0]\n", + " if base.endswith(\"real\"):\n", + " return base + \".png\", base[:-4] + \"imag.png\"\n", + " if base.endswith(\"imag\"):\n", + " return base[:-4] + \"real.png\", base + \".png\"\n", + " return None, None\n", + "\n", + " @staticmethod\n", + " def _read_shape(path):\n", + " img = cv2.imread(path)\n", + " if img is None:\n", + " return None\n", + " return img.shape\n", + "\n", + " def _validate_files(self, filenames):\n", + " from collections import Counter\n", + "\n", + " candidates = []\n", + " dropped = []\n", + " shape_counter = Counter()\n", + "\n", + " for filename in filenames:\n", + " real_path, imag_path = self._pair_paths(filename)\n", + " if real_path is None or imag_path is None:\n", + " dropped.append((filename, \"bad file suffix\"))\n", + " continue\n", + "\n", + " real_shape = self._read_shape(real_path)\n", + " imag_shape = self._read_shape(imag_path)\n", + " if real_shape is None or imag_shape is None:\n", + " dropped.append((filename, f\"read failed real={real_shape} imag={imag_shape}\"))\n", + " continue\n", + " if real_shape != imag_shape:\n", + " dropped.append((filename, f\"pair shape mismatch real={real_shape} imag={imag_shape}\"))\n", + " continue\n", + "\n", + " shape_counter[real_shape] += 1\n", + " candidates.append((filename, real_shape))\n", + "\n", + " if not candidates:\n", + " raise RuntimeError(\"No valid image pairs left after shape validation\")\n", + "\n", + " preferred_shape = (1600, 1600, 3)\n", + " if shape_counter.get(preferred_shape, 0) > 0:\n", + " target_shape = preferred_shape\n", + " else:\n", + " target_shape = shape_counter.most_common(1)[0][0]\n", + "\n", + " self.target_shape = target_shape\n", + " self.target_hw = target_shape[:2]\n", + " valid = [filename for filename, _shape in candidates]\n", + " resized_count = sum(1 for _filename, shape in candidates if shape != target_shape)\n", + "\n", + " print(f\"[dataset-shape-filter] shape_distribution={dict(shape_counter)}\")\n", + " print(\n", + " f\"[dataset-shape-filter] target_shape={target_shape} \"\n", + " f\"kept={len(valid)} dropped={len(dropped)} will_resize={resized_count}\"\n", + " )\n", + " for filename, reason in dropped[:20]:\n", + " print(f\"[dataset-shape-filter] drop {filename}: {reason}\")\n", + " if len(dropped) > 20:\n", + " print(f\"[dataset-shape-filter] ... {len(dropped) - 20} more dropped\")\n", + "\n", + " return valid\n", + "\n", + " def __len__(self):\n", + " return len(self.sig_filenames)\n", + "\n", + " def _read_image_chw(self, path):\n", + " img = cv2.imread(path)\n", + " if img is None:\n", + " raise RuntimeError(f\"failed to read image: {path}\")\n", + " if img.shape[:2] != self.target_hw:\n", + " img = cv2.resize(\n", + " img,\n", + " (self.target_hw[1], self.target_hw[0]),\n", + " interpolation=cv2.INTER_AREA,\n", + " )\n", + " return np.asarray(cv2.split(img), dtype=np.float32)\n", + "\n", + " def __getitem__(self, idx):\n", + " real_path, imag_path = self._pair_paths(self.sig_filenames[idx])\n", + " image_real = self._read_image_chw(real_path)\n", + " image_imag = self._read_image_chw(imag_path)\n", + "\n", + " path_parts = set(self.sig_filenames[idx].split('/'))\n", + " if 'drone' in path_parts:\n", + " label = torch.tensor(0)\n", + " elif 'noise' in path_parts:\n", + " label = torch.tensor(1)\n", + " else:\n", + " raise RuntimeError(f\"cannot infer label from path: {self.sig_filenames[idx]}\")\n", + "\n", + " return image_real, image_imag, label\n", + "\n", + " #----------Создаём датасет--------------\n", + " \n", + " dataset = MyDataset(path_dataset=path_res, csv_file='dataset.csv')\n", + " train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n", + " batch_size = config.batch_size\n", + " train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n", + " valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, drop_last=False)\n", + " \n", + " dataloaders = {}\n", + " dataloaders['train'] = train_dataloader\n", + " dataloaders['val'] = valid_dataloader\n", + " dataset_sizes = {}\n", + " dataset_sizes['train'] = len(train_set)\n", + " dataset_sizes['val'] = len(valid_set)\n", + "\n", + " #----------Обучаем модель--------------\n", + "\n", + " val_loss = []\n", + " val_acc = []\n", + " train_loss = []\n", + " train_acc = []\n", + " epochs = config.epoch\n", + " min_delta = 1e-4\n", + " \n", + " best_val_loss = float('inf')\n", + " best_model = copy.deepcopy(model.state_dict())\n", + " limit = config.limit\n", + " ind_limit = 0\n", + " epoch_limit = epochs\n", + " \n", + " start = timeit.default_timer()\n", + " for epoch in range(1, epochs+1):\n", + " print(f\"Epoch : {epoch}\\n\")\n", + " \n", + " for phase in ['train', 'val']:\n", + " if phase == 'train':\n", + " model.train()\n", + " else:\n", + " model.eval()\n", + "\n", + " running_loss = 0.0\n", + " running_corrects = 0\n", + " \n", + " for (img1, img2, label) in tqdm(dataloaders[phase]):\n", + " img1, img2, label = img1.to(device), img2.to(device), label.to(device)\n", + " optimizer.zero_grad()\n", + " \n", + " with torch.set_grad_enabled(phase == 'train'):\n", + " output = model([img1, img2])\n", + " _, pred = torch.max(output.data, 1)\n", + " loss = criterion(output, label)\n", + " if phase == 'train':\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " running_loss += loss.item() * img1.size(0)\n", + " running_corrects += torch.sum(pred == label.data)\n", + " \n", + " epoch_loss = running_loss / dataset_sizes[phase]\n", + " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n", + " \n", + " print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n", + " \n", + " if phase == 'train':\n", + " train_loss.append(epoch_loss)\n", + " train_acc.append(epoch_acc)\n", + " else:\n", + " val_loss.append(epoch_loss)\n", + " val_acc.append(epoch_acc)\n", + " scheduler.step(epoch_loss)\n", + "\n", + " current_lr = optimizer.param_groups[0]['lr']\n", + " print(f'val lr: {current_lr:.8f}')\n", + "\n", + " if epoch_loss < (best_val_loss - min_delta):\n", + " ind_limit = 0\n", + " best_val_loss = epoch_loss\n", + " best_model = copy.deepcopy(model.state_dict())\n", + " torch.save(best_model, path_res + model_name + '.pth')\n", + " print(f'saved best model with val_loss={best_val_loss:.4f}')\n", + " else:\n", + " ind_limit += 1\n", + " print(f'early stopping patience: {ind_limit}/{limit}')\n", + " \n", + " if ind_limit >= limit:\n", + " break\n", + " \n", + " if ind_limit >= limit:\n", + " epoch_limit = epoch\n", + " break\n", + " \n", + " print()\n", + " \n", + " end = timeit.default_timer()\n", + " print(f\"Total time elapsed = {end - start} seconds\")\n", + " epoch_limit += 1\n", + " \n", + " #----------Вывод графиков и сохранение результатов обучения--------------\n", + " \n", + " train_acc = np.asarray(list(map(lambda x: x.item(), train_acc)))\n", + " val_acc = np.asarray(list(map(lambda x: x.item(), val_acc)))\n", + " \n", + " np.save(path_res+'train_acc.npy', train_acc)\n", + " np.save(path_res+'val_acc.npy', val_acc)\n", + " np.save(path_res+'train_loss.npy', train_loss)\n", + " np.save(path_res+'val_loss.npy', val_loss)\n", + " \n", + " plt.figure()\n", + " plt.plot(range(1,epoch_limit), train_loss, color='blue')\n", + " plt.plot(range(1,epoch_limit), val_loss, color='red')\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Loss') \n", + " \n", + " plt.title('Loss Curve')\n", + " plt.legend(['Train Loss', 'Validation Loss'])\n", + " plt.show()\n", + " plt.clf()\n", + " plt.cla()\n", + " plt.close()\n", + " \n", + " plt.figure()\n", + " plt.plot(range(1,epoch_limit), train_acc, color='blue')\n", + " plt.plot(range(1,epoch_limit), val_acc, color='red')\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Accuracy')\n", + " plt.title('Accuracy Curve')\n", + " plt.legend(['Train Accuracy', 'Validation Accuracy'])\n", + " plt.show()\n", + " \n", + " plt.clf()\n", + " plt.cla()\n", + " plt.close()\n", + " torch.cuda.empty_cache()\n", + " cv2.destroyAllWindows()\n", + " del model\n", + " gc.collect()\n", + "\n", + " return path_res, model_name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbbe7fea", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + " warnings.warn(msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error1\n", + "Error1\n", + "Error1\n", + "Error1\n", + "Error1\n", + "Error1\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "No valid image pairs left after shape validation", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 45\u001b[39m\n\u001b[32m 41\u001b[39m model = model.to(device)\n\u001b[32m 42\u001b[39m \n\u001b[32m 43\u001b[39m \u001b[38;5;66;03m#----------Создания датасета и обучение модели--------------\u001b[39;00m\n\u001b[32m 44\u001b[39m \n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n\u001b[32m 46\u001b[39m selected_freq=\u001b[32m1200\u001b[39m,model_name = config_name+\u001b[33m\"1200_\"\u001b[39m, config_name = config_name, model=model)\n\u001b[32m 47\u001b[39m \n\u001b[32m 48\u001b[39m \n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 183\u001b[39m, in \u001b[36mprepare_and_learning_detection\u001b[39m\u001b[34m(num_classes, num_samples, path_dataset, model_name, config_name, model, selected_freq)\u001b[39m\n\u001b[32m 179\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m image_real, image_imag, label\n\u001b[32m 180\u001b[39m \n\u001b[32m 181\u001b[39m \u001b[38;5;66;03m#----------Создаём датасет--------------\u001b[39;00m\n\u001b[32m 182\u001b[39m \n\u001b[32m--> \u001b[39m\u001b[32m183\u001b[39m dataset = MyDataset(path_dataset=path_res, csv_file=\u001b[33m'dataset.csv'\u001b[39m)\n\u001b[32m 184\u001b[39m train_set, valid_set = torch.utils.data.random_split(dataset, [\u001b[32m0.7\u001b[39m, \u001b[32m0.3\u001b[39m], generator=torch.Generator().manual_seed(\u001b[32m42\u001b[39m))\n\u001b[32m 185\u001b[39m batch_size = config.batch_size\n\u001b[32m 186\u001b[39m train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 83\u001b[39m, in \u001b[36mprepare_and_learning_detection..MyDataset.__init__\u001b[39m\u001b[34m(self, path_dataset, csv_file)\u001b[39m\n\u001b[32m 79\u001b[39m data.append(row[\u001b[32m2\u001b[39m: len(row)-\u001b[32m2\u001b[39m])\n\u001b[32m 80\u001b[39m self.path_dataset = path_dataset\n\u001b[32m 81\u001b[39m self.target_shape = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 82\u001b[39m self.target_hw = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m83\u001b[39m self.sig_filenames = self._validate_files(data)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 126\u001b[39m, in \u001b[36mprepare_and_learning_detection..MyDataset._validate_files\u001b[39m\u001b[34m(self, filenames)\u001b[39m\n\u001b[32m 122\u001b[39m shape_counter[real_shape] += \u001b[32m1\u001b[39m\n\u001b[32m 123\u001b[39m candidates.append((filename, real_shape))\n\u001b[32m 124\u001b[39m \n\u001b[32m 125\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;01mnot\u001b[39;00m candidates:\n\u001b[32m--> \u001b[39m\u001b[32m126\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(\u001b[33m\"No valid image pairs left after shape validation\"\u001b[39m)\n\u001b[32m 127\u001b[39m \n\u001b[32m 128\u001b[39m preferred_shape = (\u001b[32m1600\u001b[39m, \u001b[32m1600\u001b[39m, \u001b[32m3\u001b[39m)\n\u001b[32m 129\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m shape_counter.get(preferred_shape, \u001b[32m0\u001b[39m) > \u001b[32m0\u001b[39m:\n", + "\u001b[31mRuntimeError\u001b[39m: No valid image pairs left after shape validation" + ] + } + ], + "source": [ + "torch.cuda.empty_cache()\n", + "cv2.destroyAllWindows()\n", + "gc.collect()\n", + "\n", + "config_name = \"ensemble\"\n", + " \n", + "def load_function(attr):\n", + " module_, func = attr.rsplit('.', maxsplit=1)\n", + " return getattr(import_module(module_), func)\n", + " \n", + "config = mlconfig.load('config_' + config_name + '.yaml')\n", + "\n", + "model1 = models.resnet18(pretrained=False)\n", + "model2 = models.resnet50(pretrained=False)\n", + "\n", + "num_classes = 2\n", + "\n", + "model1.fc = nn.Linear(model1.fc.in_features, num_classes)\n", + "model2.fc = nn.Linear(model2.fc.in_features, num_classes)\n", + "\n", + "class Ensemble(nn.Module):\n", + " def __init__(self, model1, model2):\n", + " super(Ensemble, self).__init__()\n", + " self.model1 = model1\n", + " self.model2 = model2\n", + " self.fc = nn.Linear(2 * num_classes, num_classes)\n", + "\n", + " def forward(self, x):\n", + " x1 = self.model1(x[0])\n", + " x2 = self.model2(x[1])\n", + " x = torch.cat((x1, x2), dim=1)\n", + " x = self.fc(x)\n", + " return x\n", + "model = Ensemble(model1, model2)\n", + "\n", + "optimizer = load_function(config.optimizer.name)(model.parameters(), lr=config.optimizer.lr)\n", + "criterion = load_function(config.loss_function.name)()\n", + "scheduler = load_function(config.scheduler.name)(optimizer, step_size=config.scheduler.step_size, gamma=config.scheduler.gamma)\n", + "\n", + "if device != 'cpu':\n", + " model = model.to(device)\n", + "\n", + "#----------Создания датасета и обучение модели--------------\n", + "\n", + "path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_overlay\", \n", + " selected_freq=2400,model_name = config_name+\"2400_\", config_name = config_name, model=model)\n", + "\n", + "\n", + "torch.cuda.empty_cache()\n", + "cv2.destroyAllWindows()\n", + "del model\n", + "gc.collect()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "usr", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}