You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
476 lines
22 KiB
Plaintext
476 lines
22 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5a13ad6b-56c9-4381-b376-1765f6dd7553",
|
|
"metadata": {
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"# Импортирование библиотек"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "7311cb4a-5bf3-4268-b431-43eea10e9ed6",
|
|
"metadata": {
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/sibscience-4/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.\n",
|
|
" warnings.warn(\"Unable to import Axes3D. This may be due to multiple versions of \"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cuda\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"191"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"from torch import default_generator, randperm\n",
|
|
"from torch.utils.data.dataset import Subset\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"from torchvision.io import read_image\n",
|
|
"from importlib import import_module\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from torchvision import models\n",
|
|
"import torch, torchvision\n",
|
|
"from pathlib import Path\n",
|
|
"from PIL import Image\n",
|
|
"import torch.nn as nn\n",
|
|
"from tqdm import tqdm\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib\n",
|
|
"import os, shutil\n",
|
|
"import mlconfig\n",
|
|
"import random\n",
|
|
"import shutil\n",
|
|
"import timeit\n",
|
|
"import copy\n",
|
|
"import time\n",
|
|
"import cv2\n",
|
|
"import csv\n",
|
|
"import sys\n",
|
|
"import io\n",
|
|
"import gc\n",
|
|
"\n",
|
|
"plt.rcParams[\"savefig.bbox\"] = 'tight'\n",
|
|
"torch.manual_seed(1)\n",
|
|
"#matplotlib.use('Agg')\n",
|
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
|
"print(device)\n",
|
|
"torch.cuda.empty_cache()\n",
|
|
"cv2.destroyAllWindows()\n",
|
|
"gc.collect()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "384de097-82c6-41f5-bda9-b2f54bc99593",
|
|
"metadata": {
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"# Подготовка и обучение детектирование"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "46e4dc99-6994-4fee-a32e-f3983bd991bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def prepare_and_learning_detection(num_classes, num_samples, path_dataset, model_name, config_name, model,selected_freq):\n",
|
|
" num_samples_per_class = num_samples // num_classes\n",
|
|
"\n",
|
|
"\n",
|
|
" #----------Создаём папку для сохранения результатов обучения--------------\n",
|
|
" os.makedirs(\"models\", exist_ok=True)\n",
|
|
" ind = 1\n",
|
|
" while True:\n",
|
|
" if os.path.exists(\"models/\" + model_name + str(ind)):\n",
|
|
" ind += 1\n",
|
|
" else:\n",
|
|
" os.mkdir(\"models/\" + model_name + str(ind))\n",
|
|
" path_res = \"models/\" + model_name + str(ind) + '/'\n",
|
|
" break\n",
|
|
" \n",
|
|
" #----------Создаём файл dataset.csv для обучения--------------\n",
|
|
" \n",
|
|
" pd_columns = ['file_name']\n",
|
|
" df = pd.DataFrame(columns=pd_columns)\n",
|
|
" \n",
|
|
" subdirs = os.listdir(path_dataset)\n",
|
|
" print(subdirs)\n",
|
|
" \n",
|
|
" for subdir in subdirs:\n",
|
|
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq))\n",
|
|
" if not os.path.isdir(freq_dir):\n",
|
|
" continue\n",
|
|
"\n",
|
|
" files = [\n",
|
|
" f for f in os.listdir(freq_dir)\n",
|
|
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('.npy')\n",
|
|
" ]\n",
|
|
" num_samples_per_class = min(num_samples_per_class, len(files))\n",
|
|
"\n",
|
|
" for subdir in subdirs:\n",
|
|
" freq_dir = os.path.join(path_dataset, subdir, str(selected_freq))\n",
|
|
" if not os.path.isdir(freq_dir):\n",
|
|
" continue\n",
|
|
"\n",
|
|
" files = [\n",
|
|
" f for f in os.listdir(freq_dir)\n",
|
|
" if os.path.isfile(os.path.join(freq_dir, f)) and f.endswith('.npy')\n",
|
|
" ]\n",
|
|
" random.shuffle(files)\n",
|
|
" files_to_process = files[:num_samples_per_class]\n",
|
|
"\n",
|
|
" for file in files_to_process:\n",
|
|
" row = pd.DataFrame({\n",
|
|
" pd_columns[0]: [str(os.path.join(freq_dir, file))]\n",
|
|
" })\n",
|
|
" df = pd.concat([df, row], ignore_index=True)\n",
|
|
"\n",
|
|
" dataset_csv_path = os.path.join(path_res, 'dataset.csv')\n",
|
|
" df.to_csv(dataset_csv_path, index=False)\n",
|
|
"\n",
|
|
" if not os.path.exists(dataset_csv_path):\n",
|
|
" raise RuntimeError(f'dataset.csv was not created: {dataset_csv_path}')\n",
|
|
" #----------Импортируем параметры для обучения--------------\n",
|
|
" \n",
|
|
" def load_function(attr):\n",
|
|
" module_, func = attr.rsplit('.', maxsplit=1)\n",
|
|
" return getattr(import_module(module_), func)\n",
|
|
" \n",
|
|
" config = mlconfig.load('config_' + config_name + '.yaml')\n",
|
|
" \n",
|
|
" #----------Создаём класс датасета--------------\n",
|
|
" \n",
|
|
" class MyDataset(Dataset):\n",
|
|
" def __init__(self, path_dataset, csv_file):\n",
|
|
" data=[]\n",
|
|
" with open(os.path.join(path_dataset, csv_file), newline='') as csvfile:\n",
|
|
" reader = csv.reader(csvfile, delimiter=' ', quotechar='|')\n",
|
|
" for row in list(reader)[1:]:\n",
|
|
" row = str(row)\n",
|
|
" data.append(row[2: len(row)-2])\n",
|
|
"\n",
|
|
" self.sig_filenames = data\n",
|
|
" self.path_dataset = path_dataset\n",
|
|
" \n",
|
|
" def __len__(self):\n",
|
|
" return len(self.sig_filenames)\n",
|
|
" \n",
|
|
" def __getitem__(self, idx):\n",
|
|
" base = os.path.splitext(self.sig_filenames[idx])[0]\n",
|
|
"\n",
|
|
" #raise(base)\n",
|
|
"\n",
|
|
" exit(1)\n",
|
|
" image_real = np.asarray(cv2.split(cv2.imread(base + '_real.png')), dtype=np.float32)\n",
|
|
" image_imag = np.asarray(cv2.split(cv2.imread(base + '_imag.png')), dtype=np.float32)\n",
|
|
" image_spec = np.asarray(cv2.split(cv2.imread(base + '_spec.png')), dtype=np.float32)\n",
|
|
"\n",
|
|
" if 'drone' in list(self.sig_filenames[idx].split('/')):\n",
|
|
" label = torch.tensor(0)\n",
|
|
" if 'noise' in list(self.sig_filenames[idx].split('/')):\n",
|
|
" exit(1)\n",
|
|
" label = torch.tensor(1)\n",
|
|
" return image_real, image_imag, image_spec, label\n",
|
|
" \n",
|
|
" #----------Создаём датасет--------------\n",
|
|
" exit(1)\n",
|
|
" dataset = MyDataset(path_dataset=path_res, csv_file='dataset.csv')\n",
|
|
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n",
|
|
" batch_size = config.batch_size\n",
|
|
" train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
|
|
" valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
|
|
" \n",
|
|
" dataloaders = {}\n",
|
|
" dataloaders['train'] = train_dataloader\n",
|
|
" dataloaders['val'] = valid_dataloader\n",
|
|
" dataset_sizes = {}\n",
|
|
" dataset_sizes['train'] = len(train_set)\n",
|
|
" dataset_sizes['val'] = len(valid_set)\n",
|
|
"\n",
|
|
" #----------Обучаем модель--------------\n",
|
|
"\n",
|
|
" val_loss = []\n",
|
|
" val_acc = []\n",
|
|
" train_loss = []\n",
|
|
" train_acc = []\n",
|
|
" epochs = config.epoch\n",
|
|
" \n",
|
|
" best_acc = 0.0\n",
|
|
" best_model = copy.deepcopy(model.state_dict())\n",
|
|
" limit = config.limit\n",
|
|
" epoch_limit = epochs\n",
|
|
" \n",
|
|
" start = timeit.default_timer()\n",
|
|
" for epoch in range(1, epochs+1):\n",
|
|
" print(f\"Epoch : {epoch}\\n\")\n",
|
|
" dataloader = None\n",
|
|
" \n",
|
|
" for phase in ['train', 'val']:\n",
|
|
" running_loss = 0.0\n",
|
|
" running_corrects = 0\n",
|
|
" \n",
|
|
" for (img1, img2, img3, label) in tqdm(dataloaders[phase]):\n",
|
|
" img1, img2, img3, label = img1.to(device), img2.to(device), img3.to(device), label.to(device)\n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" with torch.set_grad_enabled(phase == 'train'):\n",
|
|
" output = model([img1, img2, img3])\n",
|
|
" _, pred = torch.max(output.data, 1)\n",
|
|
" loss = criterion(output, label)\n",
|
|
" if phase=='train' :\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" running_loss += loss.item() * 3 * img1.size(0)\n",
|
|
" running_corrects += torch.sum(pred == label.data)\n",
|
|
" \n",
|
|
" epoch_loss = running_loss / dataset_sizes[phase]\n",
|
|
" epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
|
|
" \n",
|
|
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
|
|
" \n",
|
|
" if phase=='train' :\n",
|
|
" train_loss.append(epoch_loss)\n",
|
|
" train_acc.append(epoch_acc)\n",
|
|
" else :\n",
|
|
" val_loss.append(epoch_loss)\n",
|
|
" val_acc.append(epoch_acc)\n",
|
|
" if val_acc[-1] > best_acc :\n",
|
|
" ind_limit = 0\n",
|
|
" best_acc = val_acc[-1]\n",
|
|
" best_model = copy.deepcopy(model.state_dict())\n",
|
|
" torch.save(best_model, path_res + model_name + '.pth')\n",
|
|
" else:\n",
|
|
" ind_limit += 1\n",
|
|
" \n",
|
|
" if ind_limit >= limit:\n",
|
|
" break\n",
|
|
" \n",
|
|
" if ind_limit >= limit:\n",
|
|
" epoch_limit = epoch\n",
|
|
" break\n",
|
|
" \n",
|
|
" print()\n",
|
|
" \n",
|
|
" end = timeit.default_timer()\n",
|
|
" print(f\"Total time elapsed = {end - start} seconds\")\n",
|
|
" epoch_limit += 1\n",
|
|
" \n",
|
|
" #----------Вывод графиков и сохранение результатов обучения--------------\n",
|
|
" \n",
|
|
" train_acc = np.asarray(list(map(lambda x: x.item(), train_acc)))\n",
|
|
" val_acc = np.asarray(list(map(lambda x: x.item(), val_acc)))\n",
|
|
" \n",
|
|
" np.save(path_res+'train_acc.npy', train_acc)\n",
|
|
" np.save(path_res+'val_acc.npy', val_acc)\n",
|
|
" np.save(path_res+'train_loss.npy', train_loss)\n",
|
|
" np.save(path_res+'val_loss.npy', val_loss)\n",
|
|
" \n",
|
|
" plt.figure()\n",
|
|
" plt.plot(range(1,epoch_limit), train_loss, color='blue')\n",
|
|
" plt.plot(range(1,epoch_limit), val_loss, color='red')\n",
|
|
" plt.xlabel('Epoch')\n",
|
|
" plt.ylabel('Loss')\n",
|
|
" plt.title('Loss Curve')\n",
|
|
" plt.legend(['Train Loss', 'Validation Loss'])\n",
|
|
" plt.show()\n",
|
|
" plt.clf()\n",
|
|
" plt.cla()\n",
|
|
" plt.close()\n",
|
|
" \n",
|
|
" plt.figure()\n",
|
|
" plt.plot(range(1,epoch_limit), train_acc, color='blue')\n",
|
|
" plt.plot(range(1,epoch_limit), val_acc, color='red')\n",
|
|
" plt.xlabel('Epoch')\n",
|
|
" plt.ylabel('Accuracy')\n",
|
|
" plt.title('Accuracy Curve')\n",
|
|
" plt.legend(['Train Accuracy', 'Validation Accuracy'])\n",
|
|
" plt.show()\n",
|
|
" \n",
|
|
" plt.clf()\n",
|
|
" plt.cla()\n",
|
|
" plt.close()\n",
|
|
" torch.cuda.empty_cache()\n",
|
|
" cv2.destroyAllWindows()\n",
|
|
" del model\n",
|
|
" gc.collect()\n",
|
|
"\n",
|
|
" return path_res, model_name"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "93c136ee",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Ensemble"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "52e8d4c5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['noise', 'drone']\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/tmp/ipykernel_3754538/2835334388.py:103: UserWarning: Length of split at index 0 is 0. This might result in an empty dataset.\n",
|
|
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n",
|
|
"/tmp/ipykernel_3754538/2835334388.py:103: UserWarning: Length of split at index 1 is 0. This might result in an empty dataset.\n",
|
|
" train_set, valid_set = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42))\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "num_samples should be a positive integer value, but got num_samples=0",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
"\u001b[31mValueError\u001b[39m Traceback (most recent call last)",
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[13]\u001b[39m\u001b[32m, line 49\u001b[39m\n\u001b[32m 45\u001b[39m model = model.to(device)\n\u001b[32m 46\u001b[39m \n\u001b[32m 47\u001b[39m \u001b[38;5;66;03m#----------Создания датасета и обучение модели--------------\u001b[39;00m\n\u001b[32m 48\u001b[39m \n\u001b[32m---> \u001b[39m\u001b[32m49\u001b[39m path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_img\", \n\u001b[32m 50\u001b[39m selected_freq=\u001b[32m2400\u001b[39m,model_name = config_name+\u001b[33m\"_2.4_jpg_\"\u001b[39m, config_name = config_name, model=model)\n\u001b[32m 51\u001b[39m \n\u001b[32m 52\u001b[39m \n",
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 105\u001b[39m, in \u001b[36mprepare_and_learning_detection\u001b[39m\u001b[34m(num_classes, num_samples, path_dataset, model_name, config_name, model, selected_freq)\u001b[39m\n\u001b[32m 101\u001b[39m exit(\u001b[32m1\u001b[39m)\n\u001b[32m 102\u001b[39m dataset = MyDataset(path_dataset=path_res, csv_file=\u001b[33m'dataset.csv'\u001b[39m)\n\u001b[32m 103\u001b[39m train_set, valid_set = torch.utils.data.random_split(dataset, [\u001b[32m0.7\u001b[39m, \u001b[32m0.3\u001b[39m], generator=torch.Generator().manual_seed(\u001b[32m42\u001b[39m))\n\u001b[32m 104\u001b[39m batch_size = config.batch_size\n\u001b[32m--> \u001b[39m\u001b[32m105\u001b[39m train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m 106\u001b[39m valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m 107\u001b[39m \n\u001b[32m 108\u001b[39m dataloaders = {}\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torch/utils/data/dataloader.py:394\u001b[39m, in \u001b[36mDataLoader.__init__\u001b[39m\u001b[34m(self, dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn, multiprocessing_context, generator, prefetch_factor, persistent_workers, pin_memory_device, in_order)\u001b[39m\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m: \u001b[38;5;66;03m# map-style\u001b[39;00m\n\u001b[32m 393\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m shuffle:\n\u001b[32m--> \u001b[39m\u001b[32m394\u001b[39m sampler = \u001b[30;43mRandomSampler\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mdataset\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mgenerator\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mgenerator\u001b[39;49m\u001b[30;43m)\u001b[39;49m \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[32m 395\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 396\u001b[39m sampler = SequentialSampler(dataset) \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/from_ssh/DroneDetector/.venv-train/lib/python3.12/site-packages/torch/utils/data/sampler.py:149\u001b[39m, in \u001b[36mRandomSampler.__init__\u001b[39m\u001b[34m(self, data_source, replacement, num_samples, generator)\u001b[39m\n\u001b[32m 144\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 145\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mreplacement should be a boolean value, but got replacement=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.replacement\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 146\u001b[39m )\n\u001b[32m 148\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m.num_samples, \u001b[38;5;28mint\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m.num_samples <= \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 150\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mnum_samples should be a positive integer value, but got num_samples=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.num_samples\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 151\u001b[39m )\n",
|
|
"\u001b[31mValueError\u001b[39m: num_samples should be a positive integer value, but got num_samples=0"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#----------Инициализируем модель и параметры обучения--------------\n",
|
|
"\n",
|
|
"torch.cuda.empty_cache()\n",
|
|
"cv2.destroyAllWindows()\n",
|
|
"gc.collect()\n",
|
|
"\n",
|
|
"num_classes = 2\n",
|
|
"config_name = \"ensemble\"\n",
|
|
" \n",
|
|
"def load_function(attr):\n",
|
|
" module_, func = attr.rsplit('.', maxsplit=1)\n",
|
|
" return getattr(import_module(module_), func)\n",
|
|
" \n",
|
|
"config = mlconfig.load('config_' + config_name + '.yaml')\n",
|
|
"\n",
|
|
"model1 = models.resnet18(pretrained=False)\n",
|
|
"model2 = models.resnet50(pretrained=False)\n",
|
|
"\n",
|
|
"num_classes = 2\n",
|
|
"\n",
|
|
"model1.fc = nn.Linear(model1.fc.in_features, num_classes)\n",
|
|
"model2.fc = nn.Linear(model2.fc.in_features, num_classes)\n",
|
|
"\n",
|
|
"class Ensemble(nn.Module):\n",
|
|
" def __init__(self, model1, model2):\n",
|
|
" super(Ensemble, self).__init__()\n",
|
|
" self.model1 = model1\n",
|
|
" self.model2 = model2\n",
|
|
" self.fc = nn.Linear(2 * num_classes, num_classes)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x1 = self.model1(x[0])\n",
|
|
" x2 = self.model2(x[1])\n",
|
|
" x = torch.cat((x1, x2), dim=1)\n",
|
|
" x = self.fc(x)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"model = Ensemble(model1, model2)\n",
|
|
"\n",
|
|
"optimizer = load_function(config.optimizer.name)(model.parameters(), lr=config.optimizer.lr)\n",
|
|
"criterion = load_function(config.loss_function.name)()\n",
|
|
"scheduler = load_function(config.scheduler.name)(optimizer, step_size=config.scheduler.step_size, gamma=config.scheduler.gamma)\n",
|
|
"\n",
|
|
"if device != 'cpu':\n",
|
|
" model = model.to(device)\n",
|
|
"\n",
|
|
"#----------Создания датасета и обучение модели--------------\n",
|
|
"\n",
|
|
"path_res, model_name = prepare_and_learning_detection(num_classes = num_classes, num_samples = 10000, path_dataset = \"/mnt/data/Dataset_img\", \n",
|
|
" selected_freq=2400,model_name = config_name+\"_2.4_jpg_\", config_name = config_name, model=model)\n",
|
|
"\n",
|
|
"\n",
|
|
"torch.cuda.empty_cache()\n",
|
|
"cv2.destroyAllWindows()\n",
|
|
"del model\n",
|
|
"gc.collect()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4234ee26",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"celltoolbar": "Отсутствует",
|
|
"kernelspec": {
|
|
"display_name": ".venv-train",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|