diff --git a/trap/frame_emitter.py b/trap/frame_emitter.py index e0a50f1..bbcee05 100644 --- a/trap/frame_emitter.py +++ b/trap/frame_emitter.py @@ -1,13 +1,21 @@ from argparse import Namespace +from dataclasses import dataclass, field import logging import pickle import time - +from typing import Optional +import numpy as np import cv2 import zmq logger = logging.getLogger('trap.frame_emitter') +@dataclass +class Frame: + img: np.array + time: float= field(default_factory=lambda: time.time()) + trajectories: Optional[dict] = None + class FrameEmitter: ''' Emit frame in a separate threat so they can be throttled, @@ -18,8 +26,8 @@ class FrameEmitter: context = zmq.Context() self.frame_sock = context.socket(zmq.PUB) + self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind self.frame_sock.bind(config.zmq_frame_addr) - self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame logger.info(f"Connection socket {config.zmq_frame_addr}") def emit_video(self): @@ -29,14 +37,14 @@ class FrameEmitter: prev_time = time.time() while True: - ret, frame = video.read() + ret, img = video.read() # seek to 0 if video has finished. Infinite loop if not ret: video.set(cv2.CAP_PROP_POS_FRAMES, 0) - ret, frame = video.read() + ret, img = video.read() assert ret is not False # not really error proof... - + frame = Frame(img=img) # TODO: this is very dirty, need to find another way. # perhaps multiprocessing queue? self.frame_sock.send(pickle.dumps(frame)) diff --git a/trap/prediction_server.py b/trap/prediction_server.py index 10ac0b2..4e35c30 100644 --- a/trap/prediction_server.py +++ b/trap/prediction_server.py @@ -2,6 +2,7 @@ import logging from multiprocessing import Queue import os +import pickle import time import json import pandas as pd @@ -21,6 +22,8 @@ import matplotlib.pyplot as plt import zmq +from trap.frame_emitter import Frame + logger = logging.getLogger("trap.prediction") @@ -111,9 +114,9 @@ class InferenceServer: context = zmq.Context() self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB) - self.trajectory_socket.connect(config.zmq_trajectory_addr) self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'') - self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg + self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect! + self.trajectory_socket.connect(config.zmq_trajectory_addr) self.prediction_socket: zmq.Socket = context.socket(zmq.PUB) self.prediction_socket.bind(config.zmq_prediction_addr) @@ -220,8 +223,10 @@ class InferenceServer: # node_data = pd.DataFrame(data_dict, columns=data_columns) # node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data) - data = self.trajectory_socket.recv_string() - trajectory_data = json.loads(data) + data = self.trajectory_socket.recv() + frame: Frame = pickle.loads(data) + trajectory_data = frame.trajectories # TODO: properly refractor + # trajectory_data = json.loads(data) logger.debug(f"Receive {trajectory_data}") # class FakeNode: @@ -291,8 +296,8 @@ class InferenceServer: start = time.time() dists, preds = trajectron.incremental_forward(input_dict, maps, - prediction_horizon=16, # TODO: make variable - num_samples=3, # TODO: make variable + prediction_horizon=10, # TODO: make variable + num_samples=2, # TODO: make variable robot_present_and_future=robot_present_and_future, full_dist=True) end = time.time() @@ -343,6 +348,7 @@ class InferenceServer: } data = json.dumps(response) + logger.info(f"Frame delay = {time.time()-frame.time}s") self.prediction_socket.send_string(data) diff --git a/trap/tracker.py b/trap/tracker.py index 3f94a77..a05eaa1 100644 --- a/trap/tracker.py +++ b/trap/tracker.py @@ -12,6 +12,8 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re from deep_sort_realtime.deepsort_tracker import DeepSort from deep_sort_realtime.deep_sort.track import Track +from trap.frame_emitter import Frame + Detection = [int, int, int, int, float, int] Detections = [Detection] @@ -23,13 +25,13 @@ class Tracker: context = zmq.Context() self.frame_sock = context.socket(zmq.SUB) - self.frame_sock.connect(config.zmq_frame_addr) + self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'') - self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame + self.frame_sock.connect(config.zmq_frame_addr) self.trajectory_socket = context.socket(zmq.PUB) - self.trajectory_socket.bind(config.zmq_trajectory_addr) self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame + self.trajectory_socket.bind(config.zmq_trajectory_addr) # # TODO: config device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -53,9 +55,19 @@ class Tracker: def track(self): while True: - frame = pickle.loads(self.frame_sock.recv()) - detections = self.detect_persons(frame) - tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame) + msg = self.frame_sock.recv() + # after block, exhaust the queue: (superfluous now that CONFLATE is before the connect) + # i = 1 + # while True: + # try: + # msg = self.frame_sock.recv(zmq.NOBLOCK) + # i+=1 + # except Exception as e: + # break + frame: Frame = pickle.loads(msg) + logger.info(f"Frame delivery delay = {time.time()-frame.time}s") + detections = self.detect_persons(frame.img) + tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img) TEMP_boxes = [t.to_ltwh() for t in tracks] TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes]) @@ -73,7 +85,11 @@ class Tracker: "history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test } logger.debug(f"{trajectories}") - self.trajectory_socket.send_string(json.dumps(trajectories)) + frame.trajectories = trajectories + + logger.info(f"trajectory delay = {time.time()-frame.time}s") + self.trajectory_socket.send(pickle.dumps(frame)) + # self.trajectory_socket.send_string(json.dumps(trajectories)) # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}} # TODO: provide a track object that actually keeps history (unlike tracker)