{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import glob\n", "import numpy as np\n", "import seaborn as sns\n", "import pandas as pd\n", "\n", "from collections import OrderedDict" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def pretty_dataset_name(dataset_name):\n", " if dataset_name == 'eth':\n", " return 'ETH - Univ'\n", " elif dataset_name == 'hotel':\n", " return 'ETH - Hotel'\n", " elif dataset_name == 'univ':\n", " return 'UCY - Univ'\n", " elif dataset_name == 'zara1':\n", " return 'UCY - Zara 1'\n", " elif dataset_name == 'zara2':\n", " return 'UCY - Zara 2'\n", " else:\n", " return dataset_name" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dataset_names = ['eth', 'hotel', 'univ', 'zara1', 'zara2', 'Average']\n", "alg_name = \"Ours\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Displacement Error Analysis" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "prior_work_fse_results = {\n", " 'ETH - Univ': OrderedDict([('Linear', 2.94), ('Vanilla LSTM', 2.41), ('Social LSTM', 2.35), ('Social Attention', 3.74)]),\n", " 'ETH - Hotel': OrderedDict([('Linear', 0.72), ('Vanilla LSTM', 1.91), ('Social LSTM', 1.76), ('Social Attention', 2.64)]),\n", " 'UCY - Univ': OrderedDict([('Linear', 1.59), ('Vanilla LSTM', 1.31), ('Social LSTM', 1.40), ('Social Attention', 0.52)]),\n", " 'UCY - Zara 1': OrderedDict([('Linear', 1.21), ('Vanilla LSTM', 0.88), ('Social LSTM', 1.00), ('Social Attention', 2.13)]),\n", " 'UCY - Zara 2': OrderedDict([('Linear', 1.48), ('Vanilla LSTM', 1.11), ('Social LSTM', 1.17), ('Social Attention', 3.92)]),\n", " 'Average': OrderedDict([('Linear', 1.59), ('Vanilla LSTM', 1.52), ('Social LSTM', 1.54), ('Social Attention', 2.59)])\n", "}\n", "\n", "\n", "# These are for a prediction horizon of 12 timesteps.\n", "prior_work_ade_results = {\n", " 'ETH - Univ': OrderedDict([('Linear', 1.33), ('Vanilla LSTM', 1.09), ('Social LSTM', 1.09), ('Social Attention', 0.39)]),\n", " 'ETH - Hotel': OrderedDict([('Linear', 0.39), ('Vanilla LSTM', 0.86), ('Social LSTM', 0.79), ('Social Attention', 0.29)]),\n", " 'UCY - Univ': OrderedDict([('Linear', 0.82), ('Vanilla LSTM', 0.61), ('Social LSTM', 0.67), ('Social Attention', 0.20)]),\n", " 'UCY - Zara 1': OrderedDict([('Linear', 0.62), ('Vanilla LSTM', 0.41), ('Social LSTM', 0.47), ('Social Attention', 0.30)]),\n", " 'UCY - Zara 2': OrderedDict([('Linear', 0.77), ('Vanilla LSTM', 0.52), ('Social LSTM', 0.56), ('Social Attention', 0.33)]),\n", " 'Average': OrderedDict([('Linear', 0.79), ('Vanilla LSTM', 0.70), ('Social LSTM', 0.72), ('Social Attention', 0.30)])\n", "}\n", "\n", "linestyles = ['--', '-.', '-', ':']" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "mean_markers = 'X'\n", "marker_size = 7\n", "line_colors = ['#1f78b4','#33a02c','#fb9a99','#e31a1c']\n", "area_colors = ['#80CBE5','#ABCB51', '#F05F78']\n", "area_rgbs = list()\n", "for c in area_colors:\n", " area_rgbs.append([int(c[i:i+2], 16) for i in (1, 3, 5)])" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "results/eth_attention_radius_3_fde_most_likely.csv\n", "results/hotel_attention_radius_3_fde_most_likely.csv\n", "results/univ_attention_radius_3_fde_most_likely.csv\n", "results/zara1_attention_radius_3_fde_most_likely.csv\n", "results/zara2_attention_radius_3_fde_most_likely.csv\n" ] }, { "data": { "text/html": [ "
\n", " | error_value | \n", "error_type | \n", "type | \n", "dataset | \n", "method | \n", "
---|---|---|---|---|---|
0 | \n", "0.242668 | \n", "fde | \n", "ml | \n", "eth | \n", "Ours | \n", "
1 | \n", "0.158331 | \n", "fde | \n", "ml | \n", "eth | \n", "Ours | \n", "
2 | \n", "0.095482 | \n", "fde | \n", "ml | \n", "eth | \n", "Ours | \n", "
3 | \n", "1.069288 | \n", "fde | \n", "ml | \n", "eth | \n", "Ours | \n", "
4 | \n", "1.734359 | \n", "fde | \n", "ml | \n", "eth | \n", "Ours | \n", "
\n", " | dataset | \n", "method | \n", "run | \n", "node | \n", "sample | \n", "error_type | \n", "error_value | \n", "
---|---|---|---|---|---|---|---|
2186000 | \n", "hotel | \n", "sgan | \n", "0 | \n", "Pedestrian/0 | \n", "0 | \n", "fde | \n", "4.045972 | \n", "
2186001 | \n", "hotel | \n", "sgan | \n", "0 | \n", "Pedestrian/0 | \n", "1 | \n", "fde | \n", "3.717624 | \n", "
2186002 | \n", "hotel | \n", "sgan | \n", "0 | \n", "Pedestrian/0 | \n", "2 | \n", "fde | \n", "5.378286 | \n", "
2186003 | \n", "hotel | \n", "sgan | \n", "0 | \n", "Pedestrian/0 | \n", "3 | \n", "fde | \n", "4.215567 | \n", "
2186004 | \n", "hotel | \n", "sgan | \n", "0 | \n", "Pedestrian/0 | \n", "4 | \n", "fde | \n", "4.663851 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
77099995 | \n", "zara2 | \n", "sgan | \n", "99 | \n", "Pedestrian/35 | \n", "1995 | \n", "fde | \n", "0.620136 | \n", "
77099996 | \n", "zara2 | \n", "sgan | \n", "99 | \n", "Pedestrian/35 | \n", "1996 | \n", "fde | \n", "0.681608 | \n", "
77099997 | \n", "zara2 | \n", "sgan | \n", "99 | \n", "Pedestrian/35 | \n", "1997 | \n", "fde | \n", "0.860765 | \n", "
77099998 | \n", "zara2 | \n", "sgan | \n", "99 | \n", "Pedestrian/35 | \n", "1998 | \n", "fde | \n", "0.545317 | \n", "
77099999 | \n", "zara2 | \n", "sgan | \n", "99 | \n", "Pedestrian/35 | \n", "1999 | \n", "fde | \n", "1.027843 | \n", "
25700000 rows × 7 columns
\n", "