Test renderer to frames
This commit is contained in:
		
							parent
							
								
									3d34263a71
								
							
						
					
					
						commit
						821d06c9cf
					
				
					 4 changed files with 109 additions and 2 deletions
				
			
		| 
						 | 
					@ -1,6 +1,8 @@
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pyparsing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
parser = argparse.ArgumentParser()
 | 
					parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +29,7 @@ inference_parser = parser.add_argument_group('Inference')
 | 
				
			||||||
connection_parser = parser.add_argument_group('Connection')
 | 
					connection_parser = parser.add_argument_group('Connection')
 | 
				
			||||||
frame_emitter_parser = parser.add_argument_group('Frame emitter')
 | 
					frame_emitter_parser = parser.add_argument_group('Frame emitter')
 | 
				
			||||||
tracker_parser = parser.add_argument_group('Tracker')
 | 
					tracker_parser = parser.add_argument_group('Tracker')
 | 
				
			||||||
 | 
					render_parser = parser.add_argument_group('Renderer')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
inference_parser.add_argument("--model_dir",
 | 
					inference_parser.add_argument("--model_dir",
 | 
				
			||||||
                    help="directory with the model to use for inference",
 | 
					                    help="directory with the model to use for inference",
 | 
				
			||||||
| 
						 | 
					@ -106,7 +109,7 @@ inference_parser.add_argument("--eval_data_dict",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
inference_parser.add_argument("--output_dir",
 | 
					inference_parser.add_argument("--output_dir",
 | 
				
			||||||
                    help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
 | 
					                    help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
 | 
				
			||||||
                    type=str,
 | 
					                    type=Path,
 | 
				
			||||||
                    default='./OUT/test_inference')
 | 
					                    default='./OUT/test_inference')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -174,3 +177,11 @@ tracker_parser.add_argument("--homography",
 | 
				
			||||||
                    help="File with homography params",
 | 
					                    help="File with homography params",
 | 
				
			||||||
                    type=Path,
 | 
					                    type=Path,
 | 
				
			||||||
                    default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
 | 
					                    default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Renderer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# render_parser.add_argument("--output-dir",
 | 
				
			||||||
 | 
					#                     help="Target image dir",
 | 
				
			||||||
 | 
					#                     type=Optional[Path],
 | 
				
			||||||
 | 
					#                     default=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,6 +5,7 @@ import sys
 | 
				
			||||||
from trap.config import parser
 | 
					from trap.config import parser
 | 
				
			||||||
from trap.frame_emitter import run_frame_emitter
 | 
					from trap.frame_emitter import run_frame_emitter
 | 
				
			||||||
from trap.prediction_server import run_prediction_server
 | 
					from trap.prediction_server import run_prediction_server
 | 
				
			||||||
 | 
					from trap.renderer import run_renderer
 | 
				
			||||||
from trap.socket_forwarder import run_ws_forwarder
 | 
					from trap.socket_forwarder import run_ws_forwarder
 | 
				
			||||||
from trap.tracker import run_tracker
 | 
					from trap.tracker import run_tracker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,6 +53,7 @@ def start():
 | 
				
			||||||
        ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
 | 
					        ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
 | 
				
			||||||
        ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
 | 
					        ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
 | 
				
			||||||
        ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
 | 
					        ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
 | 
				
			||||||
 | 
					        ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'),
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
    if not args.bypass_prediction:
 | 
					    if not args.bypass_prediction:
 | 
				
			||||||
        procs.append(
 | 
					        procs.append(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,8 +4,11 @@ import logging
 | 
				
			||||||
from multiprocessing import Event, Queue
 | 
					from multiprocessing import Event, Queue
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
import pandas as pd
 | 
					import pandas as pd
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import dill
 | 
					import dill
 | 
				
			||||||
| 
						 | 
					@ -301,7 +304,7 @@ class PredictionServer:
 | 
				
			||||||
            start = time.time()
 | 
					            start = time.time()
 | 
				
			||||||
            dists, preds = trajectron.incremental_forward(input_dict,
 | 
					            dists, preds = trajectron.incremental_forward(input_dict,
 | 
				
			||||||
                                                        maps,
 | 
					                                                        maps,
 | 
				
			||||||
                                                        prediction_horizon=10, # TODO: make variable
 | 
					                                                        prediction_horizon=20, # TODO: make variable
 | 
				
			||||||
                                                        num_samples=2, # TODO: make variable
 | 
					                                                        num_samples=2, # TODO: make variable
 | 
				
			||||||
                                                        robot_present_and_future=robot_present_and_future,
 | 
					                                                        robot_present_and_future=robot_present_and_future,
 | 
				
			||||||
                                                        full_dist=True)
 | 
					                                                        full_dist=True)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										91
									
								
								trap/renderer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								trap/renderer.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,91 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from argparse import Namespace
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					from multiprocessing import Event
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import zmq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from trap.frame_emitter import Frame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.getLogger("trap.renderer")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Renderer:
 | 
				
			||||||
 | 
					    def __init__(self, config: Namespace, is_running: Event):
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.is_running = is_running
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        context = zmq.Context()
 | 
				
			||||||
 | 
					        self.prediction_sock = context.socket(zmq.SUB)
 | 
				
			||||||
 | 
					        self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
 | 
				
			||||||
 | 
					        self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
				
			||||||
 | 
					        self.prediction_sock.connect(config.zmq_prediction_addr)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.frame_sock = context.socket(zmq.SUB)
 | 
				
			||||||
 | 
					        self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
 | 
				
			||||||
 | 
					        self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
				
			||||||
 | 
					        self.frame_sock.connect(config.zmq_frame_addr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        H = np.loadtxt(self.config.homography, delimiter=',')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.inv_H = np.linalg.pinv(H)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.config.output_dir.exists():
 | 
				
			||||||
 | 
					            raise FileNotFoundError("Path does not exist")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def run(self):
 | 
				
			||||||
 | 
					        predictions = {}
 | 
				
			||||||
 | 
					        i=0
 | 
				
			||||||
 | 
					        first_time = None
 | 
				
			||||||
 | 
					        while self.is_running.is_set():
 | 
				
			||||||
 | 
					            i+=1
 | 
				
			||||||
 | 
					            frame: Frame = self.frame_sock.recv_pyobj()
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
 | 
				
			||||||
 | 
					            except zmq.ZMQError as e:
 | 
				
			||||||
 | 
					                logger.debug(f'reuse prediction')
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            img = frame.img
 | 
				
			||||||
 | 
					            for track_id, prediction in predictions.items():
 | 
				
			||||||
 | 
					                if not 'history' in prediction or not len(prediction['history']):
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
 | 
				
			||||||
 | 
					                # logger.warning(f"{coords=}")
 | 
				
			||||||
 | 
					                center = [int(p) for p in coords[-1]]
 | 
				
			||||||
 | 
					                cv2.circle(img, center, 5, (0,255,0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                for ci in range(1, len(coords)):
 | 
				
			||||||
 | 
					                    start = [int(p) for p in coords[ci-1]]
 | 
				
			||||||
 | 
					                    end = [int(p) for p in coords[ci]]
 | 
				
			||||||
 | 
					                    cv2.line(img, start, end, (255,255,255), 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if not 'predictions' in prediction or not len(prediction['predictions']):
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                for pred in prediction['predictions']:
 | 
				
			||||||
 | 
					                    pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
 | 
				
			||||||
 | 
					                    for ci in range(1, len(pred_coords)):
 | 
				
			||||||
 | 
					                        start = [int(p) for p in pred_coords[ci-1]]
 | 
				
			||||||
 | 
					                        end = [int(p) for p in pred_coords[ci]]
 | 
				
			||||||
 | 
					                        cv2.line(img, start, end, (0,0,255), 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if first_time is None:
 | 
				
			||||||
 | 
					                first_time = frame.time
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            cv2.imwrite(str(img_path), img)
 | 
				
			||||||
 | 
					        logger.info('Stopping')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					def run_renderer(config: Namespace, is_running: Event):
 | 
				
			||||||
 | 
					    renderer = Renderer(config, is_running)
 | 
				
			||||||
 | 
					    renderer.run()
 | 
				
			||||||
		Loading…
	
		Reference in a new issue