170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
from argparse import Namespace
|
|
from collections import defaultdict
|
|
import csv
|
|
from dataclasses import dataclass, field
|
|
import json
|
|
import logging
|
|
from math import nan
|
|
from multiprocessing import Event
|
|
import multiprocessing
|
|
from pathlib import Path
|
|
import pickle
|
|
import time
|
|
from typing import DefaultDict, Dict, Optional, List
|
|
import jsonlines
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
import ultralytics
|
|
import zmq
|
|
import cv2
|
|
|
|
from facenet_pytorch import InceptionResnetV1, MTCNN
|
|
|
|
from trap.base import Frame
|
|
|
|
logger = logging.getLogger('trap.face_detector')
|
|
|
|
class FaceDetector:
|
|
def __init__(self, config: Namespace):
|
|
self.config = config
|
|
|
|
|
|
self.context = zmq.Context()
|
|
self.frame_sock = self.context.socket(zmq.SUB)
|
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
|
self.frame_sock.connect(self.config.zmq_frame_addr)
|
|
|
|
|
|
|
|
self.face_socket = self.context.socket(zmq.PUB)
|
|
self.face_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
|
self.face_socket.bind(self.config.zmq_face_addr)
|
|
|
|
|
|
# # TODO: config device
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
def track(self, is_running: Event, timer_counter: int = 0):
|
|
"""
|
|
Live tracking of frames coming in over zmq
|
|
"""
|
|
|
|
self.is_running = is_running
|
|
|
|
prev_frame_i = -1
|
|
|
|
# For a model pretrained on CASIA-Webface
|
|
# model = InceptionResnetV1(pretrained='casia-webface').eval().to(self.device)
|
|
# mtcnn = MTCNN(
|
|
# image_size=160, margin=0, min_face_size=10,
|
|
# thresholds=[0.3, 0.3, 0.3], factor=0.709, post_process=True,
|
|
# device=self.device, keep_all=True
|
|
|
|
# )
|
|
# modelpath = Path("face_detection_yunet_2023mar_int8bq.onnx")
|
|
modelpath = Path("face_detection_yunet_2023mar_int8.onnx")
|
|
# model = YuNet(modelPath=args.model,
|
|
# inputSize=[320, 320],
|
|
# confThreshold=args.conf_threshold,
|
|
# nmsThreshold=args.nms_threshold,
|
|
# topK=args.top_k,
|
|
# backendId=backend_id,
|
|
# targetId=target_id)
|
|
detector = cv2.FaceDetectorYN.create(
|
|
str(modelpath),
|
|
"",
|
|
(320, 320),
|
|
.3,
|
|
.3,
|
|
5000,
|
|
cv2.dnn.DNN_BACKEND_CUDA,
|
|
target_id=cv2.dnn.DNN_TARGET_CUDA
|
|
)
|
|
|
|
|
|
|
|
while self.is_running.is_set():
|
|
|
|
with timer_counter.get_lock():
|
|
timer_counter.value += 1
|
|
|
|
poll_time = time.time()
|
|
zmq_ev = self.frame_sock.poll(timeout=2000)
|
|
if not zmq_ev:
|
|
logger.warning('skip poll after 2000ms')
|
|
# when there's no data after timeout, loop so that is_running is checked
|
|
continue
|
|
|
|
start_time = time.time()
|
|
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
|
|
|
# print(time.time()- frame.time)
|
|
|
|
if frame.index > (prev_frame_i+1):
|
|
logger.warning(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=}) -- poll time {start_time-poll_time:.5f}")
|
|
|
|
height, width, channels = frame.img.shape
|
|
|
|
detector.setInputSize((width//2, height//2))
|
|
|
|
img = cv2.resize(frame.img, (width//2, height//2))
|
|
|
|
faces = detector.detect(img)
|
|
|
|
prev_frame_i = frame.index
|
|
|
|
# print(f"send to {self.trajectory_socket}, {self.config.zmq_trajectory_addr}")
|
|
self.face_socket.send_pyobj(faces) # ditch image for faster passthrough
|
|
|
|
|
|
logger.info('Stopping')
|
|
|
|
|
|
|
|
def run_detector(config: Namespace, is_running: Event, timer_counter):
|
|
router = FaceDetector(config)
|
|
router.track(is_running, timer_counter)
|
|
|
|
def run():
|
|
# Frame emitter
|
|
import argparse
|
|
argparser = argparse.ArgumentParser()
|
|
argparser.add_argument('--zmq-frame-addr',
|
|
help='Manually specity communication addr for the frame messages',
|
|
type=str,
|
|
default="ipc:///tmp/feeds_frame")
|
|
argparser.add_argument('--zmq-trajectory-addr',
|
|
help='Manually specity communication addr for the trajectory messages',
|
|
type=str,
|
|
default="ipc:///tmp/feeds_traj")
|
|
|
|
argparser.add_argument("--save-for-training",
|
|
help="Specify the path in which to save",
|
|
type=Path,
|
|
default=None)
|
|
argparser.add_argument("--detector",
|
|
help="Specify the detector to use",
|
|
type=str,
|
|
default=DETECTOR_YOLOv8,
|
|
choices=DETECTORS)
|
|
argparser.add_argument("--tracker",
|
|
help="Specify the detector to use",
|
|
type=str,
|
|
default=TRACKER_BYTETRACK,
|
|
choices=TRACKERS)
|
|
argparser.add_argument("--smooth-tracks",
|
|
help="Smooth the tracker tracks before sending them to the predictor",
|
|
action='store_true')
|
|
config = argparser.parse_args()
|
|
is_running = multiprocessing.Event()
|
|
is_running.set()
|
|
timer_counter = timer.Timer('frame_emitter')
|
|
|
|
router = Tracker(config)
|
|
router.track(is_running, timer_counter.iterations)
|
|
is_running.clear()
|
|
|