1608 lines
320 KiB
Text
1608 lines
320 KiB
Text
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"%matplotlib inline\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import glob\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import seaborn as sns\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"from collections import OrderedDict"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def pretty_dataset_name(dataset_name):\n",
|
|||
|
" if dataset_name == 'eth':\n",
|
|||
|
" return 'ETH - Univ'\n",
|
|||
|
" elif dataset_name == 'hotel':\n",
|
|||
|
" return 'ETH - Hotel'\n",
|
|||
|
" elif dataset_name == 'univ':\n",
|
|||
|
" return 'UCY - Univ'\n",
|
|||
|
" elif dataset_name == 'zara1':\n",
|
|||
|
" return 'UCY - Zara 1'\n",
|
|||
|
" elif dataset_name == 'zara2':\n",
|
|||
|
" return 'UCY - Zara 2'\n",
|
|||
|
" else:\n",
|
|||
|
" return dataset_name"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"dataset_names = ['eth', 'hotel', 'univ', 'zara1', 'zara2', 'Average']\n",
|
|||
|
"alg_name = \"Ours\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Displacement Error Analysis"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 79,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"prior_work_fse_results = {\n",
|
|||
|
" 'ETH - Univ': OrderedDict([('Linear', 2.94), ('Vanilla LSTM', 2.41), ('Social LSTM', 2.35), ('Social Attention', 3.74)]),\n",
|
|||
|
" 'ETH - Hotel': OrderedDict([('Linear', 0.72), ('Vanilla LSTM', 1.91), ('Social LSTM', 1.76), ('Social Attention', 2.64)]),\n",
|
|||
|
" 'UCY - Univ': OrderedDict([('Linear', 1.59), ('Vanilla LSTM', 1.31), ('Social LSTM', 1.40), ('Social Attention', 0.52)]),\n",
|
|||
|
" 'UCY - Zara 1': OrderedDict([('Linear', 1.21), ('Vanilla LSTM', 0.88), ('Social LSTM', 1.00), ('Social Attention', 2.13)]),\n",
|
|||
|
" 'UCY - Zara 2': OrderedDict([('Linear', 1.48), ('Vanilla LSTM', 1.11), ('Social LSTM', 1.17), ('Social Attention', 3.92)]),\n",
|
|||
|
" 'Average': OrderedDict([('Linear', 1.59), ('Vanilla LSTM', 1.52), ('Social LSTM', 1.54), ('Social Attention', 2.59)])\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# These are for a prediction horizon of 12 timesteps.\n",
|
|||
|
"prior_work_ade_results = {\n",
|
|||
|
" 'ETH - Univ': OrderedDict([('Linear', 1.33), ('Vanilla LSTM', 1.09), ('Social LSTM', 1.09), ('Social Attention', 0.39)]),\n",
|
|||
|
" 'ETH - Hotel': OrderedDict([('Linear', 0.39), ('Vanilla LSTM', 0.86), ('Social LSTM', 0.79), ('Social Attention', 0.29)]),\n",
|
|||
|
" 'UCY - Univ': OrderedDict([('Linear', 0.82), ('Vanilla LSTM', 0.61), ('Social LSTM', 0.67), ('Social Attention', 0.20)]),\n",
|
|||
|
" 'UCY - Zara 1': OrderedDict([('Linear', 0.62), ('Vanilla LSTM', 0.41), ('Social LSTM', 0.47), ('Social Attention', 0.30)]),\n",
|
|||
|
" 'UCY - Zara 2': OrderedDict([('Linear', 0.77), ('Vanilla LSTM', 0.52), ('Social LSTM', 0.56), ('Social Attention', 0.33)]),\n",
|
|||
|
" 'Average': OrderedDict([('Linear', 0.79), ('Vanilla LSTM', 0.70), ('Social LSTM', 0.72), ('Social Attention', 0.30)])\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"linestyles = ['--', '-.', '-', ':']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 81,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"mean_markers = 'X'\n",
|
|||
|
"marker_size = 7\n",
|
|||
|
"line_colors = ['#1f78b4','#33a02c','#fb9a99','#e31a1c']\n",
|
|||
|
"area_colors = ['#80CBE5','#ABCB51', '#F05F78']\n",
|
|||
|
"area_rgbs = list()\n",
|
|||
|
"for c in area_colors:\n",
|
|||
|
" area_rgbs.append([int(c[i:i+2], 16) for i in (1, 3, 5)])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 94,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_attention_radius_3_fde_most_likely.csv\n",
|
|||
|
"results/hotel_attention_radius_3_fde_most_likely.csv\n",
|
|||
|
"results/univ_attention_radius_3_fde_most_likely.csv\n",
|
|||
|
"results/zara1_attention_radius_3_fde_most_likely.csv\n",
|
|||
|
"results/zara2_attention_radius_3_fde_most_likely.csv\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>error_value</th>\n",
|
|||
|
" <th>error_type</th>\n",
|
|||
|
" <th>type</th>\n",
|
|||
|
" <th>dataset</th>\n",
|
|||
|
" <th>method</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.242668</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>ml</td>\n",
|
|||
|
" <td>eth</td>\n",
|
|||
|
" <td>Ours</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.158331</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>ml</td>\n",
|
|||
|
" <td>eth</td>\n",
|
|||
|
" <td>Ours</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0.095482</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>ml</td>\n",
|
|||
|
" <td>eth</td>\n",
|
|||
|
" <td>Ours</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>1.069288</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>ml</td>\n",
|
|||
|
" <td>eth</td>\n",
|
|||
|
" <td>Ours</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1.734359</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>ml</td>\n",
|
|||
|
" <td>eth</td>\n",
|
|||
|
" <td>Ours</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" error_value error_type type dataset method\n",
|
|||
|
"0 0.242668 fde ml eth Ours\n",
|
|||
|
"1 0.158331 fde ml eth Ours\n",
|
|||
|
"2 0.095482 fde ml eth Ours\n",
|
|||
|
"3 1.069288 fde ml eth Ours\n",
|
|||
|
"4 1.734359 fde ml eth Ours"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 94,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Load Ours\n",
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}*attention_radius_3*fde_most_likely.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']\n",
|
|||
|
"perf_df = perf_df.rename(columns={\"metric\": \"error_type\", \"value\": \"error_value\"})"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 95,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Load Trajectron and GAN\n",
|
|||
|
"errors_df = pd.concat([pd.read_csv(f) for f in glob.glob('csv/old/curr_*_errors.csv')], ignore_index=True)\n",
|
|||
|
"del errors_df['data_precondition']\n",
|
|||
|
"errors_df = errors_df[~(errors_df['method'] == 'our_full')]\n",
|
|||
|
"errors_df = errors_df[~(errors_df['error_type'] == 'mse')]\n",
|
|||
|
"errors_df.loc[errors_df['error_type'] =='fse', 'error_type'] = 'fde'\n",
|
|||
|
"#errors_df.loc[errors_df['error_type'] =='mse', 'error_type'] = 'ade'\n",
|
|||
|
"errors_df.loc[errors_df['method'] == 'our_most_likely', 'method'] = 'Trajectron'"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 96,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/pandas/core/frame.py:7123: FutureWarning: Sorting because non-concatenation axis is not aligned. A future version\n",
|
|||
|
"of pandas will change to not sort by default.\n",
|
|||
|
"\n",
|
|||
|
"To accept the future behavior, pass 'sort=False'.\n",
|
|||
|
"\n",
|
|||
|
"To retain the current behavior and silence the warning, pass 'sort=True'.\n",
|
|||
|
"\n",
|
|||
|
" sort=sort,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>dataset</th>\n",
|
|||
|
" <th>method</th>\n",
|
|||
|
" <th>run</th>\n",
|
|||
|
" <th>node</th>\n",
|
|||
|
" <th>sample</th>\n",
|
|||
|
" <th>error_type</th>\n",
|
|||
|
" <th>error_value</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>2186000</td>\n",
|
|||
|
" <td>hotel</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Pedestrian/0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>4.045972</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>2186001</td>\n",
|
|||
|
" <td>hotel</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Pedestrian/0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>3.717624</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>2186002</td>\n",
|
|||
|
" <td>hotel</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Pedestrian/0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>5.378286</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>2186003</td>\n",
|
|||
|
" <td>hotel</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Pedestrian/0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>4.215567</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>2186004</td>\n",
|
|||
|
" <td>hotel</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Pedestrian/0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>4.663851</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>77099995</td>\n",
|
|||
|
" <td>zara2</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>99</td>\n",
|
|||
|
" <td>Pedestrian/35</td>\n",
|
|||
|
" <td>1995</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>0.620136</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>77099996</td>\n",
|
|||
|
" <td>zara2</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>99</td>\n",
|
|||
|
" <td>Pedestrian/35</td>\n",
|
|||
|
" <td>1996</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>0.681608</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>77099997</td>\n",
|
|||
|
" <td>zara2</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>99</td>\n",
|
|||
|
" <td>Pedestrian/35</td>\n",
|
|||
|
" <td>1997</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>0.860765</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>77099998</td>\n",
|
|||
|
" <td>zara2</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>99</td>\n",
|
|||
|
" <td>Pedestrian/35</td>\n",
|
|||
|
" <td>1998</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>0.545317</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <td>77099999</td>\n",
|
|||
|
" <td>zara2</td>\n",
|
|||
|
" <td>sgan</td>\n",
|
|||
|
" <td>99</td>\n",
|
|||
|
" <td>Pedestrian/35</td>\n",
|
|||
|
" <td>1999</td>\n",
|
|||
|
" <td>fde</td>\n",
|
|||
|
" <td>1.027843</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>25700000 rows × 7 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" dataset method run node sample error_type error_value\n",
|
|||
|
"2186000 hotel sgan 0 Pedestrian/0 0 fde 4.045972\n",
|
|||
|
"2186001 hotel sgan 0 Pedestrian/0 1 fde 3.717624\n",
|
|||
|
"2186002 hotel sgan 0 Pedestrian/0 2 fde 5.378286\n",
|
|||
|
"2186003 hotel sgan 0 Pedestrian/0 3 fde 4.215567\n",
|
|||
|
"2186004 hotel sgan 0 Pedestrian/0 4 fde 4.663851\n",
|
|||
|
"... ... ... ... ... ... ... ...\n",
|
|||
|
"77099995 zara2 sgan 99 Pedestrian/35 1995 fde 0.620136\n",
|
|||
|
"77099996 zara2 sgan 99 Pedestrian/35 1996 fde 0.681608\n",
|
|||
|
"77099997 zara2 sgan 99 Pedestrian/35 1997 fde 0.860765\n",
|
|||
|
"77099998 zara2 sgan 99 Pedestrian/35 1998 fde 0.545317\n",
|
|||
|
"77099999 zara2 sgan 99 Pedestrian/35 1999 fde 1.027843\n",
|
|||
|
"\n",
|
|||
|
"[25700000 rows x 7 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 96,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = perf_df.append(errors_df)\n",
|
|||
|
"errors_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 97,
|
|||
|
"metadata": {
|
|||
|
"scrolled": false
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB/IAAASvCAYAAAAaDLIdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd5hkVZn48e/LAIIkiaKgYMKACJINrGBGVFQQEJGgq+KurjkLK+KKcdc1/QyrDggqiIKKGZCoRBFJogIDwwgISA4CM+/vj3PLvnW7qruqunq6uuf7eZ5+Zu6tG05XV91z73nPeU9kJpIkSZIkSZIkSZIkaTQsN9MFkCRJkiRJkiRJkiRJYwzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkSxo5EbEgIrL6WTDT5ZEkaVkWEefV6uU7J9n2xbVtMyLeNaxjS5I0qiJi/0b9t79lkiTNZaNYz4ximTQ3NT5np8x0eTS3LT/TBZAkSZIk9Sci1gC2ATYCHgKsDNwF3AosAP6SmQtnrICS5qyIWBnYEngcsCawCnAPcDtwDXAFcGVmLpmxQmrOioi1KfXfIyj134qU+u8W4CpK/ffXmSuhpLnIuk+SNFMM5EsaWDVafqPaqp0y85SZKY0kSUVEfAg4tLbqvMzcZkjH3g04trbqRmCDzLx/GMfX7BQRLwZ+XFt1SWY+eRrOsyKwD/AGYFsgJtn+JuBc4CTgZ5l5aeP184Cthl3Ohrb3IiKeDFzUYbslwMaDdj6IiGOB3Tq89MXMfPMgx5Q0JiICeClwIPBcJm9PuiMizgdOBX4GnGtwY+6pRjp+s7bq1MzccRrO82Bgf+D1wBY9bH89cA6l/vtpZv6l8foC2tsypkPbexEROwK/7rDdvcD6mXnbICepRgE+q8NLh2Tmhwc5pqTCuk+DiIizgO0aq3fMzFNnojySZj9T60uSJGmuOQLI2vLWEfHEIR17v8bytw3ia2mIiK2A3wFfpzQMTRjEr6wD7Ax8GrgkIh4zfSWcsuUonRT6FhFrAS8ZbnEktUTERsCJwPHAC+ltUMhqwI7AfwJn4XdUA4qIf6F0APsiPQTxK+tTgm//C/y5ymIzqlYC9hhkx4jYGPiXYRZGUmHdp0FExBMYH8SH8e0I/Rxz40Ya9/kDHOOU+jEGLctc4RQMmm0M5EuSJGlOycxrGD/iad+pHjci1qU04tQdPtXjSpOJiB2A04BNO7x8F3ApcDZwMXDdRIcafumG6jUD7rcXJbWypCGLiEcDZwLP7vDyfcDllJHPFwILKdk1Oh5qWgqoOa3KeHMi8OgOL99OqffOBi4B/jbRoYZfuqEa9D51X0b/d5NmHes+TUG3gP3uVXYZSeqbqfUljZzM3HimyyBJmvUOp73hZZ+I+OAUUxvuDaxQW/5DZl4whePNCpm59UyXYVkWEesBPwLqDT//AL4EzAcuysxs7LMOJfX+i4FXAA/tcvh/A1bvoRgr0T51AJQ09rf3sO+dk7y+hLEO5k+MiG0y89wejltXbzCrH0/SFETECpTv/ga11QkcBXwFOCszH2jssyplyo6dgd2BUc4EMjSZOZ9yTdaQVFlkvkf7vdddlFH2R2bmZR32WZ8yEvKlwMuAtboc/tXAyj0U46HAkY11z+thP4BbJnm9Xl89IyIelZlX9XjslnoHAOs/aQis+3pn3dcuIpaje8fk1SjPT99aeiXSdMpMO+poqTGQL0mSpLno+5QUrKtWyxtSAvsnTuGYzdFSjsbX0nAo8JDa8g3A8zPzD912yMybgJ8CP42ItwC7Au+gfcoJMvOcXgpQNU42nVadZ6ouAB4LtNIe7wv0HMiv0lduW1t1MmUOU0lTdyDwpNryvcBumfnTbjtk5p2UuYFPBd4XEc8C3g4sns6Cak76JKUjWctVwPMy84puO2Tm9cAPgR9GxIGUlPVv77Ddmb0UoEpd39x3KveSdWdSOh2sSBm1uy9wSK87R8QzaQ8WWv9Jw2Hdp0E9j/YOIGcAz6wt74+BfEkDsKemJEmS5pzMvIsyiqtuKvPSbQpsWVv1AONHaElDFRErUtLG1/3rREH8psxcnJk/yMxnThT8mEH3AsfUll9VjYTqVf17fRelE4+k4WjWm4dMFMjoJDNPzcyXZWYzq4fUVUSsSRlVX7d3P/VYZt6fmUdl5taZeetwSzgUfwd+UlvuN71+/ft5A/CLKZdIElj3aXDNz86HgfNryztFxCOXXnEkzRUG8iVJkjRXNUfMv7zLyOJeNB/Kf56ZE83FKg3DdrSnvl9Ee6P/XFH/rq4N7NLLTlX6yn1qq37A5Kn8JfUgItaipAluWQJ8bYaKo2XPjrRnEf1DZp41Q2WZTvX679HVKPtJRcTKlGwDLd+mdDKVNAXWfRpURKxBmdKlZRHwa9o7/7eyr0hSX0ytL2lOq0Z0PR14MiUt7e3AQuDUzJxszrpez/FIYGvK/HlrArcB1wNnVqn9pnLsFYDHU9J6rU+ZU+kuSu/9PwK/a87NNQwRsQHl4eXhlAb1W4HjMvOvwz6XJE2j0yhpWB9VLa9CmbNwfj8HiYh5lLlU63o+RkSsRqmHHk+Zq3VlynX1JuD8zPxLP+Xp8ZxBySCwJbAucA+lbjojMxcO+3wzofq7bAJsCjyMEvC+h1JH/gk4LzPvm7kSDsWGjeXLMzM7bjmLZeaZEfEXSop9KA1cx/ew63Nof48Op3wWJE3dBo3lmzLz5qVdiIjYgvIstB4lzfrfKM9zZ2TmPUM8z+OBpwDrUOrqe4EbgcuAC6e7PomI9Sj3Co+hPLcuT6nPrgfOnupz5SzUrP/+OCOlmH4/pdwPrlMt70tJxTyZl9He0e9wYKfhFk1aJln3WfcNag/Kc37LdzJzSUR8B/g0MK9avy/w0aVduOlSPZNvTXmOWw94EOUzdBWlXfwfQz7fesAOlDaeFSh16KXAWZk566ayqL6DT6W8d6tQfp+/Ur7rt03D+Tan/L3WA/5B+a79JjMXDPtcGi4D+ZJGTkQsADaqFq/OzI0n2PbDwH/WVu2UmadExIOA9wBvo9yMNi2OiGOB92bm1QOUcUXgTcAbaJ87qy4j4nzg0Mz8UR/HXodyA/hiys3JRKNH74qI7wIf7ycQFBH1IMCpmbljtX4Xyhy6OzI+a8siemtUl6SRkJkZEUfQXk/sS5+BfMp8ow+vLf8
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 2400x1200 with 6 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"with sns.color_palette(\"muted\"):\n",
|
|||
|
" fig_fse, ax_fses = plt.subplots(nrows=1, ncols=6, figsize=(8, 4), dpi=300, sharey=True)\n",
|
|||
|
" for idx, ax_fse in enumerate(ax_fses):\n",
|
|||
|
" dataset_name = dataset_names[idx]\n",
|
|||
|
" if dataset_name != 'Average':\n",
|
|||
|
" specific_df = perf_df[(perf_df['dataset'] == dataset_name) & (perf_df['error_type'] == 'fde')]\n",
|
|||
|
" specific_df['dataset'] = pretty_dataset_name(dataset_name)\n",
|
|||
|
" else:\n",
|
|||
|
" specific_df = perf_df[(perf_df['error_type'] == 'fde')].copy()\n",
|
|||
|
" specific_df['dataset'] = 'Average'\n",
|
|||
|
"\n",
|
|||
|
" sns.boxplot(x='dataset', y='error_value', hue='method',\n",
|
|||
|
" data=specific_df, ax=ax_fse, showfliers=False,\n",
|
|||
|
" palette=area_colors, hue_order=['sgan', 'Trajectron', alg_name], width=2.)\n",
|
|||
|
" \n",
|
|||
|
" ax_fse.get_legend().remove()\n",
|
|||
|
" ax_fse.set_xlabel('')\n",
|
|||
|
" ax_fse.set_ylabel('' if idx > 0 else 'Final Displacement Error (m)')\n",
|
|||
|
"\n",
|
|||
|
" ax_fse.scatter([-0.665, 0, 0.665],\n",
|
|||
|
" [np.mean(specific_df[specific_df['method'] == 'sgan']['error_value']),\n",
|
|||
|
" np.mean(specific_df[specific_df['method'] == 'Trajectron']['error_value']),\n",
|
|||
|
" np.mean(specific_df[specific_df['method'] == alg_name]['error_value'])],\n",
|
|||
|
" s=marker_size*marker_size, c=np.asarray(area_rgbs)/255.0, marker=mean_markers,\n",
|
|||
|
" edgecolors='#545454', zorder=10)\n",
|
|||
|
" \n",
|
|||
|
" for baseline_idx, (baseline, fse_val) in enumerate(prior_work_fse_results[pretty_dataset_name(dataset_name)].items()):\n",
|
|||
|
" ax_fse.axhline(y=fse_val, label=baseline, color=line_colors[baseline_idx], linestyle=linestyles[baseline_idx])\n",
|
|||
|
" \n",
|
|||
|
" if idx == 0:\n",
|
|||
|
" handles, labels = ax_fse.get_legend_handles_labels()\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" handles = [handles[0], handles[4], handles[1], handles[5], handles[2], handles[6], handles[3]]\n",
|
|||
|
" labels = [labels[0], 'Social GAN', labels[1], 'Trajectron', labels[2], alg_name, labels[3]]\n",
|
|||
|
"\n",
|
|||
|
" ax_fse.legend(handles, labels, \n",
|
|||
|
" loc='lower center', bbox_to_anchor=(0.5, 0.9),\n",
|
|||
|
" ncol=4, borderaxespad=0, frameon=False,\n",
|
|||
|
" bbox_transform=fig_fse.transFigure)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# fig_fse.text(0.51, 0.03, 'Dataset', ha='center')\n",
|
|||
|
"\n",
|
|||
|
"plt.savefig('plots/fde_boxplots.pdf', dpi=300, bbox_inches='tight')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 98,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df\n",
|
|||
|
"del errors_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Average Displacement Error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 99,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_attention_radius_3_ade_most_likely.csv\n",
|
|||
|
"results/hotel_attention_radius_3_ade_most_likely.csv\n",
|
|||
|
"results/univ_attention_radius_3_ade_most_likely.csv\n",
|
|||
|
"results/zara1_attention_radius_3_ade_most_likely.csv\n",
|
|||
|
"results/zara2_attention_radius_3_ade_most_likely.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Load Ours\n",
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}*attention_radius_3*ade_most_likely.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']\n",
|
|||
|
"perf_df = perf_df.rename(columns={\"metric\": \"error_type\", \"value\": \"error_value\"})\n",
|
|||
|
"#perf_df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 100,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Load Trajectron and GAN\n",
|
|||
|
"errors_df = pd.concat([pd.read_csv(f) for f in glob.glob('old/curr_*_errors.csv')], ignore_index=True)\n",
|
|||
|
"del errors_df['data_precondition']\n",
|
|||
|
"errors_df = errors_df[~(errors_df['method'] == 'our_full')]\n",
|
|||
|
"errors_df = errors_df[~(errors_df['error_type'] == 'fse')]\n",
|
|||
|
"#errors_df.loc[errors_df['error_type'] =='fse', 'error_type'] = 'fde'\n",
|
|||
|
"errors_df.loc[errors_df['error_type'] =='mse', 'error_type'] = 'ade'\n",
|
|||
|
"errors_df.loc[errors_df['method'] == 'our_most_likely', 'method'] = 'Trajectron'"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 101,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/pandas/core/frame.py:7123: FutureWarning: Sorting because non-concatenation axis is not aligned. A future version\n",
|
|||
|
"of pandas will change to not sort by default.\n",
|
|||
|
"\n",
|
|||
|
"To accept the future behavior, pass 'sort=False'.\n",
|
|||
|
"\n",
|
|||
|
"To retain the current behavior and silence the warning, pass 'sort=True'.\n",
|
|||
|
"\n",
|
|||
|
" sort=sort,\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = perf_df.append(errors_df)\n",
|
|||
|
"del errors_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 102,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n",
|
|||
|
"/home/timsal/anaconda3/envs/trajectron/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
|
|||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
|||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
|||
|
"\n",
|
|||
|
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
|||
|
" import sys\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB/IAAASvCAYAAAAaDLIdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd5hkVZn48e/LAIIkiaKgYMKACJINrGBGVFQQEJGgq+KurjkLK+KKcdc1/QyrDggqiIKKGZCoRBFJogIDwwgISA4CM+/vj3PLvnW7qruqunq6uuf7eZ5+Zu6tG05XV91z73nPeU9kJpIkSZIkSZIkSZIkaTQsN9MFkCRJkiRJkiRJkiRJYwzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkS5IkSZIkSZIkSZI0QgzkSxo5EbEgIrL6WTDT5ZEkaVkWEefV6uU7J9n2xbVtMyLeNaxjS5I0qiJi/0b9t79lkiTNZaNYz4ximTQ3NT5np8x0eTS3LT/TBZAkSZIk9Sci1gC2ATYCHgKsDNwF3AosAP6SmQtnrICS5qyIWBnYEngcsCawCnAPcDtwDXAFcGVmLpmxQmrOioi1KfXfIyj134qU+u8W4CpK/ffXmSuhpLnIuk+SNFMM5EsaWDVafqPaqp0y85SZKY0kSUVEfAg4tLbqvMzcZkjH3g04trbqRmCDzLx/GMfX7BQRLwZ+XFt1SWY+eRrOsyKwD/AGYFsgJtn+JuBc4CTgZ5l5aeP184Cthl3Ohrb3IiKeDFzUYbslwMaDdj6IiGOB3Tq89MXMfPMgx5Q0JiICeClwIPBcJm9PuiMizgdOBX4GnGtwY+6pRjp+s7bq1MzccRrO82Bgf+D1wBY9bH89cA6l/vtpZv6l8foC2tsypkPbexEROwK/7rDdvcD6mXnbICepRgE+q8NLh2Tmhwc5pqTCuk+DiIizgO0aq3fMzFNnojySZj9T60uSJGmuOQLI2vLWEfHEIR17v8bytw3ia2mIiK2A3wFfpzQMTRjEr6wD7Ax8GrgkIh4zfSWcsuUonRT6FhFrAS8ZbnEktUTERsCJwPHAC+ltUMhqwI7AfwJn4XdUA4qIf6F0APsiPQTxK+tTgm//C/y5ymIzqlYC9hhkx4jYGPiXYRZGUmHdp0FExBMYH8SH8e0I/Rxz40Ya9/kDHOOU+jEGLctc4RQMmm0M5EuSJGlOycxrGD/iad+pHjci1qU04tQdPtXjSpOJiB2A04BNO7x8F3ApcDZwMXDdRIcafumG6jUD7rcXJbWypCGLiEcDZwLP7vDyfcDllJHPFwILKdk1Oh5qWgqoOa3KeHMi8OgOL99OqffOBi4B/jbRoYZfuqEa9D51X0b/d5NmHes+TUG3gP3uVXYZSeqbqfUljZzM3HimyyBJmvUOp73hZZ+I+OAUUxvuDaxQW/5DZl4whePNCpm59UyXYVkWEesBPwLqDT//AL4EzAcuysxs7LMOJfX+i4FXAA/tcvh/A1bvoRgr0T51AJQ09rf3sO+dk7y+hLEO5k+MiG0y89wejltXbzCrH0/SFETECpTv/ga11QkcBXwFOCszH2jssyplyo6dgd2BUc4EMjSZOZ9yTdaQVFlkvkf7vdddlFH2R2bmZR32WZ8yEvKlwMuAtboc/tXAyj0U46HAkY11z+thP4BbJnm9Xl89IyIelZlX9XjslnoHAOs/aQis+3pn3dcuIpaje8fk1SjPT99aeiXSdMpMO+poqTGQL0mSpLno+5QUrKtWyxtSAvsnTuGYzdFSjsbX0nAo8JDa8g3A8zPzD912yMybgJ8CP42ItwC7Au+gfcoJMvOcXgpQNU42nVadZ6ouAB4LtNIe7wv0HMiv0lduW1t1MmUOU0lTdyDwpNryvcBumfnTbjtk5p2UuYFPBd4XEc8C3g4sns6Cak76JKUjWctVwPMy84puO2Tm9cAPgR9GxIGUlPVv77Ddmb0UoEpd39x3KveSdWdSOh2sSBm1uy9wSK87R8QzaQ8WWv9Jw2Hdp0E9j/YOIGcAz6wt74+BfEkDsKemJEmS5pzMvIsyiqtuKvPSbQpsWVv1AONHaElDFRErUtLG1/3rREH8psxcnJk/yMxnThT8mEH3AsfUll9VjYTqVf17fRelE4+k4WjWm4dMFMjoJDNPzcyXZWYzq4fUVUSsSRlVX7d3P/VYZt6fmUdl5taZeetwSzgUfwd+UlvuN71+/ft5A/CLKZdIElj3aXDNz86HgfNryztFxCOXXnEkzRUG8iVJkjRXNUfMv7zLyOJeNB/Kf56ZE83FKg3DdrSnvl9Ee6P/XFH/rq4N7NLLTlX6yn1qq37A5Kn8JfUgItaipAluWQJ8bYaKo2XPjrRnEf1DZp41Q2WZTvX679HVKPtJRcTKlGwDLd+mdDKVNAXWfRpURKxBmdKlZRHwa9o7/7eyr0hSX0ytL2lOq0Z0PR14MiUt7e3AQuDUzJxszrpez/FIYGvK/HlrArcB1wNnVqn9pnLsFYDHU9J6rU+ZU+kuSu/9PwK/a87NNQwRsQHl4eXhlAb1W4HjMvOvwz6XJE2j0yhpWB9VLa9CmbNwfj8HiYh5lLlU63o+RkSsRqmHHk+Zq3VlynX1JuD8zPxLP+Xp8ZxBySCwJbAucA+lbjojMxcO+3wzofq7bAJsCjyMEvC+h1JH/gk4LzPvm7kSDsWGjeXLMzM7bjmLZeaZEfEXSop9KA1cx/ew63Nof48Op3wWJE3dBo3lmzLz5qVdiIjYgvIstB4lzfrfKM9zZ2TmPUM8z+OBpwDrUOrqe4EbgcuAC6e7PomI9Sj3Co+hPLcuT6nPrgfOnupz5SzUrP/+OCOlmH4/pdwPrlMt70tJxTyZl9He0e9wYKfhFk1aJln3WfcNag/Kc37LdzJzSUR8B/g0MK9avy/w0aVduOlSPZNvTXmOWw94EOUzdBWlXfwfQz7fesAOlDaeFSh16KXAWZk566ayqL6DT6W8d6tQfp+/Ur7rt03D+Tan/L3WA/5B+a79JjMXDPtcGi4D+ZJGTkQsADaqFq/OzI0n2PbDwH/WVu2UmadExIOA9wBvo9yMNi2OiGOB92bm1QOUcUXgTcAbaJ87qy4j4nzg0Mz8UR/HXodyA/hiys3JRKNH74qI7wIf7ycQFBH1IMCpmbljtX4Xyhy6OzI+a8siemtUl6SRkJkZEUfQXk/sS5+BfMp8ow+vLf8
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 2400x1200 with 6 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"with sns.color_palette(\"muted\"):\n",
|
|||
|
" fig_fse, ax_fses = plt.subplots(nrows=1, ncols=6, figsize=(8, 4), dpi=300, sharey=True)\n",
|
|||
|
" for idx, ax_fse in enumerate(ax_fses):\n",
|
|||
|
" dataset_name = dataset_names[idx]\n",
|
|||
|
" if dataset_name != 'Average':\n",
|
|||
|
" specific_df = perf_df[(perf_df['dataset'] == dataset_name) & (perf_df['error_type'] == 'ade')]\n",
|
|||
|
" specific_df['dataset'] = pretty_dataset_name(dataset_name)\n",
|
|||
|
" else:\n",
|
|||
|
" specific_df = perf_df[(perf_df['error_type'] == 'ade')].copy()\n",
|
|||
|
" specific_df['dataset'] = 'Average'\n",
|
|||
|
"\n",
|
|||
|
" sns.boxplot(x='dataset', y='error_value', hue='method',\n",
|
|||
|
" data=specific_df, ax=ax_fse, showfliers=False,\n",
|
|||
|
" palette=area_colors, hue_order=['sgan', 'Trajectron', alg_name], width=2.)\n",
|
|||
|
"\n",
|
|||
|
" ax_fse.get_legend().remove()\n",
|
|||
|
" ax_fse.set_xlabel('')\n",
|
|||
|
" ax_fse.set_ylabel('' if idx > 0 else 'Average Displacement Error (m)')\n",
|
|||
|
"\n",
|
|||
|
" ax_fse.scatter([-0.665, 0, 0.665],\n",
|
|||
|
" [np.mean(specific_df[specific_df['method'] == 'sgan']['error_value']),\n",
|
|||
|
" np.mean(specific_df[specific_df['method'] == 'Trajectron']['error_value']),\n",
|
|||
|
" np.mean(specific_df[specific_df['method'] == alg_name]['error_value'])],\n",
|
|||
|
" s=marker_size*marker_size, c=np.asarray(area_rgbs)/255.0, marker=mean_markers,\n",
|
|||
|
" edgecolors='#545454', zorder=10)\n",
|
|||
|
" \n",
|
|||
|
" for baseline_idx, (baseline, fse_val) in enumerate(prior_work_ade_results[pretty_dataset_name(dataset_name)].items()):\n",
|
|||
|
" ax_fse.axhline(y=fse_val, label=baseline, color=line_colors[baseline_idx], linestyle=linestyles[baseline_idx])\n",
|
|||
|
" \n",
|
|||
|
" if idx == 0:\n",
|
|||
|
" handles, labels = ax_fse.get_legend_handles_labels()\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" handles = [handles[0], handles[4], handles[1], handles[5], handles[2], handles[6], handles[3]]\n",
|
|||
|
" labels = [labels[0], 'Social GAN', labels[1], 'Trajectron', labels[2], alg_name, labels[3]]\n",
|
|||
|
"\n",
|
|||
|
" ax_fse.legend(handles, labels, \n",
|
|||
|
" loc='lower center', bbox_to_anchor=(0.5, 0.9),\n",
|
|||
|
" ncol=4, borderaxespad=0, frameon=False,\n",
|
|||
|
" bbox_transform=fig_fse.transFigure)\n",
|
|||
|
"\n",
|
|||
|
"# fig_fse.text(0.51, 0.03, 'Dataset', ha='center')\n",
|
|||
|
"\n",
|
|||
|
"plt.savefig('plots/ade_boxplots.pdf', dpi=300, bbox_inches='tight')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# KDE Negative Log Likelihood Attention Radius 3m"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_12_kde_full.csv\n",
|
|||
|
"results/hotel_12_kde_full.csv\n",
|
|||
|
"results/univ_12_kde_full.csv\n",
|
|||
|
"results/zara1_12_kde_full.csv\n",
|
|||
|
"results/zara2_12_kde_full.csv\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/home/borisi/anaconda3/envs/gentraj/lib/python3.6/site-packages/pandas/core/frame.py:7123: FutureWarning: Sorting because non-concatenation axis is not aligned. A future version\n",
|
|||
|
"of pandas will change to not sort by default.\n",
|
|||
|
"\n",
|
|||
|
"To accept the future behavior, pass 'sort=False'.\n",
|
|||
|
"\n",
|
|||
|
"To retain the current behavior and silence the warning, pass 'sort=True'.\n",
|
|||
|
"\n",
|
|||
|
" sort=sort,\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Load Ours\n",
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_12*kde_full.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True)\n",
|
|||
|
" del perf_df['Unnamed: 0']\n",
|
|||
|
"#perf_df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# # Load Trajectron and SGAN\n",
|
|||
|
"# lls_df = pd.concat([pd.read_csv(f) for f in glob.glob('csv/old/curr_*_lls.csv')], ignore_index=True)\n",
|
|||
|
"# lls_df.loc[lls_df['method'] == 'our_full', 'method'] = 'Trajectron'\n",
|
|||
|
"# lls_df['error_type'] = 'KDE'\n",
|
|||
|
"# #lls_df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"KDE NLL for ETH - Univ\n",
|
|||
|
"Ours: 2.922618333137345\n",
|
|||
|
"KDE NLL for ETH - Hotel\n",
|
|||
|
"Ours: -1.4172003110091507\n",
|
|||
|
"KDE NLL for UCY - Univ\n",
|
|||
|
"Ours: -1.0312765478327917\n",
|
|||
|
"KDE NLL for UCY - Zara 1\n",
|
|||
|
"Ours: -1.320554662223125\n",
|
|||
|
"KDE NLL for UCY - Zara 2\n",
|
|||
|
"Ours: -2.265794496180711\n",
|
|||
|
"KDE NLL for Average\n",
|
|||
|
"Ours: -1.1628978588526018\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
|
|||
|
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan')]['log-likelihood'].mean()}\")\n",
|
|||
|
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron')]['log-likelihood'].mean()}\")\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")\n",
|
|||
|
" "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Most Likely FDE Attention Radius 3m"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_12_fde_most_likely.csv\n",
|
|||
|
"results/hotel_12_fde_most_likely.csv\n",
|
|||
|
"results/univ_12_fde_most_likely.csv\n",
|
|||
|
"results/zara1_12_fde_most_likely.csv\n",
|
|||
|
"results/zara2_12_fde_most_likely.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_12*fde_most_likely.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = 'Trajectron++'\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"FDE Most Likely for ETH - Univ\n",
|
|||
|
"Ours: 1.8771982385234873\n",
|
|||
|
"FDE Most Likely for ETH - Hotel\n",
|
|||
|
"Ours: 0.617328603355503\n",
|
|||
|
"FDE Most Likely for UCY - Univ\n",
|
|||
|
"Ours: 1.1245144943461012\n",
|
|||
|
"FDE Most Likely for UCY - Zara 1\n",
|
|||
|
"Ours: 0.7944446572286218\n",
|
|||
|
"FDE Most Likely for UCY - Zara 2\n",
|
|||
|
"Ours: 0.660121689972624\n",
|
|||
|
"FDE Most Likely for Average\n",
|
|||
|
"Ours: 1.0173143626632004\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('FDE Most Likely for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Most Likely Evaluation ADE Attention Radius 3m"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_12_ade_most_likely.csv\n",
|
|||
|
"results/hotel_12_ade_most_likely.csv\n",
|
|||
|
"results/univ_12_ade_most_likely.csv\n",
|
|||
|
"results/zara1_12_ade_most_likely.csv\n",
|
|||
|
"results/zara2_12_ade_most_likely.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_12*ade_most_likely.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = 'Trajectron++'\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"ADE Most Likely for ETH - Univ\n",
|
|||
|
"Ours: 0.7914500193014897\n",
|
|||
|
"ADE Most Likely for ETH - Hotel\n",
|
|||
|
"Ours: 0.27776006532871594\n",
|
|||
|
"ADE Most Likely for UCY - Univ\n",
|
|||
|
"Ours: 0.4266315625403817\n",
|
|||
|
"ADE Most Likely for UCY - Zara 1\n",
|
|||
|
"Ours: 0.3099620513771519\n",
|
|||
|
"ADE Most Likely for UCY - Zara 2\n",
|
|||
|
"Ours: 0.25468354869237164\n",
|
|||
|
"ADE Most Likely for Average\n",
|
|||
|
"Ours: 0.3919718759746782\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('ADE Most Likely for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Best of 20 Evaluation FDE Attention Radius 3m"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_12_fde_best_of.csv\n",
|
|||
|
"results/hotel_12_fde_best_of.csv\n",
|
|||
|
"results/univ_12_fde_best_of.csv\n",
|
|||
|
"results/zara1_12_fde_best_of.csv\n",
|
|||
|
"results/zara2_12_fde_best_of.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_12*fde_best_of.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"FDE Best of 20 for ETH - Univ\n",
|
|||
|
"Trajectron++: 0.941833113317412\n",
|
|||
|
"FDE Best of 20 for ETH - Hotel\n",
|
|||
|
"Trajectron++: 0.25054278903414373\n",
|
|||
|
"FDE Best of 20 for UCY - Univ\n",
|
|||
|
"Trajectron++: 0.44114303527050003\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 1\n",
|
|||
|
"Trajectron++: 0.3371937683488606\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 2\n",
|
|||
|
"Trajectron++: 0.27854739859128275\n",
|
|||
|
"FDE Best of 20 for Average\n",
|
|||
|
"Trajectron++: 0.4106989096039851\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('FDE Best of 20 for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Best of 20 Evaluation ADE Attention Radius 3m"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_12_ade_best_of.csv\n",
|
|||
|
"results/hotel_12_ade_best_of.csv\n",
|
|||
|
"results/univ_12_ade_best_of.csv\n",
|
|||
|
"results/zara1_12_ade_best_of.csv\n",
|
|||
|
"results/zara2_12_ade_best_of.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_12*ade_best_of.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"FDE Best of 20 for ETH - Univ\n",
|
|||
|
"Trajectron++: 0.4934326048690871\n",
|
|||
|
"FDE Best of 20 for ETH - Hotel\n",
|
|||
|
"Trajectron++: 0.15716624347036173\n",
|
|||
|
"FDE Best of 20 for UCY - Univ\n",
|
|||
|
"Trajectron++: 0.2269283937135511\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 1\n",
|
|||
|
"Trajectron++: 0.1771332219260789\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 2\n",
|
|||
|
"Trajectron++: 0.13879135677920432\n",
|
|||
|
"FDE Best of 20 for Average\n",
|
|||
|
"Trajectron++: 0.21259031594000022\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('FDE Best of 20 for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# KDE Negative Log Likelihood Attention Radius 3m Velocity"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_vel_12_kde_full.csv\n",
|
|||
|
"results/hotel_vel_12_kde_full.csv\n",
|
|||
|
"results/univ_vel_12_kde_full.csv\n",
|
|||
|
"results/zara1_vel_12_kde_full.csv\n",
|
|||
|
"results/zara2_vel_12_kde_full.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Load Ours\n",
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_vel_12*kde_full.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True)\n",
|
|||
|
" del perf_df['Unnamed: 0']\n",
|
|||
|
"#perf_df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# # Load Trajectron and SGAN\n",
|
|||
|
"# lls_df = pd.concat([pd.read_csv(f) for f in glob.glob('csv/old/curr_*_lls.csv')], ignore_index=True)\n",
|
|||
|
"# lls_df.loc[lls_df['method'] == 'our_full', 'method'] = 'Trajectron'\n",
|
|||
|
"# lls_df['error_type'] = 'KDE'\n",
|
|||
|
"# #lls_df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"KDE NLL for ETH - Univ\n",
|
|||
|
"Ours: 3.927605528941531\n",
|
|||
|
"KDE NLL for ETH - Hotel\n",
|
|||
|
"Ours: -0.6236621180717697\n",
|
|||
|
"KDE NLL for UCY - Univ\n",
|
|||
|
"Ours: -0.7887800746973215\n",
|
|||
|
"KDE NLL for UCY - Zara 1\n",
|
|||
|
"Ours: -1.0283870848449408\n",
|
|||
|
"KDE NLL for UCY - Zara 2\n",
|
|||
|
"Ours: -1.9260379497727524\n",
|
|||
|
"KDE NLL for Average\n",
|
|||
|
"Ours: -0.8485584209581651\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
|
|||
|
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron') & (lls_df['dataset'] == dataset)]['log-likelihood'].mean()}\")\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print('KDE NLL for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" #print(f\"SGAN: {-lls_df[(lls_df['method'] == 'sgan')]['log-likelihood'].mean()}\")\n",
|
|||
|
" #print(f\"Trajectron: {-lls_df[(lls_df['method'] == 'Trajectron')]['log-likelihood'].mean()}\")\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")\n",
|
|||
|
" "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Most Likely FDE Attention Radius 3m Velocity"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_vel_12_fde_most_likely.csv\n",
|
|||
|
"results/hotel_vel_12_fde_most_likely.csv\n",
|
|||
|
"results/univ_vel_12_fde_most_likely.csv\n",
|
|||
|
"results/zara1_vel_12_fde_most_likely.csv\n",
|
|||
|
"results/zara2_vel_12_fde_most_likely.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_vel_12*fde_most_likely.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = 'Trajectron++'\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 25,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"FDE Most Likely for ETH - Univ\n",
|
|||
|
"Ours: 1.8566004278950379\n",
|
|||
|
"FDE Most Likely for ETH - Hotel\n",
|
|||
|
"Ours: 0.6021892704369326\n",
|
|||
|
"FDE Most Likely for UCY - Univ\n",
|
|||
|
"Ours: 1.2171244891615123\n",
|
|||
|
"FDE Most Likely for UCY - Zara 1\n",
|
|||
|
"Ours: 0.8159963144196981\n",
|
|||
|
"FDE Most Likely for UCY - Zara 2\n",
|
|||
|
"Ours: 0.6650947374043866\n",
|
|||
|
"FDE Most Likely for Average\n",
|
|||
|
"Ours: 1.080938659044269\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('FDE Most Likely for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Most Likely Evaluation ADE Attention Radius 3m Velocity"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_vel_12_ade_most_likely.csv\n",
|
|||
|
"results/hotel_vel_12_ade_most_likely.csv\n",
|
|||
|
"results/univ_vel_12_ade_most_likely.csv\n",
|
|||
|
"results/zara1_vel_12_ade_most_likely.csv\n",
|
|||
|
"results/zara2_vel_12_ade_most_likely.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_vel_12*ade_most_likely.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = 'Trajectron++'\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {
|
|||
|
"scrolled": true
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"ADE Most Likely for ETH - Univ\n",
|
|||
|
"Ours: 0.7925476110083421\n",
|
|||
|
"ADE Most Likely for ETH - Hotel\n",
|
|||
|
"Ours: 0.27995871067462225\n",
|
|||
|
"ADE Most Likely for UCY - Univ\n",
|
|||
|
"Ours: 0.46060802655655825\n",
|
|||
|
"ADE Most Likely for UCY - Zara 1\n",
|
|||
|
"Ours: 0.3145776262029875\n",
|
|||
|
"ADE Most Likely for UCY - Zara 2\n",
|
|||
|
"Ours: 0.25594473472397583\n",
|
|||
|
"ADE Most Likely for Average\n",
|
|||
|
"Ours: 0.41564361286805335\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('ADE Most Likely for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++') & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"{alg_name}: {perf_df[(perf_df['method'] == 'Trajectron++')]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Best of 20 Evaluation FDE Attention Radius 3m Velocity"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_vel_12_fde_best_of.csv\n",
|
|||
|
"results/hotel_vel_12_fde_best_of.csv\n",
|
|||
|
"results/univ_vel_12_fde_best_of.csv\n",
|
|||
|
"results/zara1_vel_12_fde_best_of.csv\n",
|
|||
|
"results/zara2_vel_12_fde_best_of.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_vel_12*fde_best_of.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"FDE Best of 20 for ETH - Univ\n",
|
|||
|
"Trajectron++: 0.9436506373753673\n",
|
|||
|
"FDE Best of 20 for ETH - Hotel\n",
|
|||
|
"Trajectron++: 0.26871119754090905\n",
|
|||
|
"FDE Best of 20 for UCY - Univ\n",
|
|||
|
"Trajectron++: 0.45632249026126165\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 1\n",
|
|||
|
"Trajectron++: 0.33819336402725053\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 2\n",
|
|||
|
"Trajectron++: 0.278890779270787\n",
|
|||
|
"FDE Best of 20 for Average\n",
|
|||
|
"Trajectron++: 0.42212505975114584\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('FDE Best of 20 for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Best of 20 Evaluation ADE Attention Radius 3m Velocity"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"results/eth_vel_12_ade_best_of.csv\n",
|
|||
|
"results/hotel_vel_12_ade_best_of.csv\n",
|
|||
|
"results/univ_vel_12_ade_best_of.csv\n",
|
|||
|
"results/zara1_vel_12_ade_best_of.csv\n",
|
|||
|
"results/zara2_vel_12_ade_best_of.csv\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"perf_df = pd.DataFrame()\n",
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" for f in glob.glob(f\"results/{dataset}_vel_12*ade_best_of.csv\"):\n",
|
|||
|
" print(f)\n",
|
|||
|
" dataset_df = pd.read_csv(f)\n",
|
|||
|
" dataset_df['dataset'] = dataset\n",
|
|||
|
" dataset_df['method'] = alg_name\n",
|
|||
|
" perf_df = perf_df.append(dataset_df, ignore_index=True, sort=False)\n",
|
|||
|
" del perf_df['Unnamed: 0']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"FDE Best of 20 for ETH - Univ\n",
|
|||
|
"Trajectron++: 0.46389447635823006\n",
|
|||
|
"FDE Best of 20 for ETH - Hotel\n",
|
|||
|
"Trajectron++: 0.14917384309686665\n",
|
|||
|
"FDE Best of 20 for UCY - Univ\n",
|
|||
|
"Trajectron++: 0.20508554382763963\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 1\n",
|
|||
|
"Trajectron++: 0.15763210664999946\n",
|
|||
|
"FDE Best of 20 for UCY - Zara 2\n",
|
|||
|
"Trajectron++: 0.12714387128413132\n",
|
|||
|
"FDE Best of 20 for Average\n",
|
|||
|
"Trajectron++: 0.19313473176031068\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for dataset in dataset_names:\n",
|
|||
|
" print('FDE Best of 20 for ' + pretty_dataset_name(dataset))\n",
|
|||
|
" if dataset != 'Average':\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name) & (perf_df['dataset'] == dataset)]['value'].mean()}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"Trajectron++: {perf_df[(perf_df['method'] == alg_name)]['value'].mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"del perf_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3.6 (GenTrajectron)",
|
|||
|
"language": "python",
|
|||
|
"name": "gentraj"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.6.9"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|