You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
from __future__ import annotations
|
|
import logging, os, numpy as np, torch
|
|
from ultralytics import YOLO
|
|
from typing import List
|
|
from . import config
|
|
logger = logging.getLogger("PTZTracker")
|
|
|
|
try:
|
|
from torch.amp import autocast # type: ignore
|
|
except Exception:
|
|
autocast = None # type: ignore
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
WEIGHT_PATH = os.getenv("YOLO_WEIGHTS", "/home/electro-pribory/Desktop/weight/electro-pribory.pt")
|
|
|
|
model_shared = YOLO(WEIGHT_PATH).to(device)
|
|
try:
|
|
model_shared.fuse()
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
model_shared.model.to(memory_format=torch.channels_last) # type: ignore
|
|
except Exception:
|
|
pass
|
|
|
|
if device == "cuda":
|
|
try:
|
|
model_shared.model.half() # type: ignore
|
|
logger.info("Model set to FP16")
|
|
except Exception:
|
|
logger.info("FP16 not supported, using FP32")
|
|
try:
|
|
if hasattr(torch, "compile"):
|
|
model_shared.model = torch.compile(model_shared.model, mode="reduce-overhead") # type: ignore
|
|
logger.info("torch.compile enabled")
|
|
except Exception as e:
|
|
logger.info("torch.compile not used: %s", e)
|
|
|
|
model_shared.model.eval()
|
|
logger.info("Single shared model loaded")
|
|
|
|
try:
|
|
_warm = np.zeros((config.INFER_IMGSZ, config.INFER_IMGSZ, 3), dtype=np.uint8)
|
|
_ = model_shared([_warm], imgsz=config.INFER_IMGSZ, conf=0.1, iou=0.45, verbose=False)
|
|
if device == "cuda":
|
|
torch.cuda.synchronize()
|
|
del _warm
|
|
except Exception:
|
|
pass
|
|
|
|
INFER_ARGS = dict(imgsz=config.INFER_IMGSZ, conf=config.MODEL_CONF, iou=0.45, verbose=False)
|
|
|
|
@torch.inference_mode()
|
|
def yolo_forward(frames: List[np.ndarray]):
|
|
if device == "cuda" and autocast is not None:
|
|
with autocast("cuda", dtype=torch.float16): # type: ignore[misc]
|
|
return model_shared(frames, **INFER_ARGS)
|
|
return model_shared(frames, **INFER_ARGS)
|