You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DroneDetector/train_scripts/create_dataset_overlay.py

218 lines
7.6 KiB
Python

#!/usr/bin/env python3
import argparse
import csv
import os
import random
import shutil
from pathlib import Path
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
DEFAULT_FREQS = (1200, 2400)
PNG_SUFFIXES = ("_real.png", "_imag.png", "_spec.png")
def parse_args():
parser = argparse.ArgumentParser(
description=(
"Build a two-class image dataset with drone signatures overlaid on noise images. "
"The output is ready for Training_models2pic_val_loss.ipynb."
)
)
parser.add_argument("--drone-root", default="/mnt/data/Dataset/drone")
parser.add_argument("--noise-img-root", default="/mnt/data/Dataset_img/noise")
parser.add_argument("--output-root", default="/mnt/data/Dataset_overlay")
parser.add_argument("--freqs", default=",".join(str(v) for v in DEFAULT_FREQS))
parser.add_argument("--alpha", type=float, default=1.0, help="Overlay strength: 1.0 keeps the darkest drone/noise pixels.")
parser.add_argument("--limit-per-freq", type=int, default=0, help="0 means use all available noise images per frequency.")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--copy-noise", action="store_true", help="Copy noise files instead of hardlinking them.")
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bars.")
return parser.parse_args()
def parse_freqs(value):
return [int(item.strip()) for item in value.split(",") if item.strip()]
def collect_drone_npys(root, freq):
root = Path(root)
candidates = []
candidates.extend((root / str(freq)).rglob("*.npy"))
candidates.extend((root / f"{freq}_jpg").glob("*.npy"))
return sorted({p for p in candidates if p.is_file()})
def collect_noise_npys(root, freq):
root = Path(root)
candidates = []
candidates.extend((root / f"{freq}_jpg").glob("*.npy"))
candidates.extend((root / str(freq)).rglob("*.npy"))
return sorted({p for p in candidates if p.is_file()})
def load_image_tensor(path):
arr = np.load(path)
if arr.ndim == 2:
arr = np.stack([arr, arr, arr], axis=0)
if arr.ndim == 3 and arr.shape[-1] in (1, 3) and arr.shape[0] not in (1, 3):
arr = np.moveaxis(arr, -1, 0)
if arr.ndim != 3:
raise ValueError(f"expected 3D image tensor, got shape={arr.shape} path={path}")
if arr.shape[0] == 1:
arr = np.repeat(arr, 3, axis=0)
if arr.shape[0] < 3:
raise ValueError(f"expected at least 3 channels, got shape={arr.shape} path={path}")
return arr[:3].astype(np.float32, copy=False)
def resize_like(arr, shape):
if arr.shape == shape:
return arr
channels, height, width = shape
resized = []
for channel in arr[:channels]:
resized.append(cv2.resize(channel, (width, height), interpolation=cv2.INTER_LINEAR))
return np.asarray(resized, dtype=np.float32)
def overlay_tensors(noise, drone, alpha):
drone = resize_like(drone, noise.shape)
bright_overlay = np.minimum(noise, drone)
mixed = (1.0 - alpha) * noise + alpha * bright_overlay
return np.clip(mixed, 0, 255).astype(np.float32)
def to_uint8(channel):
arr = np.asarray(channel, dtype=np.float32)
finite = arr[np.isfinite(arr)]
if finite.size == 0:
return np.zeros(arr.shape, dtype=np.uint8)
min_v = float(finite.min())
max_v = float(finite.max())
if min_v >= 0.0 and max_v <= 255.0:
return np.clip(arr, 0, 255).astype(np.uint8)
if max_v == min_v:
return np.zeros(arr.shape, dtype=np.uint8)
norm = (arr - min_v) / (max_v - min_v)
return np.clip(norm * 255.0, 0, 255).astype(np.uint8)
def save_notebook_style_png(path, channel):
fig = plt.figure()
plt.imshow(channel)
plt.savefig(path)
plt.clf()
plt.cla()
plt.close()
plt.close(fig)
def save_sample(base_path, tensor):
np.save(str(base_path) + ".npy", tensor.astype(np.float32))
for idx, suffix in enumerate(PNG_SUFFIXES):
save_notebook_style_png(str(base_path) + suffix, tensor[idx])
def link_or_copy(src, dst, copy_file):
dst.parent.mkdir(parents=True, exist_ok=True)
if dst.exists():
return
if copy_file:
shutil.copy2(src, dst)
return
try:
os.link(src, dst)
except OSError:
shutil.copy2(src, dst)
def copy_noise_family(noise_npy, out_dir, copy_file):
base_name = noise_npy.name[:-4] if noise_npy.name.endswith(".npy") else noise_npy.name
link_or_copy(noise_npy, out_dir / noise_npy.name, copy_file)
for suffix in PNG_SUFFIXES:
sidecar = noise_npy.with_name(base_name + suffix)
if sidecar.exists():
link_or_copy(sidecar, out_dir / sidecar.name, copy_file)
def main():
args = parse_args()
random.seed(args.seed)
freqs = parse_freqs(args.freqs)
output_root = Path(args.output_root)
manifest_rows = []
if args.overwrite and output_root.exists() and not args.dry_run:
shutil.rmtree(output_root)
for freq in freqs:
drone_files = collect_drone_npys(args.drone_root, freq)
noise_files = collect_noise_npys(args.noise_img_root, freq)
if not drone_files:
raise RuntimeError(f"no drone npy files found for freq={freq} under {args.drone_root}")
if not noise_files:
raise RuntimeError(f"no noise image npy files found for freq={freq} under {args.noise_img_root}")
count = len(noise_files) if args.limit_per_freq <= 0 else min(args.limit_per_freq, len(noise_files))
selected_noise = noise_files[:]
random.shuffle(selected_noise)
selected_noise = selected_noise[:count]
out_drone_dir = output_root / "drone" / f"{freq}_jpg"
out_noise_dir = output_root / "noise" / f"{freq}_jpg"
print(f"freq={freq}: drone_source={len(drone_files)} noise_source={len(noise_files)} output_per_class={count}")
if args.dry_run:
continue
out_drone_dir.mkdir(parents=True, exist_ok=True)
out_noise_dir.mkdir(parents=True, exist_ok=True)
iterator = tqdm(
enumerate(selected_noise),
total=count,
desc=f"overlay freq={freq}",
unit="sample",
disable=args.no_progress,
)
for idx, noise_path in iterator:
drone_path = drone_files[idx % len(drone_files)]
noise_tensor = load_image_tensor(noise_path)
drone_tensor = load_image_tensor(drone_path)
mixed = overlay_tensors(noise_tensor, drone_tensor, args.alpha)
out_base = out_drone_dir / f"overlay_{freq}_{idx:06d}"
save_sample(out_base, mixed)
copy_noise_family(noise_path, out_noise_dir, args.copy_noise)
manifest_rows.append({
"freq": freq,
"output": str(out_base) + ".npy",
"noise_source": str(noise_path),
"drone_source": str(drone_path),
"alpha": args.alpha,
})
if not args.dry_run:
manifest_path = output_root / "overlay_manifest.csv"
with manifest_path.open("w", newline="", encoding="utf-8") as fh:
writer = csv.DictWriter(fh, fieldnames=["freq", "output", "noise_source", "drone_source", "alpha"])
writer.writeheader()
writer.writerows(manifest_rows)
print(f"wrote {len(manifest_rows)} overlay samples")
print(f"manifest: {manifest_path}")
print(f"dataset: {output_root}")
if __name__ == "__main__":
main()