Fix zmq-conflate and measure timings
This commit is contained in:
parent
23162da767
commit
27565d919e
3 changed files with 48 additions and 18 deletions
|
@ -1,13 +1,21 @@
|
|||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import cv2
|
||||
import zmq
|
||||
|
||||
logger = logging.getLogger('trap.frame_emitter')
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
img: np.array
|
||||
time: float= field(default_factory=lambda: time.time())
|
||||
trajectories: Optional[dict] = None
|
||||
|
||||
class FrameEmitter:
|
||||
'''
|
||||
Emit frame in a separate threat so they can be throttled,
|
||||
|
@ -18,8 +26,8 @@ class FrameEmitter:
|
|||
|
||||
context = zmq.Context()
|
||||
self.frame_sock = context.socket(zmq.PUB)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
|
||||
self.frame_sock.bind(config.zmq_frame_addr)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||
|
||||
def emit_video(self):
|
||||
|
@ -29,14 +37,14 @@ class FrameEmitter:
|
|||
|
||||
prev_time = time.time()
|
||||
while True:
|
||||
ret, frame = video.read()
|
||||
ret, img = video.read()
|
||||
|
||||
# seek to 0 if video has finished. Infinite loop
|
||||
if not ret:
|
||||
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
ret, frame = video.read()
|
||||
ret, img = video.read()
|
||||
assert ret is not False # not really error proof...
|
||||
|
||||
frame = Frame(img=img)
|
||||
# TODO: this is very dirty, need to find another way.
|
||||
# perhaps multiprocessing queue?
|
||||
self.frame_sock.send(pickle.dumps(frame))
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import logging
|
||||
from multiprocessing import Queue
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import json
|
||||
import pandas as pd
|
||||
|
@ -21,6 +22,8 @@ import matplotlib.pyplot as plt
|
|||
|
||||
import zmq
|
||||
|
||||
from trap.frame_emitter import Frame
|
||||
|
||||
logger = logging.getLogger("trap.prediction")
|
||||
|
||||
|
||||
|
@ -111,9 +114,9 @@ class InferenceServer:
|
|||
|
||||
context = zmq.Context()
|
||||
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
||||
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
||||
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
|
||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
|
||||
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
||||
|
||||
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
||||
self.prediction_socket.bind(config.zmq_prediction_addr)
|
||||
|
@ -220,8 +223,10 @@ class InferenceServer:
|
|||
# node_data = pd.DataFrame(data_dict, columns=data_columns)
|
||||
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
|
||||
|
||||
data = self.trajectory_socket.recv_string()
|
||||
trajectory_data = json.loads(data)
|
||||
data = self.trajectory_socket.recv()
|
||||
frame: Frame = pickle.loads(data)
|
||||
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||
# trajectory_data = json.loads(data)
|
||||
logger.debug(f"Receive {trajectory_data}")
|
||||
|
||||
# class FakeNode:
|
||||
|
@ -291,8 +296,8 @@ class InferenceServer:
|
|||
start = time.time()
|
||||
dists, preds = trajectron.incremental_forward(input_dict,
|
||||
maps,
|
||||
prediction_horizon=16, # TODO: make variable
|
||||
num_samples=3, # TODO: make variable
|
||||
prediction_horizon=10, # TODO: make variable
|
||||
num_samples=2, # TODO: make variable
|
||||
robot_present_and_future=robot_present_and_future,
|
||||
full_dist=True)
|
||||
end = time.time()
|
||||
|
@ -343,6 +348,7 @@ class InferenceServer:
|
|||
}
|
||||
|
||||
data = json.dumps(response)
|
||||
logger.info(f"Frame delay = {time.time()-frame.time}s")
|
||||
self.prediction_socket.send_string(data)
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
|
|||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
from deep_sort_realtime.deep_sort.track import Track
|
||||
|
||||
from trap.frame_emitter import Frame
|
||||
|
||||
Detection = [int, int, int, int, float, int]
|
||||
Detections = [Detection]
|
||||
|
||||
|
@ -23,13 +25,13 @@ class Tracker:
|
|||
|
||||
context = zmq.Context()
|
||||
self.frame_sock = context.socket(zmq.SUB)
|
||||
self.frame_sock.connect(config.zmq_frame_addr)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
self.frame_sock.connect(config.zmq_frame_addr)
|
||||
|
||||
self.trajectory_socket = context.socket(zmq.PUB)
|
||||
self.trajectory_socket.bind(config.zmq_trajectory_addr)
|
||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
self.trajectory_socket.bind(config.zmq_trajectory_addr)
|
||||
|
||||
# # TODO: config device
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
@ -53,9 +55,19 @@ class Tracker:
|
|||
|
||||
def track(self):
|
||||
while True:
|
||||
frame = pickle.loads(self.frame_sock.recv())
|
||||
detections = self.detect_persons(frame)
|
||||
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
|
||||
msg = self.frame_sock.recv()
|
||||
# after block, exhaust the queue: (superfluous now that CONFLATE is before the connect)
|
||||
# i = 1
|
||||
# while True:
|
||||
# try:
|
||||
# msg = self.frame_sock.recv(zmq.NOBLOCK)
|
||||
# i+=1
|
||||
# except Exception as e:
|
||||
# break
|
||||
frame: Frame = pickle.loads(msg)
|
||||
logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
detections = self.detect_persons(frame.img)
|
||||
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
|
||||
|
||||
TEMP_boxes = [t.to_ltwh() for t in tracks]
|
||||
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
|
||||
|
@ -73,7 +85,11 @@ class Tracker:
|
|||
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
|
||||
}
|
||||
logger.debug(f"{trajectories}")
|
||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
frame.trajectories = trajectories
|
||||
|
||||
logger.info(f"trajectory delay = {time.time()-frame.time}s")
|
||||
self.trajectory_socket.send(pickle.dumps(frame))
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||
|
||||
|
|
Loading…
Reference in a new issue