trap/trap/face_detector.py
2025-04-03 20:59:40 +02:00

170 lines
5.6 KiB
Python

from argparse import Namespace
from collections import defaultdict
import csv
from dataclasses import dataclass, field
import json
import logging
from math import nan
from multiprocessing import Event
import multiprocessing
from pathlib import Path
import pickle
import time
from typing import DefaultDict, Dict, Optional, List
import jsonlines
import numpy as np
import torch
import torchvision
import ultralytics
import zmq
import cv2
from facenet_pytorch import InceptionResnetV1, MTCNN
from trap.base import Frame
logger = logging.getLogger('trap.face_detector')
class FaceDetector:
def __init__(self, config: Namespace):
self.config = config
self.context = zmq.Context()
self.frame_sock = self.context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(self.config.zmq_frame_addr)
self.face_socket = self.context.socket(zmq.PUB)
self.face_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.face_socket.bind(self.config.zmq_face_addr)
# # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def track(self, is_running: Event, timer_counter: int = 0):
"""
Live tracking of frames coming in over zmq
"""
self.is_running = is_running
prev_frame_i = -1
# For a model pretrained on CASIA-Webface
# model = InceptionResnetV1(pretrained='casia-webface').eval().to(self.device)
# mtcnn = MTCNN(
# image_size=160, margin=0, min_face_size=10,
# thresholds=[0.3, 0.3, 0.3], factor=0.709, post_process=True,
# device=self.device, keep_all=True
# )
# modelpath = Path("face_detection_yunet_2023mar_int8bq.onnx")
modelpath = Path("face_detection_yunet_2023mar_int8.onnx")
# model = YuNet(modelPath=args.model,
# inputSize=[320, 320],
# confThreshold=args.conf_threshold,
# nmsThreshold=args.nms_threshold,
# topK=args.top_k,
# backendId=backend_id,
# targetId=target_id)
detector = cv2.FaceDetectorYN.create(
str(modelpath),
"",
(320, 320),
.3,
.3,
5000,
cv2.dnn.DNN_BACKEND_CUDA,
target_id=cv2.dnn.DNN_TARGET_CUDA
)
while self.is_running.is_set():
with timer_counter.get_lock():
timer_counter.value += 1
poll_time = time.time()
zmq_ev = self.frame_sock.poll(timeout=2000)
if not zmq_ev:
logger.warning('skip poll after 2000ms')
# when there's no data after timeout, loop so that is_running is checked
continue
start_time = time.time()
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
# print(time.time()- frame.time)
if frame.index > (prev_frame_i+1):
logger.warning(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=}) -- poll time {start_time-poll_time:.5f}")
height, width, channels = frame.img.shape
detector.setInputSize((width//2, height//2))
img = cv2.resize(frame.img, (width//2, height//2))
faces = detector.detect(img)
prev_frame_i = frame.index
# print(f"send to {self.trajectory_socket}, {self.config.zmq_trajectory_addr}")
self.face_socket.send_pyobj(faces) # ditch image for faster passthrough
logger.info('Stopping')
def run_detector(config: Namespace, is_running: Event, timer_counter):
router = FaceDetector(config)
router.track(is_running, timer_counter)
def run():
# Frame emitter
import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument('--zmq-frame-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame")
argparser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds_traj")
argparser.add_argument("--save-for-training",
help="Specify the path in which to save",
type=Path,
default=None)
argparser.add_argument("--detector",
help="Specify the detector to use",
type=str,
default=DETECTOR_YOLOv8,
choices=DETECTORS)
argparser.add_argument("--tracker",
help="Specify the detector to use",
type=str,
default=TRACKER_BYTETRACK,
choices=TRACKERS)
argparser.add_argument("--smooth-tracks",
help="Smooth the tracker tracks before sending them to the predictor",
action='store_true')
config = argparser.parse_args()
is_running = multiprocessing.Event()
is_running.set()
timer_counter = timer.Timer('frame_emitter')
router = Tracker(config)
router.track(is_running, timer_counter.iterations)
is_running.clear()