switch to not predict video, but training data
This commit is contained in:
parent
f3ac903555
commit
2171dd459a
1 changed files with 69 additions and 60 deletions
|
@ -214,13 +214,13 @@ class PredictionServer:
|
||||||
prev_run_time = 0
|
prev_run_time = 0
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
timestep += 1
|
timestep += 1
|
||||||
this_run_time = time.time()
|
|
||||||
logger.debug(f'test {prev_run_time - this_run_time}')
|
|
||||||
time.sleep(max(0, prev_run_time - this_run_time + .5))
|
|
||||||
prev_run_time = time.time()
|
|
||||||
# for timestep in range(init_timestep + 1, eval_scene.timesteps):
|
|
||||||
|
|
||||||
# input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
# this_run_time = time.time()
|
||||||
|
# logger.debug(f'test {prev_run_time - this_run_time}')
|
||||||
|
# time.sleep(max(0, prev_run_time - this_run_time + .5))
|
||||||
|
# prev_run_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
# TODO: see process_data.py on how to create a node, the provide nodes + incoming data columns
|
# TODO: see process_data.py on how to create a node, the provide nodes + incoming data columns
|
||||||
# data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
|
# data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
|
||||||
# x = node_values[:, 0]
|
# x = node_values[:, 0]
|
||||||
|
@ -239,62 +239,66 @@ class PredictionServer:
|
||||||
|
|
||||||
# node_data = pd.DataFrame(data_dict, columns=data_columns)
|
# node_data = pd.DataFrame(data_dict, columns=data_columns)
|
||||||
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
|
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
|
||||||
|
if self.config.predict_training_data:
|
||||||
|
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
||||||
|
else:
|
||||||
|
data = self.trajectory_socket.recv()
|
||||||
|
frame: Frame = pickle.loads(data)
|
||||||
|
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||||
|
# trajectory_data = json.loads(data)
|
||||||
|
logger.debug(f"Receive {trajectory_data}")
|
||||||
|
|
||||||
data = self.trajectory_socket.recv()
|
# class FakeNode:
|
||||||
frame: Frame = pickle.loads(data)
|
# def __init__(self, node_type: NodeType):
|
||||||
trajectory_data = frame.trajectories # TODO: properly refractor
|
# self.type = node_type
|
||||||
# trajectory_data = json.loads(data)
|
|
||||||
logger.debug(f"Receive {trajectory_data}")
|
|
||||||
|
|
||||||
# class FakeNode:
|
input_dict = {}
|
||||||
# def __init__(self, node_type: NodeType):
|
for identifier, trajectory in trajectory_data.items():
|
||||||
# self.type = node_type
|
# if len(trajectory['history']) < 7:
|
||||||
|
# # TODO: these trajectories should still be in the output, but without predictions
|
||||||
|
# continue
|
||||||
|
|
||||||
input_dict = {}
|
# TODO: modify this into a mapping function between JS data an the expected Node format
|
||||||
for identifier, trajectory in trajectory_data.items():
|
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
||||||
# if len(trajectory['history']) < 7:
|
history = [[h['x'], h['y']] for h in trajectory['history']]
|
||||||
# # TODO: these trajectories should still be in the output, but without predictions
|
history = np.array(history)
|
||||||
# continue
|
x = history[:, 0]
|
||||||
|
y = history[:, 1]
|
||||||
|
# TODO: calculate dt based on input
|
||||||
|
vx = derivative_of(x, 0.2) #eval_scene.dt
|
||||||
|
vy = derivative_of(y, 0.2)
|
||||||
|
ax = derivative_of(vx, 0.2)
|
||||||
|
ay = derivative_of(vy, 0.2)
|
||||||
|
|
||||||
# TODO: modify this into a mapping function between JS data an the expected Node format
|
data_dict = {('position', 'x'): x[:],
|
||||||
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
('position', 'y'): y[:],
|
||||||
history = [[h['x'], h['y']] for h in trajectory['history']]
|
('velocity', 'x'): vx[:],
|
||||||
history = np.array(history)
|
('velocity', 'y'): vy[:],
|
||||||
x = history[:, 0]
|
('acceleration', 'x'): ax[:],
|
||||||
y = history[:, 1]
|
('acceleration', 'y'): ay[:]}
|
||||||
# TODO: calculate dt based on input
|
data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
|
||||||
vx = derivative_of(x, 0.2) #eval_scene.dt
|
|
||||||
vy = derivative_of(y, 0.2)
|
|
||||||
ax = derivative_of(vx, 0.2)
|
|
||||||
ay = derivative_of(vy, 0.2)
|
|
||||||
|
|
||||||
data_dict = {('position', 'x'): x[:],
|
node_data = pd.DataFrame(data_dict, columns=data_columns)
|
||||||
('position', 'y'): y[:],
|
node = Node(
|
||||||
('velocity', 'x'): vx[:],
|
node_type=online_env.NodeType.PEDESTRIAN,
|
||||||
('velocity', 'y'): vy[:],
|
node_id=identifier,
|
||||||
('acceleration', 'x'): ax[:],
|
data=node_data,
|
||||||
('acceleration', 'y'): ay[:]}
|
first_timestep=timestep
|
||||||
data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
|
)
|
||||||
|
|
||||||
node_data = pd.DataFrame(data_dict, columns=data_columns)
|
input_dict[node] = np.array([x[-1],y[-1],vx[-1],vy[-1],ax[-1],ay[-1]])
|
||||||
node = Node(
|
|
||||||
node_type=online_env.NodeType.PEDESTRIAN,
|
|
||||||
node_id=identifier,
|
|
||||||
data=node_data,
|
|
||||||
first_timestep=timestep
|
|
||||||
)
|
|
||||||
|
|
||||||
input_dict[node] = np.array([x[-1],y[-1],vx[-1],vy[-1],ax[-1],ay[-1]])
|
# print(input_dict)
|
||||||
|
|
||||||
# print(input_dict)
|
if not len(input_dict):
|
||||||
|
# skip if our input is empty
|
||||||
|
# TODO: we want to send out empty result...
|
||||||
|
# And want to update the network
|
||||||
|
|
||||||
if not len(input_dict):
|
data = json.dumps({})
|
||||||
# skip if our input is empty
|
self.prediction_socket.send_string(data)
|
||||||
# TODO: we want to send out empty result...
|
|
||||||
|
|
||||||
data = json.dumps({})
|
continue
|
||||||
self.prediction_socket.send_string(data)
|
|
||||||
continue
|
|
||||||
|
|
||||||
maps = None
|
maps = None
|
||||||
if hyperparams['use_map_encoding']:
|
if hyperparams['use_map_encoding']:
|
||||||
|
@ -311,12 +315,14 @@ class PredictionServer:
|
||||||
# robot_present_and_future += adjustment
|
# robot_present_and_future += adjustment
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
with warnings.catch_warnings():
|
||||||
maps,
|
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
|
||||||
prediction_horizon=20, # TODO: make variable
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
num_samples=2, # TODO: make variable
|
maps,
|
||||||
robot_present_and_future=robot_present_and_future,
|
prediction_horizon=25, # TODO: make variable
|
||||||
full_dist=True)
|
num_samples=20, # TODO: make variable
|
||||||
|
robot_present_and_future=robot_present_and_future,
|
||||||
|
full_dist=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
|
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
|
||||||
1. / (end - start), len(trajectron.nodes),
|
1. / (end - start), len(trajectron.nodes),
|
||||||
|
@ -365,7 +371,10 @@ class PredictionServer:
|
||||||
}
|
}
|
||||||
|
|
||||||
data = json.dumps(response)
|
data = json.dumps(response)
|
||||||
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
if self.config.predict_training_data:
|
||||||
|
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
||||||
|
else:
|
||||||
|
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||||
self.prediction_socket.send_string(data)
|
self.prediction_socket.send_string(data)
|
||||||
logger.info('Stopping')
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue