Fix zmq-conflate and measure timings
This commit is contained in:
parent
23162da767
commit
27565d919e
3 changed files with 48 additions and 18 deletions
|
@ -1,13 +1,21 @@
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
logger = logging.getLogger('trap.frame_emitter')
|
logger = logging.getLogger('trap.frame_emitter')
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Frame:
|
||||||
|
img: np.array
|
||||||
|
time: float= field(default_factory=lambda: time.time())
|
||||||
|
trajectories: Optional[dict] = None
|
||||||
|
|
||||||
class FrameEmitter:
|
class FrameEmitter:
|
||||||
'''
|
'''
|
||||||
Emit frame in a separate threat so they can be throttled,
|
Emit frame in a separate threat so they can be throttled,
|
||||||
|
@ -18,8 +26,8 @@ class FrameEmitter:
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
self.frame_sock = context.socket(zmq.PUB)
|
self.frame_sock = context.socket(zmq.PUB)
|
||||||
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
|
||||||
self.frame_sock.bind(config.zmq_frame_addr)
|
self.frame_sock.bind(config.zmq_frame_addr)
|
||||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
|
||||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||||
|
|
||||||
def emit_video(self):
|
def emit_video(self):
|
||||||
|
@ -29,14 +37,14 @@ class FrameEmitter:
|
||||||
|
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
ret, frame = video.read()
|
ret, img = video.read()
|
||||||
|
|
||||||
# seek to 0 if video has finished. Infinite loop
|
# seek to 0 if video has finished. Infinite loop
|
||||||
if not ret:
|
if not ret:
|
||||||
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||||
ret, frame = video.read()
|
ret, img = video.read()
|
||||||
assert ret is not False # not really error proof...
|
assert ret is not False # not really error proof...
|
||||||
|
frame = Frame(img=img)
|
||||||
# TODO: this is very dirty, need to find another way.
|
# TODO: this is very dirty, need to find another way.
|
||||||
# perhaps multiprocessing queue?
|
# perhaps multiprocessing queue?
|
||||||
self.frame_sock.send(pickle.dumps(frame))
|
self.frame_sock.send(pickle.dumps(frame))
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Queue
|
from multiprocessing import Queue
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -21,6 +22,8 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
logger = logging.getLogger("trap.prediction")
|
logger = logging.getLogger("trap.prediction")
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,9 +114,9 @@ class InferenceServer:
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
||||||
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
|
||||||
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
|
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
|
||||||
|
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
||||||
|
|
||||||
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
||||||
self.prediction_socket.bind(config.zmq_prediction_addr)
|
self.prediction_socket.bind(config.zmq_prediction_addr)
|
||||||
|
@ -220,8 +223,10 @@ class InferenceServer:
|
||||||
# node_data = pd.DataFrame(data_dict, columns=data_columns)
|
# node_data = pd.DataFrame(data_dict, columns=data_columns)
|
||||||
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
|
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
|
||||||
|
|
||||||
data = self.trajectory_socket.recv_string()
|
data = self.trajectory_socket.recv()
|
||||||
trajectory_data = json.loads(data)
|
frame: Frame = pickle.loads(data)
|
||||||
|
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||||
|
# trajectory_data = json.loads(data)
|
||||||
logger.debug(f"Receive {trajectory_data}")
|
logger.debug(f"Receive {trajectory_data}")
|
||||||
|
|
||||||
# class FakeNode:
|
# class FakeNode:
|
||||||
|
@ -291,8 +296,8 @@ class InferenceServer:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
maps,
|
maps,
|
||||||
prediction_horizon=16, # TODO: make variable
|
prediction_horizon=10, # TODO: make variable
|
||||||
num_samples=3, # TODO: make variable
|
num_samples=2, # TODO: make variable
|
||||||
robot_present_and_future=robot_present_and_future,
|
robot_present_and_future=robot_present_and_future,
|
||||||
full_dist=True)
|
full_dist=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
@ -343,6 +348,7 @@ class InferenceServer:
|
||||||
}
|
}
|
||||||
|
|
||||||
data = json.dumps(response)
|
data = json.dumps(response)
|
||||||
|
logger.info(f"Frame delay = {time.time()-frame.time}s")
|
||||||
self.prediction_socket.send_string(data)
|
self.prediction_socket.send_string(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
|
||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||||
from deep_sort_realtime.deep_sort.track import Track
|
from deep_sort_realtime.deep_sort.track import Track
|
||||||
|
|
||||||
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
Detection = [int, int, int, int, float, int]
|
Detection = [int, int, int, int, float, int]
|
||||||
Detections = [Detection]
|
Detections = [Detection]
|
||||||
|
|
||||||
|
@ -23,13 +25,13 @@ class Tracker:
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
self.frame_sock = context.socket(zmq.SUB)
|
self.frame_sock = context.socket(zmq.SUB)
|
||||||
self.frame_sock.connect(config.zmq_frame_addr)
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
self.frame_sock.connect(config.zmq_frame_addr)
|
||||||
|
|
||||||
self.trajectory_socket = context.socket(zmq.PUB)
|
self.trajectory_socket = context.socket(zmq.PUB)
|
||||||
self.trajectory_socket.bind(config.zmq_trajectory_addr)
|
|
||||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||||
|
self.trajectory_socket.bind(config.zmq_trajectory_addr)
|
||||||
|
|
||||||
# # TODO: config device
|
# # TODO: config device
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
@ -53,9 +55,19 @@ class Tracker:
|
||||||
|
|
||||||
def track(self):
|
def track(self):
|
||||||
while True:
|
while True:
|
||||||
frame = pickle.loads(self.frame_sock.recv())
|
msg = self.frame_sock.recv()
|
||||||
detections = self.detect_persons(frame)
|
# after block, exhaust the queue: (superfluous now that CONFLATE is before the connect)
|
||||||
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
|
# i = 1
|
||||||
|
# while True:
|
||||||
|
# try:
|
||||||
|
# msg = self.frame_sock.recv(zmq.NOBLOCK)
|
||||||
|
# i+=1
|
||||||
|
# except Exception as e:
|
||||||
|
# break
|
||||||
|
frame: Frame = pickle.loads(msg)
|
||||||
|
logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||||
|
detections = self.detect_persons(frame.img)
|
||||||
|
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
|
||||||
|
|
||||||
TEMP_boxes = [t.to_ltwh() for t in tracks]
|
TEMP_boxes = [t.to_ltwh() for t in tracks]
|
||||||
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
|
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
|
||||||
|
@ -73,7 +85,11 @@ class Tracker:
|
||||||
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
|
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
|
||||||
}
|
}
|
||||||
logger.debug(f"{trajectories}")
|
logger.debug(f"{trajectories}")
|
||||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
frame.trajectories = trajectories
|
||||||
|
|
||||||
|
logger.info(f"trajectory delay = {time.time()-frame.time}s")
|
||||||
|
self.trajectory_socket.send(pickle.dumps(frame))
|
||||||
|
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue