Refactor Trajectron++ to work as module and expose training command

This commit is contained in:
Ruben van de Ven 2023-12-06 12:28:56 +01:00
parent 4fa3ce95ee
commit 51d6157af9
5 changed files with 574 additions and 393 deletions

931
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,17 @@
[tool.poetry] [tool.poetry]
name = "Trajectron-plus-plus" name = "Trajectron-plus-plus"
version = "0.1.0" version = "0.1.1"
description = "Predict trajectories for anomaly detection" description = "This repository contains the code for Trajectron++: Dynamically-Feasible Trajectory Forecasting With Heterogeneous Data by Tim Salzmann*, Boris Ivanovic*, Punarjay Chakravarty, and Marco Pavone (* denotes equal contribution)."
authors = ["Ruben van de Ven <git@rubenvandeven.com>"] authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
readme = "README.md" readme = "README.md"
license = "MIT"
packages = [
{ include = "trajectron" },
]
[tool.poetry.scripts]
trajectron_train = "trajectron.train:main"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9,<3.12" python = "^3.9,<3.12"
@ -28,8 +36,6 @@ matplotlib = "^3.5"
#scikit-learn = "0.22.1" #scikit-learn = "0.22.1"
#filterpy = "^1.4.5" #filterpy = "^1.4.5"
tqdm = "^4.65.0" tqdm = "^4.65.0"
#ipywidgets = "^8.0.6"
#deep-sort-realtime = "^1.3.2"
scipy = "^1.11.3" scipy = "^1.11.3"
pandas = "^2.1.1" pandas = "^2.1.1"
orjson = "^3.9.7" orjson = "^3.9.7"
@ -41,7 +47,7 @@ notebook = "^7.0.4"
scikit-learn = "^1.3.1" scikit-learn = "^1.3.1"
seaborn = "^0.13.0" seaborn = "^0.13.0"
setuptools = "^68.2.2" setuptools = "^68.2.2"
tensorboard = "1.14" tensorboard = "^2.11"
tensorboardx = "^2.6.2.2" tensorboardx = "^2.6.2.2"

View file

@ -4,7 +4,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--conf", parser.add_argument("--conf",
help="path to json config file for hyperparameters", help="path to json config file for hyperparameters",
type=str, type=str,
default='../config/config.json') default='./config/config.json')
parser.add_argument("--debug", parser.add_argument("--debug",
help="disable all disk writing processes.", help="disable all disk writing processes.",
@ -97,7 +97,7 @@ parser.add_argument('--no_edge_encoding',
parser.add_argument("--data_dir", parser.add_argument("--data_dir",
help="what dir to look in for data", help="what dir to look in for data",
type=str, type=str,
default='../experiments/processed') default='./experiments/processed')
parser.add_argument("--train_data_dict", parser.add_argument("--train_data_dict",
help="what file to load for training data", help="what file to load for training data",

View file

@ -3,7 +3,7 @@ from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation from scipy.ndimage import binary_dilation
from scipy.stats import gaussian_kde from scipy.stats import gaussian_kde
from trajectron.utils import prediction_output_to_trajectories from trajectron.utils import prediction_output_to_trajectories
import trajectron.visualization import trajectron.visualization as visualization
from matplotlib import pyplot as plt from matplotlib import pyplot as plt

View file

@ -9,14 +9,14 @@ import random
import pathlib import pathlib
import warnings import warnings
from tqdm import tqdm from tqdm import tqdm
import visualization import trajectron.visualization as visualization
import evaluation import trajectron.evaluation as evaluation
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from argument_parser import args from trajectron.argument_parser import args
from model.trajectron import Trajectron from trajectron.model.trajectron import Trajectron
from model.model_registrar import ModelRegistrar from trajectron.model.model_registrar import ModelRegistrar
from model.model_utils import cyclical_lr from trajectron.model.model_utils import cyclical_lr
from model.dataset import EnvironmentDataset, collate from trajectron.model.dataset import EnvironmentDataset, collate
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
# torch.autograd.set_detect_anomaly(True) # torch.autograd.set_detect_anomaly(True)