Test with alternative NN for prediction
This commit is contained in:
parent
ccddc71f83
commit
e3224aa47f
4 changed files with 2558 additions and 18 deletions
File diff suppressed because one or more lines are too long
2301
test_custom_rnn_lstm-attention.ipynb
Normal file
2301
test_custom_rnn_lstm-attention.ipynb
Normal file
File diff suppressed because one or more lines are too long
68
trap/counter.py
Normal file
68
trap/counter.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import collections
|
||||||
|
import logging
|
||||||
|
import statistics
|
||||||
|
import time
|
||||||
|
from typing import MutableSequence
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
logger = logging.getLogger('counter')
|
||||||
|
|
||||||
|
class CounterSender:
|
||||||
|
def __init__(self, address = "ipc:///tmp/trap-counters"):
|
||||||
|
# self.name = name
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.sock = self.context.socket(zmq.PUB)
|
||||||
|
# self.sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||||
|
# self.sock.sndhwm = 1
|
||||||
|
self.sock.bind(address)
|
||||||
|
|
||||||
|
def set(self, name:str, value:float):
|
||||||
|
try:
|
||||||
|
self.sock.send_multipart([name.encode('utf8'), str(value).encode("utf8")], flags=zmq.NOBLOCK)
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
logger.warning(f"No space in que to count {name} as {value}")
|
||||||
|
|
||||||
|
class CounterLog():
|
||||||
|
def __init__(self, history = 20):
|
||||||
|
self.history: MutableSequence[(float, float)] = collections.deque(maxlen=history)
|
||||||
|
|
||||||
|
def add(self, value):
|
||||||
|
self.history.append((time.perf_counter(), value))
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
if not len(self.history):
|
||||||
|
return None
|
||||||
|
return self.history[-1][1]
|
||||||
|
|
||||||
|
def avg(self):
|
||||||
|
if not len(self.history):
|
||||||
|
return 0.
|
||||||
|
return statistics.fmean([h[1] for h in self.history])
|
||||||
|
|
||||||
|
class CounterListerner():
|
||||||
|
def __init__(self, address = "ipc:///tmp/trap-counters"):
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.sock = self.context.socket(zmq.SUB)
|
||||||
|
self.sock.connect(address)
|
||||||
|
self.sock.subscribe( b'')
|
||||||
|
self.values: collections.defaultdict[str, CounterLog] = collections.defaultdict(lambda: CounterLog())
|
||||||
|
|
||||||
|
def snapshot(self):
|
||||||
|
messages = []
|
||||||
|
while self.sock.poll(0) == zmq.POLLIN:
|
||||||
|
name, value = self.sock.recv_multipart()
|
||||||
|
name, value = name.decode('utf8'),float(value.decode('utf8'))
|
||||||
|
self.values[name].add(value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest(self):
|
||||||
|
self.snapshot()
|
||||||
|
return self.values
|
||||||
|
|
||||||
|
def to_string(self):
|
||||||
|
strs = [f"{k}: {v.value()} ({v.avg()})" for (k,v) in self.values.items()]
|
||||||
|
return " ".join(strs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
170
trap/face_detector.py
Normal file
170
trap/face_detector.py
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
from argparse import Namespace
|
||||||
|
from collections import defaultdict
|
||||||
|
import csv
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from math import nan
|
||||||
|
from multiprocessing import Event
|
||||||
|
import multiprocessing
|
||||||
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from typing import DefaultDict, Dict, Optional, List
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import ultralytics
|
||||||
|
import zmq
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from facenet_pytorch import InceptionResnetV1, MTCNN
|
||||||
|
|
||||||
|
from trap.base import Frame
|
||||||
|
|
||||||
|
logger = logging.getLogger('trap.face_detector')
|
||||||
|
|
||||||
|
class FaceDetector:
|
||||||
|
def __init__(self, config: Namespace):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.frame_sock = self.context.socket(zmq.SUB)
|
||||||
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
self.frame_sock.connect(self.config.zmq_frame_addr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.face_socket = self.context.socket(zmq.PUB)
|
||||||
|
self.face_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||||
|
self.face_socket.bind(self.config.zmq_face_addr)
|
||||||
|
|
||||||
|
|
||||||
|
# # TODO: config device
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def track(self, is_running: Event, timer_counter: int = 0):
|
||||||
|
"""
|
||||||
|
Live tracking of frames coming in over zmq
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
|
prev_frame_i = -1
|
||||||
|
|
||||||
|
# For a model pretrained on CASIA-Webface
|
||||||
|
# model = InceptionResnetV1(pretrained='casia-webface').eval().to(self.device)
|
||||||
|
# mtcnn = MTCNN(
|
||||||
|
# image_size=160, margin=0, min_face_size=10,
|
||||||
|
# thresholds=[0.3, 0.3, 0.3], factor=0.709, post_process=True,
|
||||||
|
# device=self.device, keep_all=True
|
||||||
|
|
||||||
|
# )
|
||||||
|
# modelpath = Path("face_detection_yunet_2023mar_int8bq.onnx")
|
||||||
|
modelpath = Path("face_detection_yunet_2023mar_int8.onnx")
|
||||||
|
# model = YuNet(modelPath=args.model,
|
||||||
|
# inputSize=[320, 320],
|
||||||
|
# confThreshold=args.conf_threshold,
|
||||||
|
# nmsThreshold=args.nms_threshold,
|
||||||
|
# topK=args.top_k,
|
||||||
|
# backendId=backend_id,
|
||||||
|
# targetId=target_id)
|
||||||
|
detector = cv2.FaceDetectorYN.create(
|
||||||
|
str(modelpath),
|
||||||
|
"",
|
||||||
|
(320, 320),
|
||||||
|
.3,
|
||||||
|
.3,
|
||||||
|
5000,
|
||||||
|
cv2.dnn.DNN_BACKEND_CUDA,
|
||||||
|
target_id=cv2.dnn.DNN_TARGET_CUDA
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
while self.is_running.is_set():
|
||||||
|
|
||||||
|
with timer_counter.get_lock():
|
||||||
|
timer_counter.value += 1
|
||||||
|
|
||||||
|
poll_time = time.time()
|
||||||
|
zmq_ev = self.frame_sock.poll(timeout=2000)
|
||||||
|
if not zmq_ev:
|
||||||
|
logger.warning('skip poll after 2000ms')
|
||||||
|
# when there's no data after timeout, loop so that is_running is checked
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||||
|
|
||||||
|
# print(time.time()- frame.time)
|
||||||
|
|
||||||
|
if frame.index > (prev_frame_i+1):
|
||||||
|
logger.warning(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=}) -- poll time {start_time-poll_time:.5f}")
|
||||||
|
|
||||||
|
height, width, channels = frame.img.shape
|
||||||
|
|
||||||
|
detector.setInputSize((width//2, height//2))
|
||||||
|
|
||||||
|
img = cv2.resize(frame.img, (width//2, height//2))
|
||||||
|
|
||||||
|
faces = detector.detect(img)
|
||||||
|
|
||||||
|
prev_frame_i = frame.index
|
||||||
|
|
||||||
|
# print(f"send to {self.trajectory_socket}, {self.config.zmq_trajectory_addr}")
|
||||||
|
self.face_socket.send_pyobj(faces) # ditch image for faster passthrough
|
||||||
|
|
||||||
|
|
||||||
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_detector(config: Namespace, is_running: Event, timer_counter):
|
||||||
|
router = FaceDetector(config)
|
||||||
|
router.track(is_running, timer_counter)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
# Frame emitter
|
||||||
|
import argparse
|
||||||
|
argparser = argparse.ArgumentParser()
|
||||||
|
argparser.add_argument('--zmq-frame-addr',
|
||||||
|
help='Manually specity communication addr for the frame messages',
|
||||||
|
type=str,
|
||||||
|
default="ipc:///tmp/feeds_frame")
|
||||||
|
argparser.add_argument('--zmq-trajectory-addr',
|
||||||
|
help='Manually specity communication addr for the trajectory messages',
|
||||||
|
type=str,
|
||||||
|
default="ipc:///tmp/feeds_traj")
|
||||||
|
|
||||||
|
argparser.add_argument("--save-for-training",
|
||||||
|
help="Specify the path in which to save",
|
||||||
|
type=Path,
|
||||||
|
default=None)
|
||||||
|
argparser.add_argument("--detector",
|
||||||
|
help="Specify the detector to use",
|
||||||
|
type=str,
|
||||||
|
default=DETECTOR_YOLOv8,
|
||||||
|
choices=DETECTORS)
|
||||||
|
argparser.add_argument("--tracker",
|
||||||
|
help="Specify the detector to use",
|
||||||
|
type=str,
|
||||||
|
default=TRACKER_BYTETRACK,
|
||||||
|
choices=TRACKERS)
|
||||||
|
argparser.add_argument("--smooth-tracks",
|
||||||
|
help="Smooth the tracker tracks before sending them to the predictor",
|
||||||
|
action='store_true')
|
||||||
|
config = argparser.parse_args()
|
||||||
|
is_running = multiprocessing.Event()
|
||||||
|
is_running.set()
|
||||||
|
timer_counter = timer.Timer('frame_emitter')
|
||||||
|
|
||||||
|
router = Tracker(config)
|
||||||
|
router.track(is_running, timer_counter.iterations)
|
||||||
|
is_running.clear()
|
||||||
|
|
Loading…
Reference in a new issue