Fix zmq-conflate and measure timings
This commit is contained in:
		
							parent
							
								
									23162da767
								
							
						
					
					
						commit
						27565d919e
					
				
					 3 changed files with 48 additions and 18 deletions
				
			
		| 
						 | 
					@ -1,13 +1,21 @@
 | 
				
			||||||
from argparse import Namespace
 | 
					from argparse import Namespace
 | 
				
			||||||
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
import cv2
 | 
					import cv2
 | 
				
			||||||
import zmq
 | 
					import zmq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger('trap.frame_emitter')
 | 
					logger = logging.getLogger('trap.frame_emitter')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class Frame:
 | 
				
			||||||
 | 
					    img: np.array
 | 
				
			||||||
 | 
					    time: float= field(default_factory=lambda: time.time())
 | 
				
			||||||
 | 
					    trajectories: Optional[dict] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FrameEmitter:
 | 
					class FrameEmitter:
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    Emit frame in a separate threat so they can be throttled,
 | 
					    Emit frame in a separate threat so they can be throttled,
 | 
				
			||||||
| 
						 | 
					@ -18,8 +26,8 @@ class FrameEmitter:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context = zmq.Context()
 | 
					        context = zmq.Context()
 | 
				
			||||||
        self.frame_sock = context.socket(zmq.PUB)
 | 
					        self.frame_sock = context.socket(zmq.PUB)
 | 
				
			||||||
 | 
					        self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
 | 
				
			||||||
        self.frame_sock.bind(config.zmq_frame_addr)
 | 
					        self.frame_sock.bind(config.zmq_frame_addr)
 | 
				
			||||||
        self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
 | 
					 | 
				
			||||||
        logger.info(f"Connection socket {config.zmq_frame_addr}")
 | 
					        logger.info(f"Connection socket {config.zmq_frame_addr}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def emit_video(self):
 | 
					    def emit_video(self):
 | 
				
			||||||
| 
						 | 
					@ -29,14 +37,14 @@ class FrameEmitter:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        prev_time = time.time()
 | 
					        prev_time = time.time()
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            ret, frame = video.read()
 | 
					            ret, img = video.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # seek to 0 if video has finished. Infinite loop
 | 
					            # seek to 0 if video has finished. Infinite loop
 | 
				
			||||||
            if not ret:
 | 
					            if not ret:
 | 
				
			||||||
                video.set(cv2.CAP_PROP_POS_FRAMES, 0)
 | 
					                video.set(cv2.CAP_PROP_POS_FRAMES, 0)
 | 
				
			||||||
                ret, frame = video.read()
 | 
					                ret, img = video.read()
 | 
				
			||||||
                assert ret is not False # not really error proof...
 | 
					                assert ret is not False # not really error proof...
 | 
				
			||||||
            
 | 
					            frame = Frame(img=img)
 | 
				
			||||||
            # TODO: this is very dirty, need to find another way.
 | 
					            # TODO: this is very dirty, need to find another way.
 | 
				
			||||||
            # perhaps multiprocessing queue?
 | 
					            # perhaps multiprocessing queue?
 | 
				
			||||||
            self.frame_sock.send(pickle.dumps(frame))
 | 
					            self.frame_sock.send(pickle.dumps(frame))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,6 +2,7 @@
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from multiprocessing import Queue
 | 
					from multiprocessing import Queue
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					import pickle
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import pandas as pd
 | 
					import pandas as pd
 | 
				
			||||||
| 
						 | 
					@ -21,6 +22,8 @@ import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import zmq
 | 
					import zmq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from trap.frame_emitter import Frame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger("trap.prediction")
 | 
					logger = logging.getLogger("trap.prediction")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,9 +114,9 @@ class InferenceServer:
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        context = zmq.Context()
 | 
					        context = zmq.Context()
 | 
				
			||||||
        self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
 | 
					        self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
 | 
				
			||||||
        self.trajectory_socket.connect(config.zmq_trajectory_addr)
 | 
					 | 
				
			||||||
        self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
 | 
					        self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
 | 
				
			||||||
        self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
 | 
					        self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
 | 
				
			||||||
 | 
					        self.trajectory_socket.connect(config.zmq_trajectory_addr)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
 | 
					        self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
 | 
				
			||||||
        self.prediction_socket.bind(config.zmq_prediction_addr)
 | 
					        self.prediction_socket.bind(config.zmq_prediction_addr)
 | 
				
			||||||
| 
						 | 
					@ -220,8 +223,10 @@ class InferenceServer:
 | 
				
			||||||
            # node_data = pd.DataFrame(data_dict, columns=data_columns)
 | 
					            # node_data = pd.DataFrame(data_dict, columns=data_columns)
 | 
				
			||||||
            # node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
 | 
					            # node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            data = self.trajectory_socket.recv_string()
 | 
					            data = self.trajectory_socket.recv()
 | 
				
			||||||
            trajectory_data = json.loads(data)
 | 
					            frame: Frame = pickle.loads(data)
 | 
				
			||||||
 | 
					            trajectory_data = frame.trajectories # TODO: properly refractor
 | 
				
			||||||
 | 
					            # trajectory_data = json.loads(data)
 | 
				
			||||||
            logger.debug(f"Receive {trajectory_data}")
 | 
					            logger.debug(f"Receive {trajectory_data}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # class FakeNode:
 | 
					            # class FakeNode:
 | 
				
			||||||
| 
						 | 
					@ -291,8 +296,8 @@ class InferenceServer:
 | 
				
			||||||
            start = time.time()
 | 
					            start = time.time()
 | 
				
			||||||
            dists, preds = trajectron.incremental_forward(input_dict,
 | 
					            dists, preds = trajectron.incremental_forward(input_dict,
 | 
				
			||||||
                                                        maps,
 | 
					                                                        maps,
 | 
				
			||||||
                                                        prediction_horizon=16, # TODO: make variable
 | 
					                                                        prediction_horizon=10, # TODO: make variable
 | 
				
			||||||
                                                        num_samples=3, # TODO: make variable
 | 
					                                                        num_samples=2, # TODO: make variable
 | 
				
			||||||
                                                        robot_present_and_future=robot_present_and_future,
 | 
					                                                        robot_present_and_future=robot_present_and_future,
 | 
				
			||||||
                                                        full_dist=True)
 | 
					                                                        full_dist=True)
 | 
				
			||||||
            end = time.time()
 | 
					            end = time.time()
 | 
				
			||||||
| 
						 | 
					@ -343,6 +348,7 @@ class InferenceServer:
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            data = json.dumps(response)
 | 
					            data = json.dumps(response)
 | 
				
			||||||
 | 
					            logger.info(f"Frame delay = {time.time()-frame.time}s")
 | 
				
			||||||
            self.prediction_socket.send_string(data)
 | 
					            self.prediction_socket.send_string(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,6 +12,8 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
 | 
				
			||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
 | 
					from deep_sort_realtime.deepsort_tracker import DeepSort
 | 
				
			||||||
from deep_sort_realtime.deep_sort.track import Track
 | 
					from deep_sort_realtime.deep_sort.track import Track
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from trap.frame_emitter import Frame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Detection = [int, int, int, int, float, int]
 | 
					Detection = [int, int, int, int, float, int]
 | 
				
			||||||
Detections = [Detection]
 | 
					Detections = [Detection]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -23,13 +25,13 @@ class Tracker:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context = zmq.Context()
 | 
					        context = zmq.Context()
 | 
				
			||||||
        self.frame_sock = context.socket(zmq.SUB)
 | 
					        self.frame_sock = context.socket(zmq.SUB)
 | 
				
			||||||
        self.frame_sock.connect(config.zmq_frame_addr)
 | 
					        self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
 | 
				
			||||||
        self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
					        self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
				
			||||||
        self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
 | 
					        self.frame_sock.connect(config.zmq_frame_addr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.trajectory_socket = context.socket(zmq.PUB)
 | 
					        self.trajectory_socket = context.socket(zmq.PUB)
 | 
				
			||||||
        self.trajectory_socket.bind(config.zmq_trajectory_addr)
 | 
					 | 
				
			||||||
        self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
 | 
					        self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
 | 
				
			||||||
 | 
					        self.trajectory_socket.bind(config.zmq_trajectory_addr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # # TODO: config device
 | 
					        # # TODO: config device
 | 
				
			||||||
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
					        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
| 
						 | 
					@ -53,9 +55,19 @@ class Tracker:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def track(self):
 | 
					    def track(self):
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            frame = pickle.loads(self.frame_sock.recv())
 | 
					            msg = self.frame_sock.recv()
 | 
				
			||||||
            detections = self.detect_persons(frame)
 | 
					            # after block, exhaust the queue: (superfluous now that CONFLATE is before the connect)
 | 
				
			||||||
            tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
 | 
					            # i = 1
 | 
				
			||||||
 | 
					            # while True:
 | 
				
			||||||
 | 
					            #     try:
 | 
				
			||||||
 | 
					            #         msg = self.frame_sock.recv(zmq.NOBLOCK)
 | 
				
			||||||
 | 
					            #         i+=1
 | 
				
			||||||
 | 
					            #     except Exception as e:
 | 
				
			||||||
 | 
					            #         break
 | 
				
			||||||
 | 
					            frame: Frame = pickle.loads(msg)
 | 
				
			||||||
 | 
					            logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
 | 
				
			||||||
 | 
					            detections = self.detect_persons(frame.img)
 | 
				
			||||||
 | 
					            tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            TEMP_boxes = [t.to_ltwh() for t in tracks]
 | 
					            TEMP_boxes = [t.to_ltwh() for t in tracks]
 | 
				
			||||||
            TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
 | 
					            TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
 | 
				
			||||||
| 
						 | 
					@ -73,7 +85,11 @@ class Tracker:
 | 
				
			||||||
                     "history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
 | 
					                     "history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
 | 
				
			||||||
                 }
 | 
					                 }
 | 
				
			||||||
            logger.debug(f"{trajectories}")
 | 
					            logger.debug(f"{trajectories}")
 | 
				
			||||||
            self.trajectory_socket.send_string(json.dumps(trajectories))
 | 
					            frame.trajectories = trajectories
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            logger.info(f"trajectory delay = {time.time()-frame.time}s")
 | 
				
			||||||
 | 
					            self.trajectory_socket.send(pickle.dumps(frame))
 | 
				
			||||||
 | 
					            # self.trajectory_socket.send_string(json.dumps(trajectories))
 | 
				
			||||||
            # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
 | 
					            # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
 | 
				
			||||||
            # TODO: provide a track object that actually keeps history (unlike tracker)
 | 
					            # TODO: provide a track object that actually keeps history (unlike tracker)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue