Refactor: prediction_server produces Frame objects
This commit is contained in:
parent
a3e42b4501
commit
44a618a5ee
4 changed files with 143 additions and 109 deletions
|
@ -11,15 +11,65 @@ from typing import Iterable, Optional
|
|||
import numpy as np
|
||||
import cv2
|
||||
import zmq
|
||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||
|
||||
logger = logging.getLogger('trap.frame_emitter')
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
track_id: str # deepsort track id association
|
||||
l: int # left - image space
|
||||
t: int # top - image space
|
||||
w: int # width - image space
|
||||
h: int # height - image space
|
||||
conf: float # object detector probablity
|
||||
|
||||
def get_foot_coords(self):
|
||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||
|
||||
@classmethod
|
||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
|
||||
|
||||
def to_ltwh(self):
|
||||
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
||||
|
||||
def to_ltrb(self):
|
||||
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Track:
|
||||
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
||||
a history, with which the predictor can work, as we then can deduce velocity
|
||||
and acceleration.
|
||||
"""
|
||||
track_id: str = None
|
||||
history: [Detection] = field(default_factory=lambda: [])
|
||||
predictor_history: Optional[list] = None # in image space
|
||||
predictions: Optional[list] = None
|
||||
|
||||
def get_projected_history(self, H) -> np.array:
|
||||
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
||||
|
||||
if len(foot_coordinates):
|
||||
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
||||
return coords[0]
|
||||
return np.array([])
|
||||
|
||||
def get_projected_history_as_dict(self, H) -> dict:
|
||||
coords = self.get_projected_history(H)
|
||||
return [{"x":c[0], "y":c[1]} for c in coords]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
index: int
|
||||
img: np.array
|
||||
time: float= field(default_factory=lambda: time.time())
|
||||
trajectories: Optional[dict] = None
|
||||
tracks: Optional[dict[str, Track]] = None
|
||||
H: Optional[np.array] = None
|
||||
|
||||
class FrameEmitter:
|
||||
'''
|
||||
|
|
|
@ -27,6 +27,7 @@ import matplotlib.pyplot as plt
|
|||
import zmq
|
||||
|
||||
from trap.frame_emitter import Frame
|
||||
from trap.tracker import Track
|
||||
|
||||
logger = logging.getLogger("trap.prediction")
|
||||
|
||||
|
@ -249,23 +250,23 @@ class PredictionServer:
|
|||
|
||||
data = self.trajectory_socket.recv()
|
||||
frame: Frame = pickle.loads(data)
|
||||
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
|
||||
# trajectory_data = json.loads(data)
|
||||
logger.debug(f"Receive {trajectory_data}")
|
||||
logger.debug(f"Receive {frame.index}")
|
||||
|
||||
# class FakeNode:
|
||||
# def __init__(self, node_type: NodeType):
|
||||
# self.type = node_type
|
||||
|
||||
input_dict = {}
|
||||
for identifier, trajectory in trajectory_data.items():
|
||||
for identifier, track in frame.tracks.items():
|
||||
# if len(trajectory['history']) < 7:
|
||||
# # TODO: these trajectories should still be in the output, but without predictions
|
||||
# continue
|
||||
|
||||
# TODO: modify this into a mapping function between JS data an the expected Node format
|
||||
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
||||
history = [[h['x'], h['y']] for h in trajectory['history']]
|
||||
history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H)]
|
||||
history = np.array(history)
|
||||
x = history[:, 0]
|
||||
y = history[:, 1]
|
||||
|
@ -301,7 +302,7 @@ class PredictionServer:
|
|||
# And want to update the network
|
||||
|
||||
data = json.dumps({})
|
||||
self.prediction_socket.send_string(data)
|
||||
self.prediction_socket.send_pyobj(frame)
|
||||
|
||||
continue
|
||||
|
||||
|
@ -364,26 +365,29 @@ class PredictionServer:
|
|||
|
||||
for node in histories_dict:
|
||||
history = histories_dict[node]
|
||||
# future = futures_dict[node]
|
||||
# future = futures_dict[node] # ground truth dict
|
||||
predictions = prediction_dict[node]
|
||||
|
||||
if not len(history) or np.isnan(history[-1]).any():
|
||||
continue
|
||||
|
||||
response[node.id] = {
|
||||
'id': node.id,
|
||||
'det_conf': trajectory_data[node.id]['det_conf'],
|
||||
'bbox': trajectory_data[node.id]['bbox'],
|
||||
'history': history.tolist(),
|
||||
'predictions': predictions[0].tolist() # use batch 0
|
||||
}
|
||||
# response[node.id] = {
|
||||
# 'id': node.id,
|
||||
# 'det_conf': trajectory_data[node.id]['det_conf'],
|
||||
# 'bbox': trajectory_data[node.id]['bbox'],
|
||||
# 'history': history.tolist(),
|
||||
# 'predictions': predictions[0].tolist() # use batch 0
|
||||
# }
|
||||
|
||||
data = json.dumps(response)
|
||||
frame.tracks[node.id].predictor_history = history.tolist()
|
||||
frame.tracks[node.id].predictions = predictions[0].tolist() # use batch 0
|
||||
|
||||
# data = json.dumps(response)
|
||||
if self.config.predict_training_data:
|
||||
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
||||
else:
|
||||
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||
self.prediction_socket.send_string(data)
|
||||
self.prediction_socket.send_pyobj(frame)
|
||||
logger.info('Stopping')
|
||||
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ class Renderer:
|
|||
|
||||
|
||||
def run(self):
|
||||
predictions = {}
|
||||
prediction_frame = None
|
||||
i=0
|
||||
first_time = None
|
||||
while self.is_running.is_set():
|
||||
|
@ -91,7 +91,7 @@ class Renderer:
|
|||
|
||||
frame: Frame = self.frame_sock.recv_pyobj()
|
||||
try:
|
||||
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
|
||||
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError as e:
|
||||
logger.debug(f'reuse prediction')
|
||||
|
||||
|
@ -106,48 +106,58 @@ class Renderer:
|
|||
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
|
||||
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
||||
|
||||
if not prediction_frame:
|
||||
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
continue
|
||||
else:
|
||||
for track_id, track in prediction_frame.tracks.items():
|
||||
if not len(track.history):
|
||||
continue
|
||||
|
||||
for track_id, prediction in predictions.items():
|
||||
if not 'history' in prediction or not len(prediction['history']):
|
||||
continue
|
||||
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||
coords = [d.get_foot_coords() for d in track.history]
|
||||
# logger.warning(f"{coords=}")
|
||||
|
||||
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||
# logger.warning(f"{coords=}")
|
||||
for ci in range(1, len(coords)):
|
||||
start = [int(p) for p in coords[ci-1]]
|
||||
end = [int(p) for p in coords[ci]]
|
||||
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
|
||||
|
||||
for ci in range(1, len(coords)):
|
||||
start = [int(p) for p in coords[ci-1]]
|
||||
end = [int(p) for p in coords[ci]]
|
||||
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
|
||||
if not track.predictions or not len(track.predictions):
|
||||
continue
|
||||
|
||||
if not 'predictions' in prediction or not len(prediction['predictions']):
|
||||
continue
|
||||
for pred_i, pred in enumerate(track.predictions):
|
||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||
color = (0,0,255) if pred_i == 1 else (100,100,100)
|
||||
for ci in range(1, len(pred_coords)):
|
||||
start = [int(p) for p in pred_coords[ci-1]]
|
||||
end = [int(p) for p in pred_coords[ci]]
|
||||
cv2.line(img, start, end, color, 1, lineType=cv2.LINE_AA)
|
||||
|
||||
for pred_i, pred in enumerate(prediction['predictions']):
|
||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||
color = (0,0,255) if pred_i == 1 else (100,100,100)
|
||||
for ci in range(1, len(pred_coords)):
|
||||
start = [int(p) for p in pred_coords[ci-1]]
|
||||
end = [int(p) for p in pred_coords[ci]]
|
||||
cv2.line(img, start, end, color, 1, lineType=cv2.LINE_AA)
|
||||
|
||||
for track_id, prediction in predictions.items():
|
||||
# draw tracker marker and track id last so it lies over the trajectories
|
||||
# this goes is a second loop so it overlays over _all_ trajectories
|
||||
coords = cv2.perspectiveTransform(np.array([[prediction['history'][-1]]]), self.inv_H)[0]
|
||||
|
||||
center = [int(p) for p in coords[-1]]
|
||||
cv2.circle(img, center, 5, (0,255,0))
|
||||
p1 = (prediction['bbox'][0], prediction['bbox'][1])
|
||||
p2 = (p1[0] + prediction['bbox'][2], p1[1] + prediction['bbox'][3])
|
||||
cv2.rectangle(img, p1, p2, (255,0,0), 1)
|
||||
cv2.putText(img, f"{track_id} ({(prediction['det_conf'] or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
|
||||
for track_id, track in prediction_frame.tracks.items():
|
||||
# draw tracker marker and track id last so it lies over the trajectories
|
||||
# this goes is a second loop so it overlays over _all_ trajectories
|
||||
# coords = cv2.perspectiveTransform(np.array([[track.history[-1].get_foot_coords()]]), self.inv_H)[0]
|
||||
coords = track.history[-1].get_foot_coords()
|
||||
|
||||
center = [int(p) for p in coords]
|
||||
cv2.circle(img, center, 5, (0,255,0))
|
||||
(l, t, r, b) = track.history[-1].to_ltrb()
|
||||
p1 = (l, t)
|
||||
p2 = (r, b)
|
||||
cv2.rectangle(img, p1, p2, (255,0,0), 1)
|
||||
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
|
||||
|
||||
if first_time is None:
|
||||
first_time = frame.time
|
||||
|
||||
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (100,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
|
||||
if prediction_frame:
|
||||
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
|
||||
|
||||
|
||||
|
||||
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from multiprocessing import Event
|
|||
from pathlib import Path
|
||||
import pickle
|
||||
import time
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
@ -21,7 +22,7 @@ from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
|||
from ultralytics import YOLO
|
||||
from ultralytics.engine.results import Results as YOLOResult
|
||||
|
||||
from trap.frame_emitter import Frame
|
||||
from trap.frame_emitter import Frame, Detection, Track
|
||||
|
||||
# Detection = [int, int, int, int, float, int]
|
||||
# Detections = [Detection]
|
||||
|
@ -41,41 +42,6 @@ DETECTOR_YOLOv8 = 'ultralytics'
|
|||
|
||||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
track_id: str
|
||||
l: int # left
|
||||
t: int # top
|
||||
w: int # width
|
||||
h: int # height
|
||||
conf: float #probablity
|
||||
|
||||
def get_foot_coords(self):
|
||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||
|
||||
@classmethod
|
||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
|
||||
|
||||
def to_ltwh(self):
|
||||
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Track:
|
||||
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
||||
a history, with which the predictor can work, as we then can deduce velocity
|
||||
and acceleration.
|
||||
"""
|
||||
track_id: str = None
|
||||
history: [Detection]= field(default_factory=lambda: [])
|
||||
|
||||
def get_projected_history(self, H):
|
||||
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
||||
|
||||
if len(foot_coordinates):
|
||||
return cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
||||
return np.array([])
|
||||
|
||||
|
||||
|
||||
|
@ -168,14 +134,15 @@ class Tracker:
|
|||
# prev_run_time = time.time()
|
||||
|
||||
start_time = time.time()
|
||||
msg = self.frame_sock.recv()
|
||||
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||
|
||||
if frame.index > (prev_frame_i+1):
|
||||
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
||||
|
||||
|
||||
prev_frame_i = frame.index
|
||||
# load homography into frame (TODO: should this be done in emitter?)
|
||||
frame.H = self.H
|
||||
|
||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
|
||||
|
@ -200,27 +167,30 @@ class Tracker:
|
|||
# track.history.pop(0)
|
||||
|
||||
|
||||
trajectories = {}
|
||||
for detection in detections:
|
||||
tid = str(detection.track_id)
|
||||
track = self.tracks[detection.track_id]
|
||||
coords = track.get_projected_history(self.H) # get full history
|
||||
trajectories[tid] = {
|
||||
"id": tid,
|
||||
"det_conf": detection.conf,
|
||||
"bbox": detection.to_ltwh(),
|
||||
"history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||
}
|
||||
# trajectories = {}
|
||||
# for detection in detections:
|
||||
# tid = str(detection.track_id)
|
||||
# track = self.tracks[detection.track_id]
|
||||
# coords = track.get_projected_history(self.H) # get full history
|
||||
# trajectories[tid] = {
|
||||
# "id": tid,
|
||||
# "det_conf": detection.conf,
|
||||
# "bbox": detection.to_ltwh(),
|
||||
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||
# }
|
||||
active_track_ids = [d.track_id for d in detections]
|
||||
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
|
||||
# logger.info(f"{trajectories}")
|
||||
frame.trajectories = trajectories
|
||||
frame.tracks = active_tracks
|
||||
|
||||
if self.config.bypass_prediction:
|
||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
else:
|
||||
self.trajectory_socket.send(pickle.dumps(frame))
|
||||
# if self.config.bypass_prediction:
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# else:
|
||||
# self.trajectory_socket.send(pickle.dumps(frame))
|
||||
self.trajectory_socket.send_pyobj(frame)
|
||||
|
||||
current_time = time.time()
|
||||
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||
|
@ -234,8 +204,8 @@ class Tracker:
|
|||
'track_id': t['id'],
|
||||
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
|
||||
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
|
||||
} for t in trajectories.values()])
|
||||
training_frames += len(trajectories)
|
||||
} for t in active_tracks.values()])
|
||||
training_frames += len(active_tracks)
|
||||
# print(time.time() - start_time)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue