Test with alternative NN for prediction
This commit is contained in:
parent
ccddc71f83
commit
e3224aa47f
4 changed files with 2558 additions and 18 deletions
File diff suppressed because one or more lines are too long
2301
test_custom_rnn_lstm-attention.ipynb
Normal file
2301
test_custom_rnn_lstm-attention.ipynb
Normal file
File diff suppressed because one or more lines are too long
68
trap/counter.py
Normal file
68
trap/counter.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import collections
|
||||
import logging
|
||||
import statistics
|
||||
import time
|
||||
from typing import MutableSequence
|
||||
import zmq
|
||||
|
||||
logger = logging.getLogger('counter')
|
||||
|
||||
class CounterSender:
|
||||
def __init__(self, address = "ipc:///tmp/trap-counters"):
|
||||
# self.name = name
|
||||
self.context = zmq.Context()
|
||||
self.sock = self.context.socket(zmq.PUB)
|
||||
# self.sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
# self.sock.sndhwm = 1
|
||||
self.sock.bind(address)
|
||||
|
||||
def set(self, name:str, value:float):
|
||||
try:
|
||||
self.sock.send_multipart([name.encode('utf8'), str(value).encode("utf8")], flags=zmq.NOBLOCK)
|
||||
except zmq.ZMQError as e:
|
||||
logger.warning(f"No space in que to count {name} as {value}")
|
||||
|
||||
class CounterLog():
|
||||
def __init__(self, history = 20):
|
||||
self.history: MutableSequence[(float, float)] = collections.deque(maxlen=history)
|
||||
|
||||
def add(self, value):
|
||||
self.history.append((time.perf_counter(), value))
|
||||
|
||||
def value(self):
|
||||
if not len(self.history):
|
||||
return None
|
||||
return self.history[-1][1]
|
||||
|
||||
def avg(self):
|
||||
if not len(self.history):
|
||||
return 0.
|
||||
return statistics.fmean([h[1] for h in self.history])
|
||||
|
||||
class CounterListerner():
|
||||
def __init__(self, address = "ipc:///tmp/trap-counters"):
|
||||
self.context = zmq.Context()
|
||||
self.sock = self.context.socket(zmq.SUB)
|
||||
self.sock.connect(address)
|
||||
self.sock.subscribe( b'')
|
||||
self.values: collections.defaultdict[str, CounterLog] = collections.defaultdict(lambda: CounterLog())
|
||||
|
||||
def snapshot(self):
|
||||
messages = []
|
||||
while self.sock.poll(0) == zmq.POLLIN:
|
||||
name, value = self.sock.recv_multipart()
|
||||
name, value = name.decode('utf8'),float(value.decode('utf8'))
|
||||
self.values[name].add(value)
|
||||
|
||||
|
||||
def get_latest(self):
|
||||
self.snapshot()
|
||||
return self.values
|
||||
|
||||
def to_string(self):
|
||||
strs = [f"{k}: {v.value()} ({v.avg()})" for (k,v) in self.values.items()]
|
||||
return " ".join(strs)
|
||||
|
||||
|
||||
|
||||
|
170
trap/face_detector.py
Normal file
170
trap/face_detector.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
import csv
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
from math import nan
|
||||
from multiprocessing import Event
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import time
|
||||
from typing import DefaultDict, Dict, Optional, List
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import ultralytics
|
||||
import zmq
|
||||
import cv2
|
||||
|
||||
from facenet_pytorch import InceptionResnetV1, MTCNN
|
||||
|
||||
from trap.base import Frame
|
||||
|
||||
logger = logging.getLogger('trap.face_detector')
|
||||
|
||||
class FaceDetector:
|
||||
def __init__(self, config: Namespace):
|
||||
self.config = config
|
||||
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.frame_sock = self.context.socket(zmq.SUB)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.frame_sock.connect(self.config.zmq_frame_addr)
|
||||
|
||||
|
||||
|
||||
self.face_socket = self.context.socket(zmq.PUB)
|
||||
self.face_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
self.face_socket.bind(self.config.zmq_face_addr)
|
||||
|
||||
|
||||
# # TODO: config device
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
|
||||
def track(self, is_running: Event, timer_counter: int = 0):
|
||||
"""
|
||||
Live tracking of frames coming in over zmq
|
||||
"""
|
||||
|
||||
self.is_running = is_running
|
||||
|
||||
prev_frame_i = -1
|
||||
|
||||
# For a model pretrained on CASIA-Webface
|
||||
# model = InceptionResnetV1(pretrained='casia-webface').eval().to(self.device)
|
||||
# mtcnn = MTCNN(
|
||||
# image_size=160, margin=0, min_face_size=10,
|
||||
# thresholds=[0.3, 0.3, 0.3], factor=0.709, post_process=True,
|
||||
# device=self.device, keep_all=True
|
||||
|
||||
# )
|
||||
# modelpath = Path("face_detection_yunet_2023mar_int8bq.onnx")
|
||||
modelpath = Path("face_detection_yunet_2023mar_int8.onnx")
|
||||
# model = YuNet(modelPath=args.model,
|
||||
# inputSize=[320, 320],
|
||||
# confThreshold=args.conf_threshold,
|
||||
# nmsThreshold=args.nms_threshold,
|
||||
# topK=args.top_k,
|
||||
# backendId=backend_id,
|
||||
# targetId=target_id)
|
||||
detector = cv2.FaceDetectorYN.create(
|
||||
str(modelpath),
|
||||
"",
|
||||
(320, 320),
|
||||
.3,
|
||||
.3,
|
||||
5000,
|
||||
cv2.dnn.DNN_BACKEND_CUDA,
|
||||
target_id=cv2.dnn.DNN_TARGET_CUDA
|
||||
)
|
||||
|
||||
|
||||
|
||||
while self.is_running.is_set():
|
||||
|
||||
with timer_counter.get_lock():
|
||||
timer_counter.value += 1
|
||||
|
||||
poll_time = time.time()
|
||||
zmq_ev = self.frame_sock.poll(timeout=2000)
|
||||
if not zmq_ev:
|
||||
logger.warning('skip poll after 2000ms')
|
||||
# when there's no data after timeout, loop so that is_running is checked
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||
|
||||
# print(time.time()- frame.time)
|
||||
|
||||
if frame.index > (prev_frame_i+1):
|
||||
logger.warning(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=}) -- poll time {start_time-poll_time:.5f}")
|
||||
|
||||
height, width, channels = frame.img.shape
|
||||
|
||||
detector.setInputSize((width//2, height//2))
|
||||
|
||||
img = cv2.resize(frame.img, (width//2, height//2))
|
||||
|
||||
faces = detector.detect(img)
|
||||
|
||||
prev_frame_i = frame.index
|
||||
|
||||
# print(f"send to {self.trajectory_socket}, {self.config.zmq_trajectory_addr}")
|
||||
self.face_socket.send_pyobj(faces) # ditch image for faster passthrough
|
||||
|
||||
|
||||
logger.info('Stopping')
|
||||
|
||||
|
||||
|
||||
def run_detector(config: Namespace, is_running: Event, timer_counter):
|
||||
router = FaceDetector(config)
|
||||
router.track(is_running, timer_counter)
|
||||
|
||||
def run():
|
||||
# Frame emitter
|
||||
import argparse
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--zmq-frame-addr',
|
||||
help='Manually specity communication addr for the frame messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_frame")
|
||||
argparser.add_argument('--zmq-trajectory-addr',
|
||||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
|
||||
argparser.add_argument("--save-for-training",
|
||||
help="Specify the path in which to save",
|
||||
type=Path,
|
||||
default=None)
|
||||
argparser.add_argument("--detector",
|
||||
help="Specify the detector to use",
|
||||
type=str,
|
||||
default=DETECTOR_YOLOv8,
|
||||
choices=DETECTORS)
|
||||
argparser.add_argument("--tracker",
|
||||
help="Specify the detector to use",
|
||||
type=str,
|
||||
default=TRACKER_BYTETRACK,
|
||||
choices=TRACKERS)
|
||||
argparser.add_argument("--smooth-tracks",
|
||||
help="Smooth the tracker tracks before sending them to the predictor",
|
||||
action='store_true')
|
||||
config = argparser.parse_args()
|
||||
is_running = multiprocessing.Event()
|
||||
is_running.set()
|
||||
timer_counter = timer.Timer('frame_emitter')
|
||||
|
||||
router = Tracker(config)
|
||||
router.track(is_running, timer_counter.iterations)
|
||||
is_running.clear()
|
||||
|
Loading…
Reference in a new issue