Trajectron-plus-plus/experiments/nuScenes/NuScenes Qualitative.ipynb

646 lines
720 KiB
Text
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import sys\n",
"sys.path.append('../../trajectron')\n",
"import os\n",
"import numpy as np\n",
"import torch\n",
"import dill\n",
"import json\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker\n",
"import matplotlib.patheffects as pe\n",
"from helper import *\n",
"import visualization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load nuScenes SDK and data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"nuScenes_data_path = # Data Path to nuScenes data set \n",
"nuScenes_devkit_path = './devkit/python-sdk/'\n",
"sys.path.append(nuScenes_devkit_path)\n",
"from nuscenes.map_expansion.map_api import NuScenesMap\n",
"nusc_map = NuScenesMap(dataroot=nuScenes_data_path, map_name='boston-seaport')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"line_colors = ['#375397','#80CBE5','#ABCB51','#F05F78', '#C8B0B0']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Map Encoding Demo"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"with open('../processed/nuScenes_test_full.pkl', 'rb') as f:\n",
" eval_env = dill.load(f, encoding='latin1')\n",
"eval_scenes = eval_env.scenes"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"ph = 6\n",
"log_dir = './models'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Loading from ./models/int_ee_me/model_registrar-12.pt\n",
"Loaded!\n",
"\n"
]
}
],
"source": [
"model_dir = os.path.join(log_dir, 'int_ee_me') \n",
"eval_stg, hyp = load_model(model_dir, eval_env, ts=12)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'105'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scene = eval_scenes[25]\n",
"scene.name"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Define ROI in nuScenes Map\n",
"x_min = 773.0\n",
"x_max = 1100.0\n",
"y_min = 1231.0\n",
"y_max = 1510.0"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"layers = ['drivable_area',\n",
" 'road_segment',\n",
" 'lane',\n",
" 'ped_crossing',\n",
" 'walkway',\n",
" 'stop_line',\n",
" 'road_divider',\n",
" 'lane_divider']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction including Map Encoding"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAt4AAAJ0CAYAAAAlEyTXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZQk6Vmf+8SWEbnWvnZX9VY9M92zSRqQZ4QkFvkI8BWGe0DYBywjg5CNJYHEIgGS0GjB8oZsuBoLri0DMvha+N5jQBhsMLaMELq+Ntpmn16ra9+yKjMyY4/47h9ZmbVvPVWVS3/POXG6qyoz4quoyIjf937v+3sVIYRAIpFIJBKJRCKRnChqswcgkUgkEolEIpHcC0jhLZFIJBKJRCKRnAJSeEskEolEIpFIJKeAFN4SiUQikUgkEskpIIW3RCKRSCQSiURyCuj7/fDLM7dPaRjNRwjByuwSpdIaQyOjZLuzKIrS7GG1BZWSw/SNKXRDZ2xiDDOdavaQOpZq1eH6My/wTU98K6nU1vNs2zZVu4zv2PRk0/R2F7Asq7EBBEFAHMfEUUQcBFRsm+L8PHEYoakKhqahqSq6qqIrKqquYaRS6JqGoRvouoaqqqAoqKqKoiioioqqKvf85yVJEoIgIIgi/EQQqwqpdJpULodpmqRSqdq5OyRRFOF5HsVikemFJRLVpNDTS6Gr61DvD4KAwPcIPJc4dEnpGmZKJ5VKoSgKy2s2XX0jO66j0yAIAhy7TOi5ZFMWmUz6SOdGcm8RRRFVx8GLI6xcnkw2h6ZpzR5WR1OpVvGKCzx85T5KU1P05XLNHtKJU65UoLeXnt7eEz2Osp+d4L0kvKEmvt1ylYWZebKFHH0jA/LDfUiiIGLm1gyVcoWz589Q6O+654XYSfD8159hfOgCY+fO7/maJEmoVqs4TpUoCEjCgCSJMFRImyaZTJp8Jo1lWZSKRYa6u0mn00RR1NjCMKz9GwREQUAchoRhSByGKLWDoCoKSpKgAIqqoSgKuq6hGwa6rqOo6oZI1zS09dfURbpSF+6qioLSduI9jmOCIMCPIkIhiDUNM5utbaaJYRiH+n2SJMHzPBzHwXEc7KpDxfFwgxAjZaGbabp6e7FM8+Dx+D6B7xB4LroiSKU0zFSqIbY343kepWpI7+BI00RvFEU4FRu/WiVtGGTSaXR933iQ5B4mjmMcx8UJA8xslkwuL6+XE2J+doazPRm6u7sJ5uboyuebPaQTJYoiVoKA4fPnT/x+KIX3LkRByML0PFEYMXJulJS1/wNPUkMIQXFhldk7s/T09jB6cRRVbR8h1eoszS+wOr/K44+/7q7eH0URnu8T+B5xGFGplPFWlunKZjB0hYxlYZkm2bSFaZqNKO3myacQgiRJiOO48W8URSRRVBPnQUC0LtBJklou2/q/CtQE96ZoOYAAEiEQovYiZV2o116noajrIl1V0VRtd+G+HnlX1t9/EkRRhB8EBHFMKATCMLByOcxMhlQqhWEYB+7D87yGyLYdl3LFwfU8UA1008RI1c69lUmTMvaPRAshCHwf33cJPZckCbB0vRbVNs1DBQ3Ktk0gUvT0DRz6PJwESZLgVCu4to2pqGQy6aZE4iXtQZIkOI5L1fcw0mky+YK8Xo6RJEmYuvEiT7zqEaq2jV4qkclkmj2sE6Vo21jDw+RPYYIhhfceCCFYW1ihWCwyMDRIvldGcA+LU3GZvjmNSBLGJsbJ5NLNHlLbE8cxz/zlV3nFQ68+tmWw1eUlMqpCOp0hjmN83ycIfKIgJI5DkigiiQJ0rSbKM2mL9Logr297RQY2C/S6SI/WxXldpEdhCEKgURPlmqKgJgkKSk1sKwoKoCoKoJCIpLZfIRD1jdpDImn8X6Com4S7ptX2p6mNCHw9NWa3lJl6RF4IQRiGBGFIIAShEGipVC1tJF0ThftF2qIownEcPM+jXKlSdVxKlSqoGrphohkmpmWRWp/sHDbCsjV9xCOlq430kcMI/91YKa6iZ3rI5Qt39f7jRAiB6zhUy2UMEjJWupEmJZFsRwiB6zpUPA81ZZHN5zHl9fKysW0bUSnyyENXWZyeJp8kHT2x8X2fMjA0Pn4qOk8K7wNwbYeFmTnMjMngmRGZenJIkkQwd2ue4vIKI+Oj9A33yInLy+DWjZuYWDz04KPHsr8oilidn2Owt/fAv0sYhYRB2BDlUegTRxEkEYamkrZMMmmLbDrdSLHYT5RvRgixJXre2Naj5kkU1QR6kqCtC/H6+zaPWwgB9Qj6+j7rorw2CRAIkSBELbqeiKQRZRdCEItkPfJe+1rVVFLpNMPnz5PN5XZE/jefxyAIcByHarWKXXWxqw5BItB0Az1loadSWOkMlmkeeVl8e/qIRoJp6numj9wNSZKwVCyR7x3GPCCd5TTxXBenYiPCgKxpkk5n5D1Esiee61L1PBJNJ1soYFqWvF7ukpmpO1we7WNwcJCZGzcYTHd2DcZS2aZ7fOzUJvlSeB+COIpZnJrD932Gx89gZeSM+rCUlstM35oik88xdvEMekrm4x2V/Qoq75ZyaQ09CMnlsi9rP2EUEngBfuCRRDFx6BOti3JT18lYJqaVIr0uxDVNQ9d1NK1WpLn5//uRJEljq7P5obrbA3b79w7zdX0C4Ps+c8vLdA8MNIomvSAkCEKC9ZSdMIxB09A0A0U3SJmHTxPZi83pI4HnIu4ifeRuCIKAYtmld3C05YILjUJM1yVrykJMyf74vo/jugRAttBFOiMnbEchSRJmbr3EE696FEVRWLx5k8EOzu92HAcvnWZgZOTUjimF9yERQrC2vEpxaZm+gX66+mUE97D4XsDMjWk8z2fs0jj57pcn9u41nn/6WcYHz+9bUHkUkiRhZX6W/kLXiYqsLaI8TkAIkiRCxDEiEcRJDCSwLqYNTUXTDVK6iqHrGIaOpqqkDKMh2HVdXy/i1Bui/TC/b70Qsp6T3kgjiWoRZT9KiKKAKExAVfDDGE9VyXf3outGrThU1zCMFJqu14pINf1YBGAYhgSeR+C7RIF7LOkjd0OlWsWNNHr6Blry3iYLMSVHobYS5eKLhHS+5oQiJ2wHs7a6ihFWePDKA7VUuenpjnU0SZKEpUqFgQsXTvVeK4X3EfEcl4WpeVKmzsDoMHrq9P5Y7YxIBEszSyzMLdI/1M/w2BCKLLw8kJdbULkbTrVKbJfpOqQt3WkgkoQoiUkSQRxFJHFMFEfEUVzL5U5ikiSGJEEkgiSJESKGuCaUTV1HN3RSmoaqqYRRRBAleL5fO4CqoigamqahaBqqqqOui2dN09FTBrqmNYT80mqRfP/gieQ17pY+UncfMU2zqaJ3dbUEZo5CV0/TxnAQshBTchQ2WxGa2RzZXL7lVnVaiZk7t7lyboTe3l4qlUpHO5qcln3gdmS44IhYmTRnJ8ZZnFlg6vpths6OkinICO5BKKrC4NggmUKO6RtTVMtV6fl9AHEcM3tnmlc89Opj3W+1XKYn01oFr4qqYtSjUUcUUXEck4iEKIpJophIJGiWRk7X6d4kpg+L7/sIVT92MWeX1/BdB5GEWLpGOqXT1dNafsRdXXmWi2t4KQsr3VrXSB1VVcnlC2RzeVzHYbVcwnCqshBTsiu6rtNVKJBbtyIsVuYwsxkyuYJcMdlGFEWI0KO7uxuA0PcxWuj+dJxEUYSrqgw1IQC1b8T7v3z5K/QMdrfksmOzEUJgr9oszc3T3dND70i/PE+HJAoiZifnsNfK0vN7H27fuEnqGAsqAXzPwyku09dzujP8dqJUKqHlC2SyxzehTpKE5blJ+rrzLR+dDcOQlVKVnoHRthEmnutStcsQhbIQU7IvSZLgui4VT1oRbqdcKqH5ZR688gAAi1NT5IXoyPNTLJexRkZOxT5wO/smPM3dmWXq2jRRGJ3WeNoGRVEo9BYYu3SOarXC9M07REHY7GG1BXpKZ2ziLMPjI9y5Nc3MjdlaDrCkQbXqUF4ucd/lK8e6X6dik5FRwT1JkgQviY892quqKoaZ3VIc2qoYhkFXJsXayhL7xGVaCiudpm9
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ph = 6\n",
"with torch.no_grad():\n",
" timestep = np.array([2])\n",
" predictions = eval_stg.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=500)\n",
"\n",
" predictions_mm = eval_stg.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=1,\n",
" z_mode=True,\n",
" gmm_mode=True)\n",
"\n",
" # Plot predicted timestep for random scene in map\n",
" my_patch = (x_min, y_min, x_max, y_max)\n",
" fig, ax = nusc_map.render_map_patch(my_patch, layers, figsize=(10, 10), alpha=0.1, render_egoposes_range=False)\n",
"\n",
" ax.plot([], [], 'ko-',\n",
" zorder=620,\n",
" markersize=4,\n",
" linewidth=2, alpha=0.7, label='Ours (MM)')\n",
"\n",
" ax.plot([],\n",
" [],\n",
" 'w--o', label='Ground Truth',\n",
" linewidth=3,\n",
" path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])\n",
"\n",
" plot_vehicle_nice(ax,\n",
" predictions,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" plot_vehicle_mm(ax,\n",
" predictions_mm,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" ax.set_ylim((1385, 1435))\n",
" ax.set_xlim((850, 900))\n",
" leg = ax.legend(loc='upper right', fontsize=20, frameon=True)\n",
" ax.axis('off')\n",
" for lh in leg.legendHandles:\n",
" lh.set_alpha(.5)\n",
" ax.get_legend().remove()\n",
" fig.show()\n",
" fig.savefig('plots/qual_nuScenes_map_pos.pdf', dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction without Map Encoding"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Loading from ./models/int_ee/model_registrar-12.pt\n",
"Loaded!\n",
"\n"
]
}
],
"source": [
"model_dir = os.path.join(log_dir, 'int_ee') \n",
"eval_stg_nm, hyp = load_model(model_dir, eval_env, ts=12)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAt4AAAJ0CAYAAAAlEyTXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9d5Rk213f+zmhcuwcpif33DszNwkuQZEgYS0M2PJayNgPbEsGIcOTcMKSQAmJKyNAxrKxLgiwEdZ7No/ntEDYzyQbSxgts4ykK918J/V0TtUVTz57vz+qq7qrQ3X1THdXVc/+rLXXTFc4Z/fpqnO++3d+v+9Pk1JKFAqFQqFQKBQKxbGid3sCCoVCoVAoFArFg4AS3gqFQqFQKBQKxQmghLdCoVAoFAqFQnECKOGtUCgUCoVCoVCcAEp4KxQKhUKhUCgUJ4DZ7skvzd85oWl0Hykl6wurlEpFxiYmSeVTaJrW7Wn1BdWSxdzNWcyIydnps8QS0W5P6dRSq1ncePZFXveabycabT3OlUqFWqWMa1UYSCUYzGeJx+PNAeB5HmEYEgYBoedRrVQoLC0R+gGGrhExDAxdx9R1TE1HNw0i0SimYRAxI5imga7roGnouo6maeiajq5rD/z3RQiB53l4QYArJKGuEU0kiKbTxGIxotFo/dh1SBAEOI5DoVBgbnkVocfIDgySzeU6er/neXiug+fYhL5N1DSIRU2i0SiaprFWrJAbmtj1OToJPM/DqpTxHZtUNE4ymTjUsVE8WARBQM2ycMKAeDpDMpXGMIxuT+tUU63VcArLPHbtIUqzswyl092e0rFTrlZhcJCBwcFj3Y/Wzk7wQRLeUBffdrnG8vwSqWyaoYkR9eXukMALmL89T7VcZerCGbLDuQdeiB0HL3z1Wc6NXeTs+Qv7vkYIQa1Ww7JqBJ6H8D2ECIjokIjFSCYTZJIJ4vE4pUKBsXyeRCJBEATN4ft+/V/PI/A8Qt/H931C30er7wRd09CEQAM03UDTNEzTwIxEME0TTde3RLphYGy+piHStYZw13U0tL4T72EY4nkebhDgS0loGMRSqfqIxYhEIh39PkIIHMfBsiwsy6JSs6haDrbnE4nGMWMJcoODxGOxg+fjuniuhefYmJokGjWIRaNNsb0dx3Eo1XwGRye6JnqDIMCqVnBrNRKRCMlEAtNsGw9SPMCEYYhl2Vi+RyyVIpnOqM/LMbG0MM/UQJJ8Po+3uEguk+n2lI6VIAhY9zzGL1w49vOhEt57EHg+y3NLBH7AxPlJovH2FzxFHSklheUNFu4uMDA4wOSlSXS9f4RUr7O6tMzG0gavfvUb7un9QRDguC6e6xD6AdVqGWd9jVwqScTUSMbjxGMxUok4sVisGaXdvviUUiKEIAzD5r9BECCCoC7OPY9gU6AjRD2XbfNfDeqCe1u0HEACQkqkrL9I2xTq9dcZaPqmSNd1DN3YW7hvRt61zfcfB0EQ4HoeXhjiS4mMRIin08SSSaLRKJFI5MBtOI7TFNkVy6ZctbAdB/QIZixGJFo/9vFkgmikfSRaSonnuriuje/YCOERN816VDsW6yhoUK5U8GSUgaGRjo/DcSCEwKpVsSsVYppOMpnoSiRe0R8IIbAsm5rrEEkkSGay6vNyhAghmL35Eq/5+sepVSqYpRLJZLLb0zpWCpUK8fFxMiewwFDCex+klBSX1ykUCoyMjZIZVBHcTrGqNnO35pBCcHb6HMl0ottT6nvCMOTZP/8Kr3r0m47sNtjG2ipJXSORSBKGIa7r4nkugecThj4iCBCBh2nURXkyESexKcgbY7/IwHaB3hDpwaY4b4j0wPdBSgzqotzQNHQh0NDqYlvT0ABd0wANIUV9u1IiG4P6RUI0/y/R9G3C3TDq2zP0ZgS+kRqzV8pMIyIvpcT3fTzfx5MSX0qMaLSeNpKoi8J2kbYgCLAsC8dxKFdr1CybUrUGuoEZiWFEYsTicaKbi51OIyyt6SMOUVNvpo90Ivz3Yr2wgZkcIJ3J3tP7jxIpJbZlUSuXiSBIxhPNNCmFYidSSmzbouo46NE4qUyGmPq83DeVSgVZLfD4o9dZmZsjI8SpXti4rksZGDt37kR0nhLeB2BXLJbnF4klY4yemVCpJx0ihGTx9hKFtXUmzk0yND6gFi73we2bt4gR59FHnjiS7QVBwMbSIqODgwf+XfzAx/f8pigPfJcwCEAERAydRDxGMhEnlUg0UyzaifLtSClboufNsRk1F0FQF+hCYGwK8cb7ts9bSgmNCPrmNhuivL4IkEgpkLIeXRdSNKPsUkpCKTYj7/WfdUMnmkgwfuECqXR6V+R/+3H0PA/LsqjValRqNpWahSckhhnBjMYxo1HiiSTxWOzQt8V3po8YCGIxc9/0kXtBCMFqoURmcJzYAeksJ4lj21jVCtL3SMViJBJJdQ5R7Itj29QcB2GYpLJZYvG4+rzcI/Ozd7kyOcTo6CjzN28ymjjdNRir5Qr5c2dPbJGvhHcHhEHIyuwirusyfu4M8aRaUXdKaa3M3O1Zkpk0Zy+dwYyqfLzD0q6g8l4pl4qYnk86nbqv7fiBj+d4uJ6DCEJC3yXYFOUx0yQZjxGLR0lsCnHDMDBNE8OoF2lu/387hBDN0WD7RXWvC+zOxzr5ubEAcF2XxbU18iMjzaJJx/PxPB9vM2XH90MwDAwjgmZGiMY6TxPZj+3pI55jI+8hfeRe8DyPQtlmcHSy54ILzUJM2yYVU4WYiva4rotl23hAKpsjkVQLtsMghGD+9su85uufQNM0Vm7dYvQU53dbloWTSDAyMXFi+1TCu0OklBTXNiisrjE0MkxuWEVwO8V1POZvzuE4LmcvnyOTvz+x96Dxwtee49zohbYFlYdBCMH60gLD2dyxiqwWUR4KkBIhAmQYIoUkFCEgYFNMRwwdw4wQNXUipkkkYmLoOtFIpCnYTdPcLOI0m6K9k9+3UQjZyElvppEE9YiyGwiCwCPwBegarh/i6DqZ/CCmGakXh5oGkUgUwzTrRaSGeSQC0Pd9PMfBc20Czz6S9JF7oVqrYQcGA0MjPXluU4WYisNQvxNl40pBIlN3QlELtoMpbmwQ8as8cu1qPVVubu7UOpoIIVitVhm5ePFEz7VthfdC8SWWar1z67EXcCyb5dklojGTkclxzOjJ/bH6GSkkq/OrLC+uMDw2zPjZMTRVeHkg91tQuRdWrUZYKZPr0JbuJJBCEIgQISRhECDCkCAMCIOwnsstQoQIQQikkAgRImUIYV0ox0wTM2ISNQx0Q8cPArxA4LhufQe6jqYZGIaBZhjouom+KZ4Nw8SMRjANoynkVzcKZIZHjyWvca/0kYb7SCwW66ro3dgoQSxNNjfQtTkchCrEVByG7VaEsVSaVDrTc3d1eon5u3e4dn6CwcFBqtXqqXY0OSn7wJ0cGC4YT7lKfG8jnkwwNX2OlfllZm/cYWxqkmRWRXAPQtM1Rs+Oksymmbs5S61cU57fBxCGIQt353jVo990pNutlcsMJHur4FXTdSKNaNQhRVQYhggpCIIQEYQEUmDEDdKmSX6bmO4U13WRunnkYq5SLuLaFlL4xE2DRNQkN9BbfsS5XIa1QhEnGiee6K3PSANd10lnsqTSGWzLYqNcImLVVCGmYk9M0ySXzZLetCIsVBeJpZIk01l1x2QHQRAgfYd8Pg+A77pEeuj8dJQEQYCt64x1IQDVNuI9u/Y8hrl10JUA30JKSWWjwuriEvmBAQYnhnvy9mwvEngBCzOLVIpl5fndhjs3bxE9woJKANdxsAprDA2c7Aq/nyiVShiZLMnU0S2ohRCsLc4wlM/0fHTW933WSzUGRib7Rpg4tk2tUobAV4WYirYIIbBtm6qjrAh3Ui6VMNwyj1y7CsDK7CwZKU/l8SmUy8QnJk7EPnAnbROeNK2K53o0tPl4ymU85Z7IxHodTdPIDmY5e/k8tVqVuVt3CTy/29PqC8yoydnpKcbPTXD39hzzNxfqOcCKJrWaRXmtxENXrh3pdq1qhaSKCu6LEAJHhEce7dV1nUgs1VIc2qt
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ph = 6\n",
"with torch.no_grad():\n",
" timestep = np.array([2])\n",
" predictions = eval_stg_nm.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=500)\n",
"\n",
" predictions_mm = eval_stg_nm.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=1,\n",
" z_mode=True,\n",
" gmm_mode=True)\n",
"\n",
" # Plot predicted timestep for random scene in map\n",
" my_patch = (x_min, y_min, x_max, y_max)\n",
" fig, ax = nusc_map.render_map_patch(my_patch, layers, figsize=(10, 10), alpha=0.1, render_egoposes_range=False)\n",
"\n",
" ax.plot([], [], 'ko-',\n",
" zorder=620,\n",
" markersize=4,\n",
" linewidth=2, alpha=0.7, label='Ours (MM)')\n",
"\n",
" ax.plot([],\n",
" [],\n",
" 'w--o', label='Ground Truth',\n",
" linewidth=3,\n",
" path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])\n",
"\n",
" plot_vehicle_nice(ax,\n",
" predictions,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" plot_vehicle_mm(ax,\n",
" predictions_mm,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" ax.set_ylim((1385, 1435))\n",
" ax.set_xlim((850, 900))\n",
" leg = ax.legend(loc='upper right', fontsize=20, frameon=True)\n",
" ax.axis('off')\n",
" for lh in leg.legendHandles:\n",
" lh.set_alpha(.5)\n",
" ax.get_legend().remove()\n",
" fig.show()\n",
" fig.savefig('plots/qual_nuScenes_no_map_pos.pdf', dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction using velocity output"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Loading from ./models/ee_vel/model_registrar-12.pt\n",
"Loaded!\n",
"\n"
]
}
],
"source": [
"model_dir = os.path.join(log_dir, 'ee_vel') \n",
"eval_stg_vel, hyp = load_model(model_dir, eval_env, ts=12)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAt4AAAJ0CAYAAAAlEyTXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZQk61ne+YslIzIit8qsvbqqum9XVd/bfTeBQBvCLJojA8YenxFgBsYjJEBgrsTYCBAIXVniXls2QjJidAFhJBsvYBgfMAx4sIXw8QEsMx6EdO/tu/VaXfuWVbnFHvHNH1lVXdXdtXR3VeXS3++cPF1VmRXxdVZkxBPv9z7PpwghBBKJRCKRSCQSieRYUVs9AIlEIpFIJBKJ5EFACm+JRCKRSCQSieQEkMJbIpFIJBKJRCI5AaTwlkgkEolEIpFITgApvCUSiUQikUgkkhNA3+/JL81dP6FhtB4hBGvzK1QqGwwOj5DpyaAoSquH1RHUKw6zV2bQUzpjk2OYltHqIXUtjYbD5Rdf4eve/E0Yxu73uVar0ahV8Z0axYxFqSdPOp3efgAEQUAcx8RRRBwE1Gs1youLxGGEpiqkNA1NVdFVFV1RUXWNlGGgaxopPYWua6iqCoqCqqooioKqqKiq8sB/XpIkIQgCgijCTwSxqmBYFkY2i2maGIbRfO8OSRRFeJ5HuVxmdmmFRDXJF0vkC4VD/X4QBAS+R+C5xKGLoWuYho5hGCiKwupGjULv8G3H0UkQBAFOrUrouWSMNLZt3dV7I3mwiKKIhuPgxRHpbA47k0XTtFYPq6upNxp45SUeP3+OyswMvdlsq4d07FTrdSiVKJZKx7ofZb84wQdJeENTfLvVBktzi2TyWXqH++WH+5BEQcTctTnq1TqjZ06R7ys88ELsOHj5+RcZH3yIsdNn9nxNkiQ0Gg0cp0EUBCRhQJJEpFSwTBPbtsjZFul0mkq5zGBPD5ZlEUXR9iMMw+a/QUAUBMRhSBiGxGGI0twJqqKgJAkKoKgaiqKg6xp6KoWu6yiqelOkaxra5mu2RLqyJdxVFQWl48R7HMcEQYAfRYRCEGsaZibTfJgmqVTqUP+fJEnwPA/HcXAch1rDoe54uEFIykijmxaFUom0aR48Ht8n8B0Cz0VXBIahYRrGttjeied5VBohpYHhloneKIpw6jX8RgMrlcK2LHR933qQ5AEmjmMcx8UJA8xMBjubk8fLMbE4P8do0aanp4dgYYFCLtfqIR0rURSxFgQMnTlz7OdDKbzvQBSELM0uEoURw6dHMNL7X/AkTYQQlJfWmb8xT7FUZOTsCKraOUKq3VlZXGJ9cZ03venr7+n3oyjC830C3yMOI+r1Kt7aKoWMTUpXsNNp0qZJxkpjmuZ2lXbnzacQgiRJiON4+98oikiiqCnOg4BoU6CTJM1ets1/FWgK7h3VcgABJEIgRPNFyqZQb75OQ1E3RbqqoqnanYX7ZuVd2fz94yCKIvwgIIhjQiEQqRTpbBbTtjEMg1QqdeA2PM/bFtk1x6Vad3A9D9QUummSMprvfdq2MFL7V6KFEAS+j++7hJ5LkgSkdb1Z1TbNQxUNqrUagTAo9vYf+n04DpIkwWnUcWs1TEXFtq2WVOIlnUGSJDiOS8P3SFkWdi4vj5cjJEkSZq68ypu/+gkatRp6pYJt260e1rFSrtVIDw2RO4EbDCm890AIwcbSGuVymf7BAXIlWcE9LE7dZfbqLCJJGJscx85arR5SxxPHMS/+5Zd53WNvOLJpsPXVFWxVwbJs4jjG932CwCcKQuI4JIkikihA15qi3LbSWJuCfOuxV2Vgp0DfEunRpjjfEulRGIIQaDRFuaYoqEmCgtIU24qCAqiKAigkImluVwjE1oPmRSLZ/lqgqDuEu6Y1t6ep2xX4rdaYO7XMbFXkhRCEYUgQhgRCEAqBZhjNthGrKQr3q7RFUYTjOHieR7XeoOG4VOoNUDX0lImWMjHTaYzNm53DVlh2t494GLq63T5yGOF/J9bK6+h2kWwuf0+/f5QIIXAdh0a1SooEO21tt0lJJLcihMB1Heqeh2qkyeRymPJ4uW9qtRqiXuaJxy6wPDtLLkm6+sbG932qwOD4+InoPCm8D8CtOSzNLWDaJgOnhmXrySFJEsHCtUXKq2sMj4/QO1SUNy73wbUrVzFJ89ijTx7J9qIoYn1xgYFS6cC/SxiFhEG4Lcqj0CeOIkgiUpqKlTaxrTQZy9pusdhPlO9ECLGrer792KyaJ1HUFOhJgrYpxLd+b+e4hRCwVUHf3OaWKG/eBAiESBCiWV1PRLJdZRdCEItks/Le/F7VVAzLYujMGTLZ7G2V/53vYxAEOI5Do9Gg1nCpNRyCRKDpKXQjjW4YpC2btGne9bT4re0jGgmmqe/ZPnIvJEnCSrlCrjSEeUA7y0niuS5OvYYIAzKmiWXZ8hwi2RPPdWl4Hommk8nnMdNpebzcI3MzN5ga6WVgYIC5K1cYsLrbg7FSrdEzPnZiN/lSeB+COIpZnlnA932Gxk+RtuUd9WGprFaZvTaDncsydvYUuiH78e6W/QyV90q1soEehGSzmfvaThiFBF6AH3gkUUwc+kSbotzUdey0iZk2sDaFuKZp6LqOpjVNmju/3o8kSbYfW+y8qN7pAnvrzw7z/dYNgO/7LKyu0tPfv22a9IKQIAgJNlt2wjAGTUPTUih6CsM8fJvIXuxsHwk8F3EP7SP3QhAElKsupYGRtisubBsxXZeMKY2Ykv3xfR/HdQmATL6AZcsbtrshSRLmrr3Gm7/6SRRFYfnqVQa6uL/bcRw8y6J/ePjE9imF9yERQrCxuk55ZZXe/j4KfbKCe1h8L2Duyiye5zM2MU6u5/7E3oPGyy9cZHzgzL6GyrshSRLWFufpyxeOVWTtEuVxAkKQJBEijhGJIE5iIIFNMZ3SVDQ9haGrpHSdVEpHU1WMVGpbsOu6vmni1LdF+2H+v1tGyK2e9O02kqhZUfajhCgKiMIEVAU/jPFUlVxPCV1PNc2hukYqZaDpetNEqulHIgDDMCTwPALfJQrcI2kfuRfqjQZupFHs7W/Lc5s0YkruhuZMlIsvEqxcMwlF3rAdzMb6OqmwzqPnH2m2ys3Odm2iSZIkrNTr9D/00Imea6Xwvks8x2VpZhHD1OkfGUI3Tu6P1cmIRLAyt8LSwjJ9g30MjQ2iSOPlgdyvofJOOI0Gca1K4ZCxdCeBSBKiJCZJBHEUkcQxURwRR3GzlzuJSZIYkgSRCJIkRogY4qZQNnUdPaVjaBqqphJGEUGU4Pl+cweqiqJoaJqGommoqo66KZ41TUc3Uuiati3kV9bL5PoGjqWv8U7tI1vpI6ZptlT0rq9XwMySLxRbNoaDkEZMyd2wM4rQzGTJZHNtN6vTTszduM7508OUSiXq9XpXJ5qcVHzgrchywV2Sti1GJ8dZnlti5vJ1BkdHsPOygnsQiqowMDaAnc8ye2WGRrUhM78PII5j5m/M8rrH3nCk221UqxTt9jK8KqpKaqsadZciKo5jEpEQRTFJFBOJBC2tkdV1enaI6cPi+z5C1Y9czNWqG/iug0hC0rqGZegUiu2VR1wo5Fgtb+AZadJWex0jW6iqSjaXJ5PN4ToO69UKKachjZiSO6LrOoV8nuxmFGG5voCZsbGzeTljcgtRFCFCj56eHgBC3yfVRuenoySKIlxVZbAFBah9K95//KW/ojjQ05bTjq1GCEFtvcbKwiI9xSKl4T75Ph2SKIiYn16gtlGVmd/7cP3KVYwjNFQC+J6HU16lt3iyd/idRKVSQcvlsTNHd0OdJAmrC9P09uTavjobhiFrlQbF/pGOESae69KoVSEKpRFTsi9JkuC6LnVPRhHeSrVSQfOrPHr+EQCWZ2bICdGV70+5WiU9PHwi8YG3sm/D08KNeWYuzRKF0UmNp2NQFIV8Kc/YxGkajTqzV28QBWGrh9UR6IbO2OQoQ+PD3Lg2y9yV+WYPsGSbRsOhulrh3NT5I92uU69hy6rgniRJgpfER17tVVWVlJnZZQ5tV1KpFAXbYGNthX3
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ph = 6\n",
"with torch.no_grad():\n",
" timestep = np.array([2])\n",
" predictions = eval_stg_vel.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=500)\n",
"\n",
" predictions_mm = eval_stg_vel.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=1,\n",
" z_mode=True,\n",
" gmm_mode=True)\n",
"\n",
" # Plot predicted timestep for random scene in map\n",
" my_patch = (x_min, y_min, x_max, y_max)\n",
" fig, ax = nusc_map.render_map_patch(my_patch, layers, figsize=(10, 10), alpha=0.1, render_egoposes_range=False)\n",
"\n",
" ax.plot([], [], 'ko-',\n",
" zorder=620,\n",
" markersize=4,\n",
" linewidth=2, alpha=0.7, label='Ours (MM)')\n",
"\n",
" ax.plot([],\n",
" [],\n",
" 'w--o', label='Ground Truth',\n",
" linewidth=3,\n",
" path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])\n",
"\n",
" plot_vehicle_nice(ax,\n",
" predictions,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" plot_vehicle_mm(ax,\n",
" predictions_mm,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" ax.set_ylim((1385, 1435))\n",
" ax.set_xlim((850, 900))\n",
" leg = ax.legend(loc='upper right', fontsize=20, frameon=True)\n",
" ax.axis('off')\n",
" for lh in leg.legendHandles:\n",
" lh.set_alpha(.5)\n",
" ax.get_legend().remove()\n",
" fig.show()\n",
" fig.savefig('plots/qual_nuScenes_no_map_vel.pdf', dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction using velocity output and map"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Loading from ./models/me_vel/model_registrar-12.pt\n",
"Loaded!\n",
"\n"
]
}
],
"source": [
"model_dir = os.path.join(log_dir, 'me_vel') \n",
"eval_stg_vel_map, hyp = load_model(model_dir, eval_env, ts=12)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAt4AAAJ0CAYAAAAlEyTXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZRk6Vne+btL3CX2JffMyqWq1V3dYoTMIgHGK/bMsc9ogDEyHmyLOccCme4WEmhBMmqE1C0bARIaI8nGxttgwFgcGzBeQAd7BAgjyQgL6K5eas2q3GNf737nj8jMriUrK5fIiBu3vt85daoqMuPeLzIjvvvc93u+55XCMAwRCAQCgUAgEAgEZ4o86gEIBAKBQCAQCAQPA0J4CwQCgUAgEAgEQ0AIb4FAIBAIBAKBYAgI4S0QCAQCgUAgEAwBIbwFAoFAIBAIBIIhoB72xS+vXR/SME6H7/msXVtFN3SmFmaRJGnUQxoonutx/aXryCgsP76ErDyc90tf+eIf8DVf/Y3kcjkAPM+jsrOD020wXcxTzOdIJpMkEglc18V1XTzLYmNtDbvVQgYSioKuqiQUBTWRQNM0dE1HVRUkWUaWZWRJRpalsXwfBUFAz7KwfB9PUTCyWZLpNLquH/p6giCg2+1y49YalWaP/MQU2d2f892EYYhj2/S6LVy7i5FQSJoGlm3jyybZfHGgrykMQzrtFr1mk4yhk0ymBnp8wfHpdDq0bZtkNkcqkxn1cEZGtVohhUNagulUKtJzRhAE7PR6zKysIMsP5zVEIIgC0mFxguMivAE8x2Xt2k3MdJLJuelIT4AnwXM8rr54DVVWWX5iGVmO1+t7EOWdMvWNGm94wzfd8zXP86jXa9i9Hr5jQ+CSS6fIZVKYhkGv2WRpbg7XdXEcB9d1sR0Hp9vFtW0c2yb0fWQk5DDoLwNJMmpCRdd1FFVFluW+MFcUVEV59f8REOonEdtBENBut2k2m9SbbarNFpKiYWZzFIulAy/MnufR67axOi0SckjS0DBMc/8cQRCwXW1QmJxHVQ+9pz8RruvSqteRPYdsOnMm5xAcHc/zaLXbeIpCrlh6KH8fazdXeWS2SNBsMhXxG5B2p4Ofy1GcmBj1UASCh5rYCG/YE9+3MFMmk/PxFd+KpLLy2odLfF/64+e5MP8oM3NzD/zeIAjo9nr0Oh0a9Spho0E6ZZBNmWRSSTLpNKZpouv6/nN838fzPHzf71fKHQfXsrB7FqHnIgNSECCFYV9syzKKohCGIX4QEoQhstIX5vKeMFcUVEVFlqU7hLoiK6f+3R1XbN8utKv1JvV2ByVhoOomZjJJMpU6UDgFQYDV69HrtAh9m6SRwDSM+4qsTqeDFajki5Onen2H0e106DRqpBI6qVQydp/zcaPX69LsWZjZLKl05qH6faxeeYnXX3yE3tY2E5n0qIdzKDvNJsXlZTRNG/VQBIKHmliVKFQtwfzKAjevriKv71Cam4zVRUDVVM5fXOHqi9e4cek6S48/HOK70+niWx6T09NH+n5ZlkmnUqRTKVRZIjU5RSKRwOpZVHo9tpo7eLYFYUAuZZJNp0ilUpimSTKZvOd4vu/vC/M9Ue47Do5tI4UhCqABUrArygE5DAldj8BxcMO+MA/DEN/3CYAguFeoy6q6L8r3hLqiKPsV9YPEdvY+YjsIAprNJs1mk1qzTaPdQU7o6GYaI11kYfrcocvNtm1jdds4VgdDlckldXS98MCffTKZpFOpYdvZO25sBkkylUI3DFr1GlatRjadFmJihJhmEk3TabXaVHs9soUiiURi1MM6cyzbRldkEokETsSnYdu2kZNJ8TkRCCJArCree3iOy62rq2SyWYqzE7ES39CvfF954QqaprF0Mf7i+5WXXmYiOcGF1zx2rOcFQUBlfY2pYvHA94Dv+1g9C8vq4do9fNdBJiCXSpFOmfti/DABeT9R7u6KcjkM7zl3eNtzgzAkCAJ8PyAMA4IgxA+D3cdDfML9fysJleL8PNOzs/eIbcdx6Ha7NJtNKvUmrW4PVTdJ6EmMZJJUKvVAX6fv+/S6HaxOCwWfpKlhGMax/aBWr0fTCilNzRzreSfB6vVo12sYikImnY7dZ33csHo9Gr0uRiZLOpON9e+jWq2Qllxmp6fwtrbIpqNb8a61Wphzc6RSYn+EQDBqYim8ARzLZv36rdiKb9fxuPrCVTRDZ+nRxdiKb8dxufSHf8yf/sa/cOxqjdXr4TTq5O+zSfAg9sR4r9fFcyx8z0GVQjJmklTSQNd1EokEqqru/60oyn2P5fs+wB3vv5P823VddioVXEXBMAy6vR7dnk3PcelZFqEkk9B0FN0klc6QNM0jCeYwDLEti163jed0MTWVpGmcumJZrtQwc1MYpnmq4xyFIAhoNxs4nTbZZArDMM78nIL7EwQBrVYLG4lssRjbKuvazVUeW5hEkSSURoPUAatlUSAIAnYsi9mVldhdBwWCcSS2whv64nvt6ir5UoH8VCl2k87t4nv50UWkGIrvW6uryK7KV732q4/93EatghGGmObpLoj9SnAX27YJgwDfdyEI8D2PMPQhDNFVFV1PkFBVdE1F13QURUHTtDv+vp8Yvn3jp+u6dHo9bMuh53pYtg2SRKPdxZyYwDBSKAkVXTdIaBpaInHsqrTruvS6bexuG02BpKGhG8bAPiO2bVNrO0xMzw3tc2fbNq1aDZ2QdDp93xsiwXCwLItmt4uWTpPJ5mI3/65eeYk3vO4J2rUaScc5M2vVaWl3uwT5PIXiYNOGBALByYiVx/tuNENn/vwia1dvIskyuYlCrCb/hKayfHGZqy9c4/rLq7EU35XNHb7mq7/x2M8LwxCnZ+1HD54GRVFIZzKk75NaEAYBXuDjOi6e79N1HIKeReC7hL5PEAQEgQeej5qQSagaRqL/0dsX1oCi6kiqgqIkkFUF3cyRziUoqipBGKI2m0zOzZ/4dRy0UTJbOBuBqus6es+i224PLW5O13W06WnarSblRp2saZ76pktwcgzDQNM0Wu025W6PbLEYWXF6XPb83YZh0HCcSN/kdT2PiYgnrggEDxOxFt7QF9+zy/NsXF9DlmUyxXhVXnRD4/zjK1y7dI2bl29x7pGF2Ijv8k6ZlJE9kXh2HIeEJA0lr1aSZRKyTEJ9sD3D930c1yXw+haUPWEtPWCcTq+LZp7MQhEEAa1GFcfqHmuj5GnJpFOU63XMI/jLB4UkSWSyOQwzSbNWpVeviejBESLLMrlsFsO2aZS3sVNp0tnc2OdIdzttCrm+mPU8DyWidhrLslB2exsIBIJocOjsV96oDmscZ4qRNJldnqe8tU2r2hj1cAaObmqsPL5Ct93l5uVbhMF93UNjxc7mFosLyyd6rm310BPRE1uKomAaBql0ilQ6ha5pDxTdALbtohsnq966rotvtZkqZsnnzy5t5G5UVcVISHTaraGc73YSiQSlqWm0TI5Kq0m73eEQV53gjNF1nclCEdlxqG5tYlvWqId0KhzLIp/NEARBP2Y0osWcruuSFhYTgSBSHHrF37i5TnWrNqyxnClG0mR2cYGdzW1ateaohzNwbhffa1fXx158t1qtY0UI3o3d7aLr8dhkF4Yhtu+hnVAwa5qGHzIScZBJp7A6jf1NpsMmmU5TnJ7FkSUqtRqO44xkHIL++y+byZBPJulUyjRqFYIgGPWwTkTgORiGge/7RNVk4vs+rqJgDmGDs0AgODqHCu/FlXOs3ViLjfg200nmFhfYXt+Mr/i+uEyz0Rx78b25vsH8zLkTeSdd10UOg9jYCxzHQdH1Ey/PS5KErBq4rjvgkT0YRVFIGSrtZn3o5759DPnSBMliiVqnQ7PVEtXvEaJpGqVCAdX1qWxuYPV6ox7SsfEcp39D6/vIUa12WxbJQrz2NQkEceDQK3luMse55b74buyM7sI5SMxMktlzc2yvb9GutUc9nIGjJ3UuPH6eZr3J+rXNsRQYjuPSrbc5t7RyoufbVg9Djabn8iTYjoN+yk2CmmGOrNqbTqVwrNZIhP/tGKZJaWaWQNPYqVW
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ph = 6\n",
"with torch.no_grad():\n",
" timestep = np.array([2])\n",
" predictions = eval_stg_vel_map.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=500)\n",
"\n",
" predictions_mm = eval_stg_vel_map.predict(scene,\n",
" timestep,\n",
" ph,\n",
" num_samples=1,\n",
" z_mode=True,\n",
" gmm_mode=True)\n",
"\n",
" # Plot predicted timestep for random scene in map\n",
" my_patch = (x_min, y_min, x_max, y_max)\n",
" fig, ax = nusc_map.render_map_patch(my_patch, layers, figsize=(10, 10), alpha=0.1, render_egoposes_range=False)\n",
"\n",
" ax.plot([], [], 'ko-',\n",
" zorder=620,\n",
" markersize=4,\n",
" linewidth=2, alpha=0.7, label='Ours (MM)')\n",
"\n",
" ax.plot([],\n",
" [],\n",
" 'w--o', label='Ground Truth',\n",
" linewidth=3,\n",
" path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])\n",
"\n",
" plot_vehicle_nice(ax,\n",
" predictions,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" plot_vehicle_mm(ax,\n",
" predictions_mm,\n",
" scene.dt,\n",
" max_hl=10,\n",
" ph=ph,\n",
" map=None, x_min=x_min, y_min=y_min)\n",
"\n",
" ax.set_ylim((1385, 1435))\n",
" ax.set_xlim((850, 900))\n",
" leg = ax.legend(loc='upper right', fontsize=20, frameon=True)\n",
" ax.axis('off')\n",
" for lh in leg.legendHandles:\n",
" lh.set_alpha(.5)\n",
" ax.get_legend().remove()\n",
" fig.show()\n",
" fig.savefig('plots/qual_nuScenes_map_vel.pdf', dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAHUCAYAAABYsLELAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3gU1frA8e+mFxIIJASkGyQgIL1LV0FKaMJVpIOggEYFLqCIWDBwEZRelS5FCEogRWqoASKkURJuIiWhJ6QTUn9/5LdzN2RnU0gIJO/neXyecefMmbOz7Jv3lJnVZGVlZSGEECIXo5JugBBCPK8kQAohhAoJkEIIoUICpBBCqJAAKYQQKiRACiGECgmQQgihQgKkEEKokAAphBAqJEAKIYQKCZBCCKFCAqQQQqiQACmEECokQAohhAoJkEIIoUICpBBCqJAAKYQQKiRACiGECgmQQgihQgKkEEKokAAphBAqTIr7BBkZGcTHx5OQkMCjR4/IzMws7lMKIYQqIyMjLC0tsbGxwdbWFmNjY9WymuL82dfU1FSuX7+OlZUVNjY2WFtbY2RkhEajKa5TCiGEqqysLDIzM0lKSiIhIYHk5GRq1aqFmZmZ3vLFFiAzMjKIiIjA3t4eOzu74jiFEEI8lYcPH/LgwQNefvllvZlksY1BxsfHY2VlJcFRCPHcsrOzw8rKivj4eL37iy1AJiQkYGNjU1zVCyFEkbCxsSEhIUHvvmILkI8ePcLa2rq4qhdCiCJhbW3No0eP9O4rtgCZmZmJkZGsIhJCPN+MjIxUV9cUawST2WohxPPOUJySFE8IIVRIgBRCCBUSIIUQQoUESCGEUCEBUrxwhg8fjrOzM926dSuxNixduhRnZ2ecnZ2JjIzMsa9bt244OzszfPjwIjvfmTNnlPO5u7sXqo4ZM2YodYj8KfaHVQhR1lhbW2NjY4OVlVWR1WlsbKzceGFqalpk9QrDJEAKUcQ8PDyKvM6WLVvi7+9f5PUKw6SLLYQQKiRACiGECgmQQgihQsYgSzntjOW///1v/vWvf/Hzzz+zf/9+4uPjmT9/Pn369MlR/saNG2zZsgU/Pz9u377No0ePsLGx4eWXX6ZTp04MHTrU4FOa4uPj2b59O8eOHSM8PJyEhASMjY1xcHCgadOmvPvuu7Rs2dJgmw8ePMi2bdu4ePEiSUlJ2Nvb06JFC0aNGkWjRo2e/qLkw5UrV1i3bh1nz54lJiaG8uXLU79+fYYMGUKPHj0MHtutWzeioqJo3bo1mzdvJjU1lQ4dOhAfH4+TkxOenp4Gj/f29sbV1RWAKVOmMH78eM6cOcOIESMAcHNzY+DAgbmOO3fuHBs2bCAgIIC4uDjs7Ox47bXXGDZsGO3atcvX+7579y6//fYbJ06c4MaNGzx69IgKFSpQp04d3njjDf71r39hYWGR6zh3d3dmzpwJwOnTp4mMjGThwoUEBAQAEBgYmK/zP28kQJYhU6dO5ciRI5iYmGBkZERqamqO/Vu3buWHH34gPT0dyL5H1cLCgpiYGGJiYvD392f9+vWsWLGC5s2b56r/0qVLfPDBBzx48EB5zdzcnIyMDG7evMnNmzfZt28fkydPZvLkyXrb6ObmxoYNG5T/NzEx4f79+3h4eODt7c38+fOL4EoY5uPjw+eff57jOiQmJnLixAlOnDjBsGHDqFChQr7rMzMzo2fPnuzcuZPw8HDCwsKoV6+eavn9+/cD2Q9RcHFxydc5NmzYwLx589A+/9rY2Ji4uDgOHjzIoUOHmDZtWp51eHp68uWXX5KcnKy8Zmpqyv3797l//z5nz55lw4YNrFmzhldeeUW1nsjISMaMGUNCQgIWFha5/p29SKSLXUacP3+eU6dOMW/ePAIDAwkODqZv377Kfh8fH7799lvS09OpXbs2q1evJigoiICAAM6dO8c333yDpaUlDx8+ZPz48dy5cydH/enp6bi6uvLgwQNMTEyYMWMGp0+fJigoiODgYHbs2EGTJk3Iyspi6dKlnD59OlcbDxw4oATH2rVrs2XLFkJCQggJCcHT05Pu3bsza9Ysbt++XWzX6e7du0yfPp309HSsrKyYP38+gYGBBAYGcuLECT7++GN+++03Dhw4UKB6dQOdt7e3arnExER8fX0BaNeuHVWqVMmz7osXLyrB0cHBgZUrVxIUFERQUBCHDh1i6NChLFy4kJCQENU6Tp8+zdSpU0lOTqZhw4asX7+eCxcuEBwcjK+vL//+978pV64ct27dYvTo0cTFxanWtXr1aqpXr46Hh4dy7V5UEiDLiEOHDuHq6sqAAQMwMcnuOGjX02VkZDB37lwAbG1t2bx5M126dFF+p8PW1pZ3332XefPmAdkPQ162bFmO+k+fPs2NGzcAGDFiBKNHj6ZixYpAdgbWtGlTVq5cqZxz165dudq4atUqpV1r1qyhVatWypNWnJyc+Pnnn2nRogU3b94sugvzhA0bNijPBpw9ezb9+/fH3NwcAAcHByZPnsxnn31GaGhogept2bIl1apVAwwHyIMHD/L48WMA+vXrl6+6V69erWSOixYtolu3bspnXL16dWbPns0777zD1atX9R6fmZnJt99+S0ZGBq+88gpbtmyhffv2WFlZodFoqFKlCmPHjmXlypUA3L9/X/ms9Dl58iQrV65UsmS133t5ETwXXezkuBSS41NKuhnPlJWtBVblc4/lFBcTExPeffddvftOnDjB3bt3AXjvvfeoXLmy3nI9e/bk5ZdfJiIiAk9PT+bMmaN8Edu3b8/x48eJjo7GwcFB7/GVKlWibt26XL58mbCwsBz77ty5o2Q4nTt3platWrmO12g0TJw4kePHj+fvTRfCoUOHAKhQoYJq93bkyJGsWrWKpKSkfNer0Wjo06cPq1evNtjN1o5PWltb89Zbb+VZb2pqKseOHQOgQYMGtG7dWm+5SZMmsXPnTvT9BJWfnx8REREATJw4UXWBe+vWrWndujVnz57Fw8OD6dOn6y3XvXt3qlatmmfbXwSSQZYR9erVU33C+/nz55Xtjh07Gqynbdu2ACQlJSlfKsge86pcuTINGjTA3t5e9XjtBM+TweXy5cvKdosWLVSPb9asWZHeoaIrMTGR69evK+dR+zlQc3Nzg21Uoxtwvby8cu2PjY3l1KlTAPTo0QNLS8s867x27ZqS8Rpqk6OjIy+//LLefWfOnFG2GzdubPB8bdq0AbKzyFu3bukt07RpU4N1vEieiwzSqvyzzabKokqVKqnu0wYFgDp16hisp2bNmsr2jRs3cmVB4eHh7Nmzh6CgIKKjo4mJiSEtLU3ZrzsBoCsqKkrZfumll1TPr9FoqF69eq4MtCjofuENtQHQm+HmpW7durz66qtcunQpx0y11l9//aVcqwEDBuSrzvxeN8j+7MLDw3O9rh0aAejfv7/BB8jqfpaRkZF6z2no39qL5rkIkKL4GZp11f3BonLlyhmsRzcLTUxMzLFvyZIlrFixQm83Li+6deWVIebVxsIqSBsK+3tLLi4uXLp0iYiICEJDQ3M8OEI7e12tWjVatWqVr/qK4rrp/qLfk5+pIWpDDKXpl0wlQJYRat3FJ+UV3HR/u0P3N4d27drF8uXLgexg/NFHH9GxY0ccHBywsbFRspLhw4dz9uxZg+fI66c61H4/pCgVVxt69+7NggULyMjIwMvLSwmQDx484Ny5c0D25Exhfq6ksG3WPS4gICBfXXtDStNvUZWedyIKrXz58sp2XhmE7n5bW1tle+3atUB2IN64cSOjRo3CyckJW1vbHF9A3S6aLt0vpVo3XEvtN4yflm4GVlxtqFy5sjKOqzub7eXlRUZGBpDdzc2vorhuup9jbGxsvs9dFkiAFNSuXVvZ1p140eeff/5RtrXjlUlJSVy7dg3IniioX7++3mMzMjL0joEBOdb7GVrnmJ6eXmzLfBwdHfPVBsh5HQpKO1nzzz//KGOp2u518+bNCzS+md/rpj2fPk5OTsr2lStX8n3uskACpMhx65+hJTRZWVnKAm97e3vli6ybVRoa69Te4qiPblC9cOG
"text/plain": [
"<Figure size 72x72 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"my_patch = (0, 0, 1, 1)\n",
"fig, ax = nusc_map.render_map_patch(my_patch, layers, figsize=(1, 1), alpha=0.1, render_egoposes_range=False)\n",
"ax.plot([], [], 'ko',\n",
" zorder=620,\n",
" markersize=4,\n",
" linewidth=2, alpha=0.7, label='Ours (ML)')\n",
"\n",
"ax.plot([],\n",
" [],\n",
" 'w--o', label='Ground Truth',\n",
" linewidth=3,\n",
" path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])\n",
"leg = ax.legend(loc='upper left', fontsize=30, frameon=True)\n",
"for lh in leg.legendHandles:\n",
" lh.set_alpha(.5)\n",
"ax.axis('off')\n",
"ax.grid('off')\n",
"fig.savefig('plots/qual_nuScenes_legend.pdf', dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6 (GenTrajectron)",
"language": "python",
"name": "gentraj"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}