You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
1.8 KiB
Python

from __future__ import annotations
import logging, os, numpy as np, torch
from ultralytics import YOLO
from typing import List
from . import config
logger = logging.getLogger("PTZTracker")
try:
from torch.amp import autocast # type: ignore
except Exception:
autocast = None # type: ignore
device = "cuda" if torch.cuda.is_available() else "cpu"
WEIGHT_PATH = os.getenv("YOLO_WEIGHTS", "/home/electro-pribory/Desktop/weight/electro-pribory.pt")
model_shared = YOLO(WEIGHT_PATH).to(device)
try:
model_shared.fuse()
except Exception:
pass
try:
model_shared.model.to(memory_format=torch.channels_last) # type: ignore
except Exception:
pass
if device == "cuda":
try:
model_shared.model.half() # type: ignore
logger.info("Model set to FP16")
except Exception:
logger.info("FP16 not supported, using FP32")
try:
if hasattr(torch, "compile"):
model_shared.model = torch.compile(model_shared.model, mode="reduce-overhead") # type: ignore
logger.info("torch.compile enabled")
except Exception as e:
logger.info("torch.compile not used: %s", e)
model_shared.model.eval()
logger.info("Single shared model loaded")
try:
_warm = np.zeros((config.INFER_IMGSZ, config.INFER_IMGSZ, 3), dtype=np.uint8)
_ = model_shared([_warm], imgsz=config.INFER_IMGSZ, conf=0.1, iou=0.45, verbose=False)
if device == "cuda":
torch.cuda.synchronize()
del _warm
except Exception:
pass
INFER_ARGS = dict(imgsz=config.INFER_IMGSZ, conf=config.MODEL_CONF, iou=0.45, verbose=False)
@torch.inference_mode()
def yolo_forward(frames: List[np.ndarray]):
if device == "cuda" and autocast is not None:
with autocast("cuda", dtype=torch.float16): # type: ignore[misc]
return model_shared(frames, **INFER_ARGS)
return model_shared(frames, **INFER_ARGS)