Fix zmq-conflate and measure timings

This commit is contained in:
Ruben van de Ven 2023-10-13 23:13:12 +02:00
parent 23162da767
commit 27565d919e
3 changed files with 48 additions and 18 deletions

View File

@ -1,13 +1,21 @@
from argparse import Namespace
from dataclasses import dataclass, field
import logging
import pickle
import time
from typing import Optional
import numpy as np
import cv2
import zmq
logger = logging.getLogger('trap.frame_emitter')
@dataclass
class Frame:
img: np.array
time: float= field(default_factory=lambda: time.time())
trajectories: Optional[dict] = None
class FrameEmitter:
'''
Emit frame in a separate threat so they can be throttled,
@ -18,8 +26,8 @@ class FrameEmitter:
context = zmq.Context()
self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
logger.info(f"Connection socket {config.zmq_frame_addr}")
def emit_video(self):
@ -29,14 +37,14 @@ class FrameEmitter:
prev_time = time.time()
while True:
ret, frame = video.read()
ret, img = video.read()
# seek to 0 if video has finished. Infinite loop
if not ret:
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = video.read()
ret, img = video.read()
assert ret is not False # not really error proof...
frame = Frame(img=img)
# TODO: this is very dirty, need to find another way.
# perhaps multiprocessing queue?
self.frame_sock.send(pickle.dumps(frame))

View File

@ -2,6 +2,7 @@
import logging
from multiprocessing import Queue
import os
import pickle
import time
import json
import pandas as pd
@ -21,6 +22,8 @@ import matplotlib.pyplot as plt
import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.prediction")
@ -111,9 +114,9 @@ class InferenceServer:
context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
@ -220,8 +223,10 @@ class InferenceServer:
# node_data = pd.DataFrame(data_dict, columns=data_columns)
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
data = self.trajectory_socket.recv_string()
trajectory_data = json.loads(data)
data = self.trajectory_socket.recv()
frame: Frame = pickle.loads(data)
trajectory_data = frame.trajectories # TODO: properly refractor
# trajectory_data = json.loads(data)
logger.debug(f"Receive {trajectory_data}")
# class FakeNode:
@ -291,8 +296,8 @@ class InferenceServer:
start = time.time()
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=16, # TODO: make variable
num_samples=3, # TODO: make variable
prediction_horizon=10, # TODO: make variable
num_samples=2, # TODO: make variable
robot_present_and_future=robot_present_and_future,
full_dist=True)
end = time.time()
@ -343,6 +348,7 @@ class InferenceServer:
}
data = json.dumps(response)
logger.info(f"Frame delay = {time.time()-frame.time}s")
self.prediction_socket.send_string(data)

View File

@ -12,6 +12,8 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
from deep_sort_realtime.deepsort_tracker import DeepSort
from deep_sort_realtime.deep_sort.track import Track
from trap.frame_emitter import Frame
Detection = [int, int, int, int, float, int]
Detections = [Detection]
@ -23,13 +25,13 @@ class Tracker:
context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.connect(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.frame_sock.connect(config.zmq_frame_addr)
self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.trajectory_socket.bind(config.zmq_trajectory_addr)
# # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -53,9 +55,19 @@ class Tracker:
def track(self):
while True:
frame = pickle.loads(self.frame_sock.recv())
detections = self.detect_persons(frame)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
msg = self.frame_sock.recv()
# after block, exhaust the queue: (superfluous now that CONFLATE is before the connect)
# i = 1
# while True:
# try:
# msg = self.frame_sock.recv(zmq.NOBLOCK)
# i+=1
# except Exception as e:
# break
frame: Frame = pickle.loads(msg)
logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
detections = self.detect_persons(frame.img)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
TEMP_boxes = [t.to_ltwh() for t in tracks]
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
@ -73,7 +85,11 @@ class Tracker:
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
}
logger.debug(f"{trajectories}")
self.trajectory_socket.send_string(json.dumps(trajectories))
frame.trajectories = trajectories
logger.info(f"trajectory delay = {time.time()-frame.time}s")
self.trajectory_socket.send(pickle.dumps(frame))
# self.trajectory_socket.send_string(json.dumps(trajectories))
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
# TODO: provide a track object that actually keeps history (unlike tracker)