Test with alternative NN for prediction

This commit is contained in:
Ruben van de Ven 2025-04-03 20:59:40 +02:00
parent ccddc71f83
commit e3224aa47f
4 changed files with 2558 additions and 18 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

68
trap/counter.py Normal file
View file

@ -0,0 +1,68 @@
import collections
import logging
import statistics
import time
from typing import MutableSequence
import zmq
logger = logging.getLogger('counter')
class CounterSender:
def __init__(self, address = "ipc:///tmp/trap-counters"):
# self.name = name
self.context = zmq.Context()
self.sock = self.context.socket(zmq.PUB)
# self.sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
# self.sock.sndhwm = 1
self.sock.bind(address)
def set(self, name:str, value:float):
try:
self.sock.send_multipart([name.encode('utf8'), str(value).encode("utf8")], flags=zmq.NOBLOCK)
except zmq.ZMQError as e:
logger.warning(f"No space in que to count {name} as {value}")
class CounterLog():
def __init__(self, history = 20):
self.history: MutableSequence[(float, float)] = collections.deque(maxlen=history)
def add(self, value):
self.history.append((time.perf_counter(), value))
def value(self):
if not len(self.history):
return None
return self.history[-1][1]
def avg(self):
if not len(self.history):
return 0.
return statistics.fmean([h[1] for h in self.history])
class CounterListerner():
def __init__(self, address = "ipc:///tmp/trap-counters"):
self.context = zmq.Context()
self.sock = self.context.socket(zmq.SUB)
self.sock.connect(address)
self.sock.subscribe( b'')
self.values: collections.defaultdict[str, CounterLog] = collections.defaultdict(lambda: CounterLog())
def snapshot(self):
messages = []
while self.sock.poll(0) == zmq.POLLIN:
name, value = self.sock.recv_multipart()
name, value = name.decode('utf8'),float(value.decode('utf8'))
self.values[name].add(value)
def get_latest(self):
self.snapshot()
return self.values
def to_string(self):
strs = [f"{k}: {v.value()} ({v.avg()})" for (k,v) in self.values.items()]
return " ".join(strs)

170
trap/face_detector.py Normal file
View file

@ -0,0 +1,170 @@
from argparse import Namespace
from collections import defaultdict
import csv
from dataclasses import dataclass, field
import json
import logging
from math import nan
from multiprocessing import Event
import multiprocessing
from pathlib import Path
import pickle
import time
from typing import DefaultDict, Dict, Optional, List
import jsonlines
import numpy as np
import torch
import torchvision
import ultralytics
import zmq
import cv2
from facenet_pytorch import InceptionResnetV1, MTCNN
from trap.base import Frame
logger = logging.getLogger('trap.face_detector')
class FaceDetector:
def __init__(self, config: Namespace):
self.config = config
self.context = zmq.Context()
self.frame_sock = self.context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(self.config.zmq_frame_addr)
self.face_socket = self.context.socket(zmq.PUB)
self.face_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.face_socket.bind(self.config.zmq_face_addr)
# # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def track(self, is_running: Event, timer_counter: int = 0):
"""
Live tracking of frames coming in over zmq
"""
self.is_running = is_running
prev_frame_i = -1
# For a model pretrained on CASIA-Webface
# model = InceptionResnetV1(pretrained='casia-webface').eval().to(self.device)
# mtcnn = MTCNN(
# image_size=160, margin=0, min_face_size=10,
# thresholds=[0.3, 0.3, 0.3], factor=0.709, post_process=True,
# device=self.device, keep_all=True
# )
# modelpath = Path("face_detection_yunet_2023mar_int8bq.onnx")
modelpath = Path("face_detection_yunet_2023mar_int8.onnx")
# model = YuNet(modelPath=args.model,
# inputSize=[320, 320],
# confThreshold=args.conf_threshold,
# nmsThreshold=args.nms_threshold,
# topK=args.top_k,
# backendId=backend_id,
# targetId=target_id)
detector = cv2.FaceDetectorYN.create(
str(modelpath),
"",
(320, 320),
.3,
.3,
5000,
cv2.dnn.DNN_BACKEND_CUDA,
target_id=cv2.dnn.DNN_TARGET_CUDA
)
while self.is_running.is_set():
with timer_counter.get_lock():
timer_counter.value += 1
poll_time = time.time()
zmq_ev = self.frame_sock.poll(timeout=2000)
if not zmq_ev:
logger.warning('skip poll after 2000ms')
# when there's no data after timeout, loop so that is_running is checked
continue
start_time = time.time()
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
# print(time.time()- frame.time)
if frame.index > (prev_frame_i+1):
logger.warning(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=}) -- poll time {start_time-poll_time:.5f}")
height, width, channels = frame.img.shape
detector.setInputSize((width//2, height//2))
img = cv2.resize(frame.img, (width//2, height//2))
faces = detector.detect(img)
prev_frame_i = frame.index
# print(f"send to {self.trajectory_socket}, {self.config.zmq_trajectory_addr}")
self.face_socket.send_pyobj(faces) # ditch image for faster passthrough
logger.info('Stopping')
def run_detector(config: Namespace, is_running: Event, timer_counter):
router = FaceDetector(config)
router.track(is_running, timer_counter)
def run():
# Frame emitter
import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument('--zmq-frame-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame")
argparser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds_traj")
argparser.add_argument("--save-for-training",
help="Specify the path in which to save",
type=Path,
default=None)
argparser.add_argument("--detector",
help="Specify the detector to use",
type=str,
default=DETECTOR_YOLOv8,
choices=DETECTORS)
argparser.add_argument("--tracker",
help="Specify the detector to use",
type=str,
default=TRACKER_BYTETRACK,
choices=TRACKERS)
argparser.add_argument("--smooth-tracks",
help="Smooth the tracker tracks before sending them to the predictor",
action='store_true')
config = argparser.parse_args()
is_running = multiprocessing.Event()
is_running.set()
timer_counter = timer.Timer('frame_emitter')
router = Tracker(config)
router.track(is_running, timer_counter.iterations)
is_running.clear()