Trajectron-plus-plus/experiments/pedestrians/Result Analysis.ipynb

1608 lines
320 KiB
Text
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"\n",
"from collections import OrderedDict"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def pretty_dataset_name(dataset_name):\n",
" if dataset_name == 'eth':\n",
" return 'ETH - Univ'\n",
" elif dataset_name == 'hotel':\n",
" return 'ETH - Hotel'\n",
" elif dataset_name == 'univ':\n",
" return 'UCY - Univ'\n",
" elif dataset_name == 'zara1':\n",
" return 'UCY - Zara 1'\n",
" elif dataset_name == 'zara2':\n",
" return 'UCY - Zara 2'\n",
" else:\n",
" return dataset_name"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"dataset_names = ['eth', 'hotel', 'univ', 'zara1', 'zara2', 'Average']\n",
"alg_name = \"Ours\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Displacement Error Analysis"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"prior_work_fse_results = {\n",
" 'ETH - Univ': OrderedDict([('Linear', 2.94), ('Vanilla LSTM', 2.41), ('Social LSTM', 2.35), ('Social Attention', 3.74)]),\n",
" 'ETH - Hotel': OrderedDict([('Linear', 0.72), ('Vanilla LSTM', 1.91), ('Social LSTM', 1.76), ('Social Attention', 2.64)]),\n",
" 'UCY - Univ': OrderedDict([('Linear', 1.59), ('Vanilla LSTM', 1.31), ('Social LSTM', 1.40), ('Social Attention', 0.52)]),\n",
" 'UCY - Zara 1': OrderedDict([('Linear', 1.21), ('Vanilla LSTM', 0.88), ('Social LSTM', 1.00), ('Social Attention', 2.13)]),\n",
" 'UCY - Zara 2': OrderedDict([('Linear', 1.48), ('Vanilla LSTM', 1.11), ('Social LSTM', 1.17), ('Social Attention', 3.92)]),\n",
" 'Average': OrderedDict([('Linear', 1.59), ('Vanilla LSTM', 1.52), ('Social LSTM', 1.54), ('Social Attention', 2.59)])\n",
"}\n",
"\n",
"\n",
"# These are for a prediction horizon of 12 timesteps.\n",
"prior_work_ade_results = {\n",
" 'ETH - Univ': OrderedDict([('Linear', 1.33), ('Vanilla LSTM', 1.09), ('Social LSTM', 1.09), ('Social Attention', 0.39)]),\n",
" 'ETH - Hotel': OrderedDict([('Linear', 0.39), ('Vanilla LSTM', 0.86), ('Social LSTM', 0.79), ('Social Attention', 0.29)]),\n",
" 'UCY - Univ': OrderedDict([('Linear', 0.82), ('Vanilla LSTM', 0.61), ('Social LSTM', 0.67), ('Social Attention', 0.20)]),\n",
" 'UCY - Zara 1': OrderedDict([('Linear', 0.62), ('Vanilla LSTM', 0.41), ('Social LSTM', 0.47), ('Social Attention', 0.30)]),\n",
" 'UCY - Zara 2': OrderedDict([('Linear', 0.77), ('Vanilla LSTM', 0.52), ('Social LSTM', 0.56), ('Social Attention', 0.33)]),\n",
" 'Average': OrderedDict([('Linear', 0.79), ('Vanilla LSTM', 0.70), ('Social LSTM', 0.72), ('Social Attention', 0.30)])\n",
"}\n",
"\n",
"linestyles = ['--', '-.', '-', ':']"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"mean_markers = 'X'\n",
"marker_size = 7\n",
"line_colors = ['#1f78b4','#33a02c','#fb9a99','#e31a1c']\n",
"area_colors = ['#80CBE5','#ABCB51', '#F05F78']\n",
"area_rgbs = list()\n",
"for c in area_colors:\n",
" area_rgbs.append([int(c[i:i+2], 16) for i in (1, 3, 5)])"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_attention_radius_3_fde_most_likely.csv\n",
"results/hotel_attention_radius_3_fde_most_likely.csv\n",
"results/univ_attention_radius_3_fde_most_likely.csv\n",
"results/zara1_attention_radius_3_fde_most_likely.csv\n",
"results/zara2_attention_radius_3_fde_most_likely.csv\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>error_value</th>\n",
" <th>error_type</th>\n",
" <th>type</th>\n",
" <th>dataset</th>\n",
" <th>method</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.242668</td>\n",
" <td>fde</td>\n",
" <td>ml</td>\n",
" <td>eth</td>\n",
" <td>Ours</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.158331</td>\n",
" <td>fde</td>\n",
" <td>ml</td>\n",
" <td>eth</td>\n",
" <td>Ours</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.095482</td>\n",
" <td>fde</td>\n",
" <td>ml</td>\n",
" <td>eth</td>\n",
" <td>Ours</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.069288</td>\n",
" <td>fde</td>\n",
" <td>ml</td>\n",
" <td>eth</td>\n",
" <td>Ours</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.734359</td>\n",
" <td>fde</td>\n",
" <td>ml</td>\n",
" <td>eth</td>\n",
" <td>Ours</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" error_value error_type type dataset method\n",
"0 0.242668 fde ml eth Ours\n",
"1 0.158331 fde ml eth Ours\n",
"2 0.095482 fde ml eth Ours\n",
"3 1.069288 fde ml eth Ours\n",
"4 1.734359 fde ml eth Ours"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load Ours\n",
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}*attention_radius_3*fde_most_likely.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']\n",
"perf_df = perf_df.rename(columns={\"metric\": \"error_type\", \"value\": \"error_value\"})"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"# Load Trajectron and GAN\n",
"errors_df = pd.concat([pd.read_csv(f) for f in glob.glob('csv/old/curr_*_errors.csv')], ignore_index=True)\n",
"del errors_df['data_precondition']\n",
"errors_df = errors_df[~(errors_df['method'] == 'our_full')]\n",
"errors_df = errors_df[~(errors_df['error_type'] == 'mse')]\n",
"errors_df.loc[errors_df['error_type'] =='fse', 'error_type'] = 'fde'\n",
"#errors_df.loc[errors_df['error_type'] =='mse', 'error_type'] = 'ade'\n",
"errors_df.loc[errors_df['method'] == 'our_most_likely', 'method'] = 'Trajectron'"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/pandas/core/frame.py:7123: FutureWarning: Sorting because non-concatenation axis is not aligned. A future version\n",
"of pandas will change to not sort by default.\n",
"\n",
"To accept the future behavior, pass 'sort=False'.\n",
"\n",
"To retain the current behavior and silence the warning, pass 'sort=True'.\n",
"\n",
" sort=sort,\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>dataset</th>\n",
" <th>method</th>\n",
" <th>run</th>\n",
" <th>node</th>\n",
" <th>sample</th>\n",
" <th>error_type</th>\n",
" <th>error_value</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>2186000</td>\n",
" <td>hotel</td>\n",
" <td>sgan</td>\n",
" <td>0</td>\n",
" <td>Pedestrian/0</td>\n",
" <td>0</td>\n",
" <td>fde</td>\n",
" <td>4.045972</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2186001</td>\n",
" <td>hotel</td>\n",
" <td>sgan</td>\n",
" <td>0</td>\n",
" <td>Pedestrian/0</td>\n",
" <td>1</td>\n",
" <td>fde</td>\n",
" <td>3.717624</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2186002</td>\n",
" <td>hotel</td>\n",
" <td>sgan</td>\n",
" <td>0</td>\n",
" <td>Pedestrian/0</td>\n",
" <td>2</td>\n",
" <td>fde</td>\n",
" <td>5.378286</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2186003</td>\n",
" <td>hotel</td>\n",
" <td>sgan</td>\n",
" <td>0</td>\n",
" <td>Pedestrian/0</td>\n",
" <td>3</td>\n",
" <td>fde</td>\n",
" <td>4.215567</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2186004</td>\n",
" <td>hotel</td>\n",
" <td>sgan</td>\n",
" <td>0</td>\n",
" <td>Pedestrian/0</td>\n",
" <td>4</td>\n",
" <td>fde</td>\n",
" <td>4.663851</td>\n",
" </tr>\n",
" <tr>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77099995</td>\n",
" <td>zara2</td>\n",
" <td>sgan</td>\n",
" <td>99</td>\n",
" <td>Pedestrian/35</td>\n",
" <td>1995</td>\n",
" <td>fde</td>\n",
" <td>0.620136</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77099996</td>\n",
" <td>zara2</td>\n",
" <td>sgan</td>\n",
" <td>99</td>\n",
" <td>Pedestrian/35</td>\n",
" <td>1996</td>\n",
" <td>fde</td>\n",
" <td>0.681608</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77099997</td>\n",
" <td>zara2</td>\n",
" <td>sgan</td>\n",
" <td>99</td>\n",
" <td>Pedestrian/35</td>\n",
" <td>1997</td>\n",
" <td>fde</td>\n",
" <td>0.860765</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77099998</td>\n",
" <td>zara2</td>\n",
" <td>sgan</td>\n",
" <td>99</td>\n",
" <td>Pedestrian/35</td>\n",
" <td>1998</td>\n",
" <td>fde</td>\n",
" <td>0.545317</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77099999</td>\n",
" <td>zara2</td>\n",
" <td>sgan</td>\n",
" <td>99</td>\n",
" <td>Pedestrian/35</td>\n",
" <td>1999</td>\n",
" <td>fde</td>\n",
" <td>1.027843</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>25700000 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" dataset method run node sample error_type error_value\n",
"2186000 hotel sgan 0 Pedestrian/0 0 fde 4.045972\n",
"2186001 hotel sgan 0 Pedestrian/0 1 fde 3.717624\n",
"2186002 hotel sgan 0 Pedestrian/0 2 fde 5.378286\n",
"2186003 hotel sgan 0 Pedestrian/0 3 fde 4.215567\n",
"2186004 hotel sgan 0 Pedestrian/0 4 fde 4.663851\n",
"... ... ... ... ... ... ... ...\n",
"77099995 zara2 sgan 99 Pedestrian/35 1995 fde 0.620136\n",
"77099996 zara2 sgan 99 Pedestrian/35 1996 fde 0.681608\n",
"77099997 zara2 sgan 99 Pedestrian/35 1997 fde 0.860765\n",
"77099998 zara2 sgan 99 Pedestrian/35 1998 fde 0.545317\n",
"77099999 zara2 sgan 99 Pedestrian/35 1999 fde 1.027843\n",
"\n",
"[25700000 rows x 7 columns]"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"perf_df = perf_df.append(errors_df)\n",
"errors_df"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB/IAAASvCAYAAAAaDLIdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd5hkVZn48e/LAIIkiaKgYMKACJINrGBGVFQQEJGgq+KurjkLK+KKcdc1/QyrDggqiIKKGZCoRBFJogIDwwgISA4CM+/vj3PLvnW7qruqunq6uuf7eZ5+Zu6tG05XV91z73nPeU9kJpIkSZIkSZIkSZIkaTQsN9MFkCRJkiRJkiRJkiRJYwzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkSxo5EbEgIrL6WTDT5ZEkaVkWEefV6uU7J9n2xbVtMyLeNaxjS5I0qiJi/0b9t79lkiTNZaNYz4ximTQ3NT5np8x0eTS3LT/TBZAkSZIk9Sci1gC2ATYCHgKsDNwF3AosAP6SmQtnrICS5qyIWBnYEngcsCawCnAPcDtwDXAFcGVmLpmxQmrOioi1KfXfIyj134qU+u8W4CpK/ffXmSuhpLnIuk+SNFMM5EsaWDVafqPaqp0y85SZKY0kSUVEfAg4tLbqvMzcZkjH3g04trbqRmCDzLx/GMfX7BQRLwZ+XFt1SWY+eRrOsyKwD/AGYFsgJtn+JuBc4CTgZ5l5aeP184Cthl3Ohrb3IiKeDFzUYbslwMaDdj6IiGOB3Tq89MXMfPMgx5Q0JiICeClwIPBcJm9PuiMizgdOBX4GnGtwY+6pRjp+s7bq1MzccRrO82Bgf+D1wBY9bH89cA6l/vtpZv6l8foC2tsypkPbexEROwK/7rDdvcD6mXnbICepRgE+q8NLh2Tmhwc5pqTCuk+DiIizgO0aq3fMzFNnojySZj9T60uSJGmuOQLI2vLWEfHEIR17v8bytw3ia2mIiK2A3wFfpzQMTRjEr6wD7Ax8GrgkIh4zfSWcsuUonRT6FhFrAS8ZbnEktUTERsCJwPHAC+ltUMhqwI7AfwJn4XdUA4qIf6F0APsiPQTxK+tTgm//C/y5ymIzqlYC9hhkx4jYGPiXYRZGUmHdp0FExBMYH8SH8e0I/Rxz40Ya9/kDHOOU+jEGLctc4RQMmm0M5EuSJGlOycxrGD/iad+pHjci1qU04tQdPtXjSpOJiB2A04BNO7x8F3ApcDZwMXDdRIcafumG6jUD7rcXJbWypCGLiEcDZwLP7vDyfcDllJHPFwILKdk1Oh5qWgqoOa3KeHMi8OgOL99OqffOBi4B/jbRoYZfuqEa9D51X0b/d5NmHes+TUG3gP3uVXYZSeqbqfUljZzM3HimyyBJmvUOp73hZZ+I+OAUUxvuDaxQW/5DZl4whePNCpm59UyXYVkWEesBPwLqDT//AL4EzAcuysxs7LMOJfX+i4FXAA/tcvh/A1bvoRgr0T51AJQ09rf3sO+dk7y+hLEO5k+MiG0y89wejltXbzCrH0/SFETECpTv/ga11QkcBXwFOCszH2jssyplyo6dgd2BUc4EMjSZOZ9yTdaQVFlkvkf7vdddlFH2R2bmZR32WZ8yEvKlwMuAtboc/tXAyj0U46HAkY11z+thP4BbJnm9Xl89IyIelZlX9XjslnoHAOs/aQis+3pn3dcuIpaje8fk1SjPT99aeiXSdMpMO+poqTGQL0mSpLno+5QUrKtWyxtSAvsnTuGYzdFSjsbX0nAo8JDa8g3A8zPzD912yMybgJ8CP42ItwC7Au+gfcoJMvOcXgpQNU42nVadZ6ouAB4LtNIe7wv0HMiv0lduW1t1MmUOU0lTdyDwpNryvcBumfnTbjtk5p2UuYFPBd4XEc8C3g4sns6Cak76JKUjWctVwPMy84puO2Tm9cAPgR9GxIGUlPVv77Ddmb0UoEpd39x3KveSdWdSOh2sSBm1uy9wSK87R8QzaQ8WWv9Jw2Hdp0E9j/YOIGcAz6wt74+BfEkDsKemJEmS5pzMvIsyiqtuKvPSbQpsWVv1AONHaElDFRErUtLG1/3rREH8psxcnJk/yMxnThT8mEH3AsfUll9VjYTqVf17fRelE4+k4WjWm4dMFMjoJDNPzcyXZWYzq4fUVUSsSRlVX7d3P/VYZt6fmUdl5taZeetwSzgUfwd+UlvuN71+/ft5A/CLKZdIElj3aXDNz86HgfNryztFxCOXXnEkzRUG8iVJkjRXNUfMv7zLyOJeNB/Kf56ZE83FKg3DdrSnvl9Ee6P/XFH/rq4N7NLLTlX6yn1qq37A5Kn8JfUgItaipAluWQJ8bYaKo2XPjrRnEf1DZp41Q2WZTvX679HVKPtJRcTKlGwDLd+mdDKVNAXWfRpURKxBmdKlZRHwa9o7/7eyr0hSX0ytL2lOq0Z0PR14MiUt7e3AQuDUzJxszrpez/FIYGvK/HlrArcB1wNnVqn9pnLsFYDHU9J6rU+ZU+kuSu/9PwK/a87NNQwRsQHl4eXhlAb1W4HjMvOvwz6XJE2j0yhpWB9VLa9CmbNwfj8HiYh5lLlU63o+RkSsRqmHHk+Zq3VlynX1JuD8zPxLP+Xp8ZxBySCwJbAucA+lbjojMxcO+3wzofq7bAJsCjyMEvC+h1JH/gk4LzPvm7kSDsWGjeXLMzM7bjmLZeaZEfEXSop9KA1cx/ew63Nof48Op3wWJE3dBo3lmzLz5qVdiIjYgvIstB4lzfrfKM9zZ2TmPUM8z+OBpwDrUOrqe4EbgcuAC6e7PomI9Sj3Co+hPLcuT6nPrgfOnupz5SzUrP/+OCOlmH4/pdwPrlMt70tJxTyZl9He0e9wYKfhFk1aJln3WfcNag/Kc37LdzJzSUR8B/g0MK9avy/w0aVduOlSPZNvTXmOWw94EOUzdBWlXfwfQz7fesAOlDaeFSh16KXAWZk566ayqL6DT6W8d6tQfp+/Ur7rt03D+Tan/L3WA/5B+a79JjMXDPtcGi4D+ZJGTkQsADaqFq/OzI0n2PbDwH/WVu2UmadExIOA9wBvo9yMNi2OiGOB92bm1QOUcUXgTcAbaJ87qy4j4nzg0Mz8UR/HXodyA/hiys3JRKNH74qI7wIf7ycQFBH1IMCpmbljtX4Xyhy6OzI+a8siemtUl6SRkJkZEUfQXk/sS5+BfMp8ow+vLf8
"text/plain": [
"<Figure size 2400x1200 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with sns.color_palette(\"muted\"):\n",
" fig_fse, ax_fses = plt.subplots(nrows=1, ncols=6, figsize=(8, 4), dpi=300, sharey=True)\n",
" for idx, ax_fse in enumerate(ax_fses):\n",
" dataset_name = dataset_names[idx]\n",
" if dataset_name != 'Average':\n",
" specific_df = perf_df[(perf_df['dataset'] == dataset_name) & (perf_df['error_type'] == 'fde')]\n",
" specific_df['dataset'] = pretty_dataset_name(dataset_name)\n",
" else:\n",
" specific_df = perf_df[(perf_df['error_type'] == 'fde')].copy()\n",
" specific_df['dataset'] = 'Average'\n",
"\n",
" sns.boxplot(x='dataset', y='error_value', hue='method',\n",
" data=specific_df, ax=ax_fse, showfliers=False,\n",
" palette=area_colors, hue_order=['sgan', 'Trajectron', alg_name], width=2.)\n",
" \n",
" ax_fse.get_legend().remove()\n",
" ax_fse.set_xlabel('')\n",
" ax_fse.set_ylabel('' if idx > 0 else 'Final Displacement Error (m)')\n",
"\n",
" ax_fse.scatter([-0.665, 0, 0.665],\n",
" [np.mean(specific_df[specific_df['method'] == 'sgan']['error_value']),\n",
" np.mean(specific_df[specific_df['method'] == 'Trajectron']['error_value']),\n",
" np.mean(specific_df[specific_df['method'] == alg_name]['error_value'])],\n",
" s=marker_size*marker_size, c=np.asarray(area_rgbs)/255.0, marker=mean_markers,\n",
" edgecolors='#545454', zorder=10)\n",
" \n",
" for baseline_idx, (baseline, fse_val) in enumerate(prior_work_fse_results[pretty_dataset_name(dataset_name)].items()):\n",
" ax_fse.axhline(y=fse_val, label=baseline, color=line_colors[baseline_idx], linestyle=linestyles[baseline_idx])\n",
" \n",
" if idx == 0:\n",
" handles, labels = ax_fse.get_legend_handles_labels()\n",
"\n",
"\n",
" handles = [handles[0], handles[4], handles[1], handles[5], handles[2], handles[6], handles[3]]\n",
" labels = [labels[0], 'Social GAN', labels[1], 'Trajectron', labels[2], alg_name, labels[3]]\n",
"\n",
" ax_fse.legend(handles, labels, \n",
" loc='lower center', bbox_to_anchor=(0.5, 0.9),\n",
" ncol=4, borderaxespad=0, frameon=False,\n",
" bbox_transform=fig_fse.transFigure)\n",
"\n",
"\n",
"# fig_fse.text(0.51, 0.03, 'Dataset', ha='center')\n",
"\n",
"plt.savefig('plots/fde_boxplots.pdf', dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"del perf_df\n",
"del errors_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Average Displacement Error"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_attention_radius_3_ade_most_likely.csv\n",
"results/hotel_attention_radius_3_ade_most_likely.csv\n",
"results/univ_attention_radius_3_ade_most_likely.csv\n",
"results/zara1_attention_radius_3_ade_most_likely.csv\n",
"results/zara2_attention_radius_3_ade_most_likely.csv\n"
]
}
],
"source": [
"# Load Ours\n",
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}*attention_radius_3*ade_most_likely.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']\n",
"perf_df = perf_df.rename(columns={\"metric\": \"error_type\", \"value\": \"error_value\"})\n",
"#perf_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
"# Load Trajectron and GAN\n",
"errors_df = pd.concat([pd.read_csv(f) for f in glob.glob('old/curr_*_errors.csv')], ignore_index=True)\n",
"del errors_df['data_precondition']\n",
"errors_df = errors_df[~(errors_df['method'] == 'our_full')]\n",
"errors_df = errors_df[~(errors_df['error_type'] == 'fse')]\n",
"#errors_df.loc[errors_df['error_type'] =='fse', 'error_type'] = 'fde'\n",
"errors_df.loc[errors_df['error_type'] =='mse', 'error_type'] = 'ade'\n",
"errors_df.loc[errors_df['method'] == 'our_most_likely', 'method'] = 'Trajectron'"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/pandas/core/frame.py:7123: FutureWarning: Sorting because non-concatenation axis is not aligned. A future version\n",
"of pandas will change to not sort by default.\n",
"\n",
"To accept the future behavior, pass 'sort=False'.\n",
"\n",
"To retain the current behavior and silence the warning, pass 'sort=True'.\n",
"\n",
" sort=sort,\n"
]
}
],
"source": [
"perf_df = perf_df.append(errors_df)\n",
"del errors_df"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n",
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" import sys\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB/IAAASvCAYAAAAaDLIdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd5hkVZn48e/LAIIkiaKgYMKACJINrGBGVFQQEJGgq+KurjkLK+KKcdc1/QyrDggqiIKKGZCoRBFJogIDwwgISA4CM+/vj3PLvnW7qruqunq6uuf7eZ5+Zu6tG05XV91z73nPeU9kJpIkSZIkSZIkSZIkaTQsN9MFkCRJkiRJkiRJkiRJYwzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkSxo5EbEgIrL6WTDT5ZEkaVkWEefV6uU7J9n2xbVtMyLeNaxjS5I0qiJi/0b9t79lkiTNZaNYz4ximTQ3NT5np8x0eTS3LT/TBZAkSZIk9Sci1gC2ATYCHgKsDNwF3AosAP6SmQtnrICS5qyIWBnYEngcsCawCnAPcDtwDXAFcGVmLpmxQmrOioi1KfXfIyj134qU+u8W4CpK/ffXmSuhpLnIuk+SNFMM5EsaWDVafqPaqp0y85SZKY0kSUVEfAg4tLbqvMzcZkjH3g04trbqRmCDzLx/GMfX7BQRLwZ+XFt1SWY+eRrOsyKwD/AGYFsgJtn+JuBc4CTgZ5l5aeP184Cthl3Ohrb3IiKeDFzUYbslwMaDdj6IiGOB3Tq89MXMfPMgx5Q0JiICeClwIPBcJm9PuiMizgdOBX4GnGtwY+6pRjp+s7bq1MzccRrO82Bgf+D1wBY9bH89cA6l/vtpZv6l8foC2tsypkPbexEROwK/7rDdvcD6mXnbICepRgE+q8NLh2Tmhwc5pqTCuk+DiIizgO0aq3fMzFNnojySZj9T60uSJGmuOQLI2vLWEfHEIR17v8bytw3ia2mIiK2A3wFfpzQMTRjEr6wD7Ax8GrgkIh4zfSWcsuUonRT6FhFrAS8ZbnEktUTERsCJwPHAC+ltUMhqwI7AfwJn4XdUA4qIf6F0APsiPQTxK+tTgm//C/y5ymIzqlYC9hhkx4jYGPiXYRZGUmHdp0FExBMYH8SH8e0I/Rxz40Ya9/kDHOOU+jEGLctc4RQMmm0M5EuSJGlOycxrGD/iad+pHjci1qU04tQdPtXjSpOJiB2A04BNO7x8F3ApcDZwMXDdRIcafumG6jUD7rcXJbWypCGLiEcDZwLP7vDyfcDllJHPFwILKdk1Oh5qWgqoOa3KeHMi8OgOL99OqffOBi4B/jbRoYZfuqEa9D51X0b/d5NmHes+TUG3gP3uVXYZSeqbqfUljZzM3HimyyBJmvUOp73hZZ+I+OAUUxvuDaxQW/5DZl4whePNCpm59UyXYVkWEesBPwLqDT//AL4EzAcuysxs7LMOJfX+i4FXAA/tcvh/A1bvoRgr0T51AJQ09rf3sO+dk7y+hLEO5k+MiG0y89wejltXbzCrH0/SFETECpTv/ga11QkcBXwFOCszH2jssyplyo6dgd2BUc4EMjSZOZ9yTdaQVFlkvkf7vdddlFH2R2bmZR32WZ8yEvKlwMuAtboc/tXAyj0U46HAkY11z+thP4BbJnm9Xl89IyIelZlX9XjslnoHAOs/aQis+3pn3dcuIpaje8fk1SjPT99aeiXSdMpMO+poqTGQL0mSpLno+5QUrKtWyxtSAvsnTuGYzdFSjsbX0nAo8JDa8g3A8zPzD912yMybgJ8CP42ItwC7Au+gfcoJMvOcXgpQNU42nVadZ6ouAB4LtNIe7wv0HMiv0lduW1t1MmUOU0lTdyDwpNryvcBumfnTbjtk5p2UuYFPBd4XEc8C3g4sns6Cak76JKUjWctVwPMy84puO2Tm9cAPgR9GxIGUlPVv77Ddmb0UoEpd39x3KveSdWdSOh2sSBm1uy9wSK87R8QzaQ8WWv9Jw2Hdp0E9j/YOIGcAz6wt74+BfEkDsKemJEmS5pzMvIsyiqtuKvPSbQpsWVv1AONHaElDFRErUtLG1/3rREH8psxcnJk/yMxnThT8mEH3AsfUll9VjYTqVf17fRelE4+k4WjWm4dMFMjoJDNPzcyXZWYzq4fUVUSsSRlVX7d3P/VYZt6fmUdl5taZeetwSzgUfwd+UlvuN71+/ft5A/CLKZdIElj3aXDNz86HgfNryztFxCOXXnEkzRUG8iVJkjRXNUfMv7zLyOJeNB/Kf56ZE83FKg3DdrSnvl9Ee6P/XFH/rq4N7NLLTlX6yn1qq37A5Kn8JfUgItaipAluWQJ8bYaKo2XPjrRnEf1DZp41Q2WZTvX679HVKPtJRcTKlGwDLd+mdDKVNAXWfRpURKxBmdKlZRHwa9o7/7eyr0hSX0ytL2lOq0Z0PR14MiUt7e3AQuDUzJxszrpez/FIYGvK/HlrArcB1wNnVqn9pnLsFYDHU9J6rU+ZU+kuSu/9PwK/a87NNQwRsQHl4eXhlAb1W4HjMvOvwz6XJE2j0yhpWB9VLa9CmbNwfj8HiYh5lLlU63o+RkSsRqmHHk+Zq3VlynX1JuD8zPxLP+Xp8ZxBySCwJbAucA+lbjojMxcO+3wzofq7bAJsCjyMEvC+h1JH/gk4LzPvm7kSDsWGjeXLMzM7bjmLZeaZEfEXSop9KA1cx/ew63Nof48Op3wWJE3dBo3lmzLz5qVdiIjYgvIstB4lzfrfKM9zZ2TmPUM8z+OBpwDrUOrqe4EbgcuAC6e7PomI9Sj3Co+hPLcuT6nPrgfOnupz5SzUrP/+OCOlmH4/pdwPrlMt70tJxTyZl9He0e9wYKfhFk1aJln3WfcNag/Kc37LdzJzSUR8B/g0MK9avy/w0aVduOlSPZNvTXmOWw94EOUzdBWlXfwfQz7fesAOlDaeFSh16KXAWZk566ayqL6DT6W8d6tQfp+/Ur7rt03D+Tan/L3WA/5B+a79JjMXDPtcGi4D+ZJGTkQsADaqFq/OzI0n2PbDwH/WVu2UmadExIOA9wBvo9yMNi2OiGOB92bm1QOUcUXgTcAbaJ87qy4j4nzg0Mz8UR/HXodyA/hiys3JRKNH74qI7wIf7ycQFBH1IMCpmbljtX4Xyhy6OzI+a8siemtUl6SRkJkZEUfQXk/sS5+BfMp8ow+vLf8
"text/plain": [
"<Figure size 2400x1200 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with sns.color_palette(\"muted\"):\n",
" fig_fse, ax_fses = plt.subplots(nrows=1, ncols=6, figsize=(8, 4), dpi=300, sharey=True)\n",
" for idx, ax_fse in enumerate(ax_fses):\n",
" dataset_name = dataset_names[idx]\n",
" if dataset_name != 'Average':\n",
" specific_df = perf_df[(perf_df['dataset'] == dataset_name) & (perf_df['error_type'] == 'ade')]\n",
" specific_df['dataset'] = pretty_dataset_name(dataset_name)\n",
" else:\n",
" specific_df = perf_df[(perf_df['error_type'] == 'ade')].copy()\n",
" specific_df['dataset'] = 'Average'\n",
"\n",
" sns.boxplot(x='dataset', y='error_value', hue='method',\n",
" data=specific_df, ax=ax_fse, showfliers=False,\n",
" palette=area_colors, hue_order=['sgan', 'Trajectron', alg_name], width=2.)\n",
"\n",
" ax_fse.get_legend().remove()\n",
" ax_fse.set_xlabel('')\n",
" ax_fse.set_ylabel('' if idx > 0 else 'Average Displacement Error (m)')\n",
"\n",
" ax_fse.scatter([-0.665, 0, 0.665],\n",
" [np.mean(specific_df[specific_df['method'] == 'sgan']['error_value']),\n",
" np.mean(specific_df[specific_df['method'] == 'Trajectron']['error_value']),\n",
" np.mean(specific_df[specific_df['method'] == alg_name]['error_value'])],\n",
" s=marker_size*marker_size, c=np.asarray(area_rgbs)/255.0, marker=mean_markers,\n",
" edgecolors='#545454', zorder=10)\n",
" \n",
" for baseline_idx, (baseline, fse_val) in enumerate(prior_work_ade_results[pretty_dataset_name(dataset_name)].items()):\n",
" ax_fse.axhline(y=fse_val, label=baseline, color=line_colors[baseline_idx], linestyle=linestyles[baseline_idx])\n",
" \n",
" if idx == 0:\n",
" handles, labels = ax_fse.get_legend_handles_labels()\n",
"\n",
"\n",
" handles = [handles[0], handles[4], handles[1], handles[5], handles[2], handles[6], handles[3]]\n",
" labels = [labels[0], 'Social GAN', labels[1], 'Trajectron', labels[2], alg_name, labels[3]]\n",
"\n",
" ax_fse.legend(handles, labels, \n",
" loc='lower center', bbox_to_anchor=(0.5, 0.9),\n",
" ncol=4, borderaxespad=0, frameon=False,\n",
" bbox_transform=fig_fse.transFigure)\n",
"\n",
"# fig_fse.text(0.51, 0.03, 'Dataset', ha='center')\n",
"\n",
"plt.savefig('plots/ade_boxplots.pdf', dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# KDE Negative Log Likelihood Attention Radius 3m"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_12_kde_full.csv\n",
"results/hotel_12_kde_full.csv\n",
"results/univ_12_kde_full.csv\n",
"results/zara1_12_kde_full.csv\n",
"results/zara2_12_kde_full.csv\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/borisi/anaconda3/envs/gentraj/lib/python3.6/site-packages/pandas/core/frame.py:7123: FutureWarning: Sorting because non-concatenation axis is not aligned. A future version\n",
"of pandas will change to not sort by default.\n",
"\n",
"To accept the future behavior, pass 'sort=False'.\n",
"\n",
"To retain the current behavior and silence the warning, pass 'sort=True'.\n",
"\n",
" sort=sort,\n"
]
}
],
"source": [
"# Load Ours\n",
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_12*kde_full.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True)\n",
" del perf_df['Unnamed: 0']\n",
"#perf_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# # Load Trajectron and SGAN\n",
"# lls_df = pd.concat([pd.read_csv(f) for f in glob.glob('csv/old/curr_*_lls.csv')], ignore_index=True)\n",
"# lls_df.loc[lls_df['method'] == 'our_full', 'method'] = 'Trajectron'\n",
"# lls_df['error_type'] = 'KDE'\n",
"# #lls_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"KDE NLL for ETH - Univ\n",
"Ours: 1.3068431681049446\n",
"KDE NLL for ETH - Hotel\n",
"Ours: -1.939345347471224\n",
"KDE NLL for UCY - Univ\n",
"Ours: -1.1288059163920086\n",
"KDE NLL for UCY - Zara 1\n",
"Ours: -1.4119791272274707\n",
"KDE NLL for UCY - Zara 2\n",
"Ours: -2.525154634369542\n",
"KDE NLL for Average\n",
"Ours: -1.392358401395975\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" if dataset != 'Average':\n",
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan')]['log-likelihood'].mean()}\")\n",
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron')]['log-likelihood'].mean()}\")\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Most Likely FDE Attention Radius 3m"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_12_fde_most_likely.csv\n",
"results/hotel_12_fde_most_likely.csv\n",
"results/univ_12_fde_most_likely.csv\n",
"results/zara1_12_fde_most_likely.csv\n",
"results/zara2_12_fde_most_likely.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_12*fde_most_likely.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = 'Trajectron++'\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FDE Most Likely for ETH - Univ\n",
"Ours: 1.6763703267322265\n",
"FDE Most Likely for ETH - Hotel\n",
"Ours: 0.4614785391399073\n",
"FDE Most Likely for UCY - Univ\n",
"Ours: 1.0747924515297491\n",
"FDE Most Likely for UCY - Zara 1\n",
"Ours: 0.7704666189102252\n",
"FDE Most Likely for UCY - Zara 2\n",
"Ours: 0.5865659029486421\n",
"FDE Most Likely for Average\n",
"Ours: 0.9542581296327649\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('FDE Most Likely for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Most Likely ADE Attention Radius 3m"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_12_ade_most_likely.csv\n",
"results/hotel_12_ade_most_likely.csv\n",
"results/univ_12_ade_most_likely.csv\n",
"results/zara1_12_ade_most_likely.csv\n",
"results/zara2_12_ade_most_likely.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_12*ade_most_likely.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = 'Trajectron++'\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ADE Most Likely for ETH - Univ\n",
"Ours: 0.7052341524825474\n",
"ADE Most Likely for ETH - Hotel\n",
"Ours: 0.21620184785291033\n",
"ADE Most Likely for UCY - Univ\n",
"Ours: 0.40926643885853664\n",
"ADE Most Likely for UCY - Zara 1\n",
"Ours: 0.2972134362490682\n",
"ADE Most Likely for UCY - Zara 2\n",
"Ours: 0.22585898118058487\n",
"ADE Most Likely for Average\n",
"Ours: 0.3661968268243691\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('ADE Most Likely for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Best of 20 Evaluation FDE Attention Radius 3m"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_12_fde_best_of.csv\n",
"results/hotel_12_fde_best_of.csv\n",
"results/univ_12_fde_best_of.csv\n",
"results/zara1_12_fde_best_of.csv\n",
"results/zara2_12_fde_best_of.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_12*fde_best_of.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FDE Best of 20 for ETH - Univ\n",
"Trajectron++: 0.8574431969131285\n",
"FDE Best of 20 for ETH - Hotel\n",
"Trajectron++: 0.19084707198210932\n",
"FDE Best of 20 for UCY - Univ\n",
"Trajectron++: 0.4287221576716801\n",
"FDE Best of 20 for UCY - Zara 1\n",
"Trajectron++: 0.3159092794417088\n",
"FDE Best of 20 for UCY - Zara 2\n",
"Trajectron++: 0.25292433731989494\n",
"FDE Best of 20 for Average\n",
"Trajectron++: 0.38676102425417497\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('FDE Best of 20 for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Best of 20 Evaluation ADE Attention Radius 3m"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_12_ade_best_of.csv\n",
"results/hotel_12_ade_best_of.csv\n",
"results/univ_12_ade_best_of.csv\n",
"results/zara1_12_ade_best_of.csv\n",
"results/zara2_12_ade_best_of.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_12*ade_best_of.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ADE Best of 20 for ETH - Univ\n",
"Trajectron++: 0.4334745882075259\n",
"ADE Best of 20 for ETH - Hotel\n",
"Trajectron++: 0.12103208674836993\n",
"ADE Best of 20 for UCY - Univ\n",
"Trajectron++: 0.2200696369591878\n",
"ADE Best of 20 for UCY - Zara 1\n",
"Trajectron++: 0.16745747931164434\n",
"ADE Best of 20 for UCY - Zara 2\n",
"Trajectron++: 0.12485689651978099\n",
"ADE Best of 20 for Average\n",
"Trajectron++: 0.1987725413014945\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('ADE Best of 20 for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# KDE Negative Log Likelihood Attention Radius 3m Velocity"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_vel_12_kde_full.csv\n",
"results/hotel_vel_12_kde_full.csv\n",
"results/univ_vel_12_kde_full.csv\n",
"results/zara1_vel_12_kde_full.csv\n",
"results/zara2_vel_12_kde_full.csv\n"
]
}
],
"source": [
"# Load Ours\n",
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_vel_12*kde_full.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True)\n",
" del perf_df['Unnamed: 0']\n",
"#perf_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# # Load Trajectron and SGAN\n",
"# lls_df = pd.concat([pd.read_csv(f) for f in glob.glob('csv/old/curr_*_lls.csv')], ignore_index=True)\n",
"# lls_df.loc[lls_df['method'] == 'our_full', 'method'] = 'Trajectron'\n",
"# lls_df['error_type'] = 'KDE'\n",
"# #lls_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"KDE NLL for ETH - Univ\n",
"Ours: 1.7987703559884305\n",
"KDE NLL for ETH - Hotel\n",
"Ours: -1.2864991518790894\n",
"KDE NLL for UCY - Univ\n",
"Ours: -0.8897570611371921\n",
"KDE NLL for UCY - Zara 1\n",
"Ours: -1.1275849983603234\n",
"KDE NLL for UCY - Zara 2\n",
"Ours: -2.1946898640072234\n",
"KDE NLL for Average\n",
"Ours: -1.1171728799903844\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" if dataset != 'Average':\n",
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan')]['log-likelihood'].mean()}\")\n",
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron')]['log-likelihood'].mean()}\")\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Most Likely FDE Attention Radius 3m Velocity"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_vel_12_fde_most_likely.csv\n",
"results/hotel_vel_12_fde_most_likely.csv\n",
"results/univ_vel_12_fde_most_likely.csv\n",
"results/zara1_vel_12_fde_most_likely.csv\n",
"results/zara2_vel_12_fde_most_likely.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_vel_12*fde_most_likely.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = 'Trajectron++'\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FDE Most Likely for ETH - Univ\n",
"Ours: 1.6595803163079856\n",
"FDE Most Likely for ETH - Hotel\n",
"Ours: 0.4577661486145857\n",
"FDE Most Likely for UCY - Univ\n",
"Ours: 1.1657061102834114\n",
"FDE Most Likely for UCY - Zara 1\n",
"Ours: 0.792450220754066\n",
"FDE Most Likely for UCY - Zara 2\n",
"Ours: 0.5878910318410495\n",
"FDE Most Likely for Average\n",
"Ours: 1.0204553297895693\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('FDE Most Likely for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Most Likely Evaluation ADE Attention Radius 3m Velocity"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_vel_12_ade_most_likely.csv\n",
"results/hotel_vel_12_ade_most_likely.csv\n",
"results/univ_vel_12_ade_most_likely.csv\n",
"results/zara1_vel_12_ade_most_likely.csv\n",
"results/zara2_vel_12_ade_most_likely.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_vel_12*ade_most_likely.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = 'Trajectron++'\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ADE Most Likely for ETH - Univ\n",
"Ours: 0.7081645037230024\n",
"ADE Most Likely for ETH - Hotel\n",
"Ours: 0.21800273126936145\n",
"ADE Most Likely for UCY - Univ\n",
"Ours: 0.4429943056454118\n",
"ADE Most Likely for UCY - Zara 1\n",
"Ours: 0.30200377722385563\n",
"ADE Most Likely for UCY - Zara 2\n",
"Ours: 0.22635933788153614\n",
"ADE Most Likely for Average\n",
"Ours: 0.3907335607353219\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('ADE Most Likely for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Best of 20 Evaluation FDE Attention Radius 3m Velocity"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_vel_12_fde_best_of.csv\n",
"results/hotel_vel_12_fde_best_of.csv\n",
"results/univ_vel_12_fde_best_of.csv\n",
"results/zara1_vel_12_fde_best_of.csv\n",
"results/zara2_vel_12_fde_best_of.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_vel_12*fde_best_of.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FDE Best of 20 for ETH - Univ\n",
"Trajectron++: 0.8281783496158873\n",
"FDE Best of 20 for ETH - Hotel\n",
"Trajectron++: 0.20567153121359805\n",
"FDE Best of 20 for UCY - Univ\n",
"Trajectron++: 0.44246075609459684\n",
"FDE Best of 20 for UCY - Zara 1\n",
"Trajectron++: 0.3307553133179889\n",
"FDE Best of 20 for UCY - Zara 2\n",
"Trajectron++: 0.24909920986736078\n",
"FDE Best of 20 for Average\n",
"Trajectron++: 0.39711722810872235\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('FDE Best of 20 for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Best of 20 Evaluation ADE Attention Radius 3m Velocity"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"results/eth_vel_12_ade_best_of.csv\n",
"results/hotel_vel_12_ade_best_of.csv\n",
"results/univ_vel_12_ade_best_of.csv\n",
"results/zara1_vel_12_ade_best_of.csv\n",
"results/zara2_vel_12_ade_best_of.csv\n"
]
}
],
"source": [
"perf_df = pd.DataFrame()\n",
"for dataset in dataset_names:\n",
" for f in glob.glob(f\"results/{dataset}_vel_12*ade_best_of.csv\"):\n",
" print(f)\n",
" dataset_df = pd.read_csv(f)\n",
" dataset_df['dataset'] = dataset\n",
" dataset_df['method'] = alg_name\n",
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
" del perf_df['Unnamed: 0']"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ADE Best of 20 for ETH - Univ\n",
"Trajectron++: 0.39238413417612106\n",
"ADE Best of 20 for ETH - Hotel\n",
"Trajectron++: 0.11769507047457346\n",
"ADE Best of 20 for UCY - Univ\n",
"Trajectron++: 0.1990357831124613\n",
"ADE Best of 20 for UCY - Zara 1\n",
"Trajectron++: 0.15218487860383714\n",
"ADE Best of 20 for UCY - Zara 2\n",
"Trajectron++: 0.11350252815738102\n",
"ADE Best of 20 for Average\n",
"Trajectron++: 0.1802170043575296\n"
]
}
],
"source": [
"for dataset in dataset_names:\n",
" print('ADE Best of 20 for ' + pretty_dataset_name(dataset))\n",
" if dataset != 'Average':\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
" else:\n",
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"del perf_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6 (GenTrajectron)",
"language": "python",
"name": "gentraj"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}