Fix zmq-conflate and measure timings

This commit is contained in:
Ruben van de Ven 2023-10-13 23:13:12 +02:00
parent 23162da767
commit 27565d919e
3 changed files with 48 additions and 18 deletions

View file

@ -1,13 +1,21 @@
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field
import logging import logging
import pickle import pickle
import time import time
from typing import Optional
import numpy as np
import cv2 import cv2
import zmq import zmq
logger = logging.getLogger('trap.frame_emitter') logger = logging.getLogger('trap.frame_emitter')
@dataclass
class Frame:
img: np.array
time: float= field(default_factory=lambda: time.time())
trajectories: Optional[dict] = None
class FrameEmitter: class FrameEmitter:
''' '''
Emit frame in a separate threat so they can be throttled, Emit frame in a separate threat so they can be throttled,
@ -18,8 +26,8 @@ class FrameEmitter:
context = zmq.Context() context = zmq.Context()
self.frame_sock = context.socket(zmq.PUB) self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_sock.bind(config.zmq_frame_addr) self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
logger.info(f"Connection socket {config.zmq_frame_addr}") logger.info(f"Connection socket {config.zmq_frame_addr}")
def emit_video(self): def emit_video(self):
@ -29,14 +37,14 @@ class FrameEmitter:
prev_time = time.time() prev_time = time.time()
while True: while True:
ret, frame = video.read() ret, img = video.read()
# seek to 0 if video has finished. Infinite loop # seek to 0 if video has finished. Infinite loop
if not ret: if not ret:
video.set(cv2.CAP_PROP_POS_FRAMES, 0) video.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = video.read() ret, img = video.read()
assert ret is not False # not really error proof... assert ret is not False # not really error proof...
frame = Frame(img=img)
# TODO: this is very dirty, need to find another way. # TODO: this is very dirty, need to find another way.
# perhaps multiprocessing queue? # perhaps multiprocessing queue?
self.frame_sock.send(pickle.dumps(frame)) self.frame_sock.send(pickle.dumps(frame))

View file

@ -2,6 +2,7 @@
import logging import logging
from multiprocessing import Queue from multiprocessing import Queue
import os import os
import pickle
import time import time
import json import json
import pandas as pd import pandas as pd
@ -21,6 +22,8 @@ import matplotlib.pyplot as plt
import zmq import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.prediction") logger = logging.getLogger("trap.prediction")
@ -111,9 +114,9 @@ class InferenceServer:
context = zmq.Context() context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB) self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'') self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB) self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr) self.prediction_socket.bind(config.zmq_prediction_addr)
@ -220,8 +223,10 @@ class InferenceServer:
# node_data = pd.DataFrame(data_dict, columns=data_columns) # node_data = pd.DataFrame(data_dict, columns=data_columns)
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data) # node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
data = self.trajectory_socket.recv_string() data = self.trajectory_socket.recv()
trajectory_data = json.loads(data) frame: Frame = pickle.loads(data)
trajectory_data = frame.trajectories # TODO: properly refractor
# trajectory_data = json.loads(data)
logger.debug(f"Receive {trajectory_data}") logger.debug(f"Receive {trajectory_data}")
# class FakeNode: # class FakeNode:
@ -291,8 +296,8 @@ class InferenceServer:
start = time.time() start = time.time()
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=16, # TODO: make variable prediction_horizon=10, # TODO: make variable
num_samples=3, # TODO: make variable num_samples=2, # TODO: make variable
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)
end = time.time() end = time.time()
@ -343,6 +348,7 @@ class InferenceServer:
} }
data = json.dumps(response) data = json.dumps(response)
logger.info(f"Frame delay = {time.time()-frame.time}s")
self.prediction_socket.send_string(data) self.prediction_socket.send_string(data)

View file

@ -12,6 +12,8 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
from deep_sort_realtime.deepsort_tracker import DeepSort from deep_sort_realtime.deepsort_tracker import DeepSort
from deep_sort_realtime.deep_sort.track import Track from deep_sort_realtime.deep_sort.track import Track
from trap.frame_emitter import Frame
Detection = [int, int, int, int, float, int] Detection = [int, int, int, int, float, int]
Detections = [Detection] Detections = [Detection]
@ -23,13 +25,13 @@ class Tracker:
context = zmq.Context() context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB) self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.connect(config.zmq_frame_addr) self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'') self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame self.frame_sock.connect(config.zmq_frame_addr)
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.trajectory_socket.bind(config.zmq_trajectory_addr)
# # TODO: config device # # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -53,9 +55,19 @@ class Tracker:
def track(self): def track(self):
while True: while True:
frame = pickle.loads(self.frame_sock.recv()) msg = self.frame_sock.recv()
detections = self.detect_persons(frame) # after block, exhaust the queue: (superfluous now that CONFLATE is before the connect)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame) # i = 1
# while True:
# try:
# msg = self.frame_sock.recv(zmq.NOBLOCK)
# i+=1
# except Exception as e:
# break
frame: Frame = pickle.loads(msg)
logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
detections = self.detect_persons(frame.img)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
TEMP_boxes = [t.to_ltwh() for t in tracks] TEMP_boxes = [t.to_ltwh() for t in tracks]
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes]) TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
@ -73,7 +85,11 @@ class Tracker:
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test "history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
} }
logger.debug(f"{trajectories}") logger.debug(f"{trajectories}")
self.trajectory_socket.send_string(json.dumps(trajectories)) frame.trajectories = trajectories
logger.info(f"trajectory delay = {time.time()-frame.time}s")
self.trajectory_socket.send(pickle.dumps(frame))
# self.trajectory_socket.send_string(json.dumps(trajectories))
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}} # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
# TODO: provide a track object that actually keeps history (unlike tracker) # TODO: provide a track object that actually keeps history (unlike tracker)