trap/trap/frame_emitter.py

218 lines
7.4 KiB
Python

from argparse import Namespace
from dataclasses import dataclass, field
from enum import IntFlag
from itertools import cycle
import logging
from multiprocessing import Event
from pathlib import Path
import pickle
import sys
import time
from typing import Iterable, Optional
import numpy as np
import cv2
import zmq
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
logger = logging.getLogger('trap.frame_emitter')
class DetectionState(IntFlag):
Tentative = 1 # state before n_init (see DeepsortTrack)
Confirmed = 2 # after tentative
Lost = 4 # lost when DeepsortTrack.time_since_update > 0 but not Deleted
@classmethod
def from_deepsort_track(cls, track: DeepsortTrack):
if track.state == DeepsortTrackState.Tentative:
return cls.Tentative
if track.state == DeepsortTrackState.Confirmed:
if track.time_since_update > 0:
return cls.Lost
return cls.Confirmed
raise RuntimeError("Should not run into Deleted entries here")
@dataclass
class Detection:
track_id: str # deepsort track id association
l: int # left - image space
t: int # top - image space
w: int # width - image space
h: int # height - image space
conf: float # object detector probablity
state: DetectionState
frame_nr: int
def get_foot_coords(self):
return [self.l + 0.5 * self.w, self.t+self.h]
@classmethod
def from_deepsort(cls, dstrack: DeepsortTrack):
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf, DetectionState.from_deepsort_track(dstrack))
def get_scaled(self, scale: float = 1):
if scale == 1:
return self
return Detection(
self.track_id,
self.l*scale,
self.t*scale,
self.w*scale,
self.h*scale,
self.conf,
self.state)
def to_ltwh(self):
return (int(self.l), int(self.t), int(self.w), int(self.h))
def to_ltrb(self):
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
@dataclass
class Track:
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
a history, with which the predictor can work, as we then can deduce velocity
and acceleration.
"""
track_id: str = None
history: [Detection] = field(default_factory=lambda: [])
predictor_history: Optional[list] = None # in image space
predictions: Optional[list] = None
def get_projected_history(self, H) -> np.array:
foot_coordinates = [d.get_foot_coords() for d in self.history]
if len(foot_coordinates):
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
return coords[0]
return np.array([])
def get_projected_history_as_dict(self, H) -> dict:
coords = self.get_projected_history(H)
return [{"x":c[0], "y":c[1]} for c in coords]
@dataclass
class Frame:
index: int
img: np.array
time: float= field(default_factory=lambda: time.time())
tracks: Optional[dict[str, Track]] = None
H: Optional[np.array] = None
def aslist(self) -> [dict]:
return { t.track_id:
{
'id': t.track_id,
'history': t.get_projected_history(self.H).tolist(),
'det_conf': t.history[-1].conf,
# 'det_conf': trajectory_data[node.id]['det_conf'],
# 'bbox': trajectory_data[node.id]['bbox'],
# 'history': history.tolist(),
'predictions': t.predictions
} for t in self.tracks.values()
}
class FrameEmitter:
'''
Emit frame in a separate threat so they can be throttled,
or thrown away when the rest of the system cannot keep up
'''
def __init__(self, config: Namespace, is_running: Event) -> None:
self.config = config
self.is_running = is_running
context = zmq.Context()
# TODO: to make things faster, a multiprocessing.Array might be a tad faster: https://stackoverflow.com/a/65201859
self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_sock.bind(config.zmq_frame_addr)
logger.info(f"Connection socket {config.zmq_frame_addr}")
if self.config.video_loop:
self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
else:
self.video_srcs: [Path] = self.config.video_src
def emit_video(self):
i = 0
for video_path in self.video_srcs:
logger.info(f"Play from '{str(video_path)}'")
video = cv2.VideoCapture(str(video_path))
fps = video.get(cv2.CAP_PROP_FPS)
target_frame_duration = 1./fps
logger.info(f"Emit frames at {fps} fps")
if self.config.video_offset:
logger.info(f"Start at frame {self.config.video_offset}")
video.set(cv2.CAP_PROP_POS_FRAMES, self.config.video_offset)
i = self.config.video_offset
if '-' in video_path.stem:
path_stem = video_path.stem[:video_path.stem.rfind('-')]
else:
path_stem = video_path.stem
path_stem += "-homography"
homography_path = video_path.with_stem(path_stem).with_suffix('.txt')
logger.info(f'check homography file {homography_path}')
if homography_path.exists():
logger.info(f'Found custom homography file! Using {homography_path}')
video_H = np.loadtxt(homography_path, delimiter=',')
else:
video_H = None
prev_time = time.time()
while self.is_running.is_set():
ret, img = video.read()
# seek to 0 if video has finished. Infinite loop
if not ret:
# now loading multiple files
break
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
# ret, img = video.read()
# assert ret is not False # not really error proof...
if "DATASETS/hof/" in str(video_path):
# hack to mask out area
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
frame = Frame(index=i, img=img, H=video_H)
# TODO: this is very dirty, need to find another way.
# perhaps multiprocessing Array?
self.frame_sock.send(pickle.dumps(frame))
# defer next loop
now = time.time()
time_diff = (now - prev_time)
if time_diff < target_frame_duration:
time.sleep(target_frame_duration - time_diff)
now += target_frame_duration - time_diff
prev_time = now
i += 1
if not self.is_running.is_set():
# if not running, also break out of infinite generator loop
break
logger.info("Stopping")
def run_frame_emitter(config: Namespace, is_running: Event):
router = FrameEmitter(config, is_running)
router.emit_video()
is_running.clear()