218 lines
7.4 KiB
Python
218 lines
7.4 KiB
Python
from argparse import Namespace
|
|
from dataclasses import dataclass, field
|
|
from enum import IntFlag
|
|
from itertools import cycle
|
|
import logging
|
|
from multiprocessing import Event
|
|
from pathlib import Path
|
|
import pickle
|
|
import sys
|
|
import time
|
|
from typing import Iterable, Optional
|
|
import numpy as np
|
|
import cv2
|
|
import zmq
|
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
|
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
|
|
|
|
logger = logging.getLogger('trap.frame_emitter')
|
|
|
|
class DetectionState(IntFlag):
|
|
Tentative = 1 # state before n_init (see DeepsortTrack)
|
|
Confirmed = 2 # after tentative
|
|
Lost = 4 # lost when DeepsortTrack.time_since_update > 0 but not Deleted
|
|
|
|
@classmethod
|
|
def from_deepsort_track(cls, track: DeepsortTrack):
|
|
if track.state == DeepsortTrackState.Tentative:
|
|
return cls.Tentative
|
|
if track.state == DeepsortTrackState.Confirmed:
|
|
if track.time_since_update > 0:
|
|
return cls.Lost
|
|
return cls.Confirmed
|
|
raise RuntimeError("Should not run into Deleted entries here")
|
|
|
|
|
|
@dataclass
|
|
class Detection:
|
|
track_id: str # deepsort track id association
|
|
l: int # left - image space
|
|
t: int # top - image space
|
|
w: int # width - image space
|
|
h: int # height - image space
|
|
conf: float # object detector probablity
|
|
state: DetectionState
|
|
frame_nr: int
|
|
|
|
def get_foot_coords(self):
|
|
return [self.l + 0.5 * self.w, self.t+self.h]
|
|
|
|
@classmethod
|
|
def from_deepsort(cls, dstrack: DeepsortTrack):
|
|
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf, DetectionState.from_deepsort_track(dstrack))
|
|
|
|
def get_scaled(self, scale: float = 1):
|
|
if scale == 1:
|
|
return self
|
|
|
|
return Detection(
|
|
self.track_id,
|
|
self.l*scale,
|
|
self.t*scale,
|
|
self.w*scale,
|
|
self.h*scale,
|
|
self.conf,
|
|
self.state)
|
|
|
|
def to_ltwh(self):
|
|
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
|
|
|
def to_ltrb(self):
|
|
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
|
|
|
|
|
|
|
|
@dataclass
|
|
class Track:
|
|
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
|
a history, with which the predictor can work, as we then can deduce velocity
|
|
and acceleration.
|
|
"""
|
|
track_id: str = None
|
|
history: [Detection] = field(default_factory=lambda: [])
|
|
predictor_history: Optional[list] = None # in image space
|
|
predictions: Optional[list] = None
|
|
|
|
def get_projected_history(self, H) -> np.array:
|
|
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
|
|
|
if len(foot_coordinates):
|
|
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
|
return coords[0]
|
|
return np.array([])
|
|
|
|
def get_projected_history_as_dict(self, H) -> dict:
|
|
coords = self.get_projected_history(H)
|
|
return [{"x":c[0], "y":c[1]} for c in coords]
|
|
|
|
|
|
|
|
@dataclass
|
|
class Frame:
|
|
index: int
|
|
img: np.array
|
|
time: float= field(default_factory=lambda: time.time())
|
|
tracks: Optional[dict[str, Track]] = None
|
|
H: Optional[np.array] = None
|
|
|
|
def aslist(self) -> [dict]:
|
|
return { t.track_id:
|
|
{
|
|
'id': t.track_id,
|
|
'history': t.get_projected_history(self.H).tolist(),
|
|
'det_conf': t.history[-1].conf,
|
|
# 'det_conf': trajectory_data[node.id]['det_conf'],
|
|
# 'bbox': trajectory_data[node.id]['bbox'],
|
|
# 'history': history.tolist(),
|
|
'predictions': t.predictions
|
|
} for t in self.tracks.values()
|
|
}
|
|
|
|
class FrameEmitter:
|
|
'''
|
|
Emit frame in a separate threat so they can be throttled,
|
|
or thrown away when the rest of the system cannot keep up
|
|
'''
|
|
def __init__(self, config: Namespace, is_running: Event) -> None:
|
|
self.config = config
|
|
self.is_running = is_running
|
|
|
|
context = zmq.Context()
|
|
# TODO: to make things faster, a multiprocessing.Array might be a tad faster: https://stackoverflow.com/a/65201859
|
|
self.frame_sock = context.socket(zmq.PUB)
|
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
|
|
self.frame_sock.bind(config.zmq_frame_addr)
|
|
|
|
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
|
|
|
if self.config.video_loop:
|
|
self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
|
|
else:
|
|
self.video_srcs: [Path] = self.config.video_src
|
|
|
|
|
|
def emit_video(self):
|
|
i = 0
|
|
for video_path in self.video_srcs:
|
|
logger.info(f"Play from '{str(video_path)}'")
|
|
video = cv2.VideoCapture(str(video_path))
|
|
fps = video.get(cv2.CAP_PROP_FPS)
|
|
target_frame_duration = 1./fps
|
|
logger.info(f"Emit frames at {fps} fps")
|
|
|
|
if self.config.video_offset:
|
|
logger.info(f"Start at frame {self.config.video_offset}")
|
|
video.set(cv2.CAP_PROP_POS_FRAMES, self.config.video_offset)
|
|
i = self.config.video_offset
|
|
|
|
|
|
if '-' in video_path.stem:
|
|
path_stem = video_path.stem[:video_path.stem.rfind('-')]
|
|
else:
|
|
path_stem = video_path.stem
|
|
path_stem += "-homography"
|
|
homography_path = video_path.with_stem(path_stem).with_suffix('.txt')
|
|
logger.info(f'check homography file {homography_path}')
|
|
if homography_path.exists():
|
|
logger.info(f'Found custom homography file! Using {homography_path}')
|
|
video_H = np.loadtxt(homography_path, delimiter=',')
|
|
else:
|
|
video_H = None
|
|
|
|
prev_time = time.time()
|
|
|
|
while self.is_running.is_set():
|
|
ret, img = video.read()
|
|
|
|
# seek to 0 if video has finished. Infinite loop
|
|
if not ret:
|
|
# now loading multiple files
|
|
break
|
|
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
# ret, img = video.read()
|
|
# assert ret is not False # not really error proof...
|
|
|
|
|
|
if "DATASETS/hof/" in str(video_path):
|
|
# hack to mask out area
|
|
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
|
|
|
|
frame = Frame(index=i, img=img, H=video_H)
|
|
# TODO: this is very dirty, need to find another way.
|
|
# perhaps multiprocessing Array?
|
|
self.frame_sock.send(pickle.dumps(frame))
|
|
|
|
# defer next loop
|
|
now = time.time()
|
|
time_diff = (now - prev_time)
|
|
if time_diff < target_frame_duration:
|
|
time.sleep(target_frame_duration - time_diff)
|
|
now += target_frame_duration - time_diff
|
|
|
|
prev_time = now
|
|
|
|
i += 1
|
|
|
|
if not self.is_running.is_set():
|
|
# if not running, also break out of infinite generator loop
|
|
break
|
|
|
|
|
|
logger.info("Stopping")
|
|
|
|
|
|
|
|
def run_frame_emitter(config: Namespace, is_running: Event):
|
|
router = FrameEmitter(config, is_running)
|
|
router.emit_video()
|
|
is_running.clear() |