Compare commits

...

1 commit

Author SHA1 Message Date
Ruben van de Ven
51d6157af9 Refactor Trajectron++ to work as module and expose training command 2023-12-06 12:28:56 +01:00
5 changed files with 574 additions and 393 deletions

931
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,17 @@
[tool.poetry]
name = "Trajectron-plus-plus"
version = "0.1.0"
description = "Predict trajectories for anomaly detection"
version = "0.1.1"
description = "This repository contains the code for Trajectron++: Dynamically-Feasible Trajectory Forecasting With Heterogeneous Data by Tim Salzmann*, Boris Ivanovic*, Punarjay Chakravarty, and Marco Pavone (* denotes equal contribution)."
authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
readme = "README.md"
license = "MIT"
packages = [
{ include = "trajectron" },
]
[tool.poetry.scripts]
trajectron_train = "trajectron.train:main"
[tool.poetry.dependencies]
python = "^3.9,<3.12"
@ -28,8 +36,6 @@ matplotlib = "^3.5"
#scikit-learn = "0.22.1"
#filterpy = "^1.4.5"
tqdm = "^4.65.0"
#ipywidgets = "^8.0.6"
#deep-sort-realtime = "^1.3.2"
scipy = "^1.11.3"
pandas = "^2.1.1"
orjson = "^3.9.7"
@ -41,7 +47,7 @@ notebook = "^7.0.4"
scikit-learn = "^1.3.1"
seaborn = "^0.13.0"
setuptools = "^68.2.2"
tensorboard = "1.14"
tensorboard = "^2.11"
tensorboardx = "^2.6.2.2"

View file

@ -4,7 +4,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--conf",
help="path to json config file for hyperparameters",
type=str,
default='../config/config.json')
default='./config/config.json')
parser.add_argument("--debug",
help="disable all disk writing processes.",
@ -97,7 +97,7 @@ parser.add_argument('--no_edge_encoding',
parser.add_argument("--data_dir",
help="what dir to look in for data",
type=str,
default='../experiments/processed')
default='./experiments/processed')
parser.add_argument("--train_data_dict",
help="what file to load for training data",

View file

@ -3,7 +3,7 @@ from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation
from scipy.stats import gaussian_kde
from trajectron.utils import prediction_output_to_trajectories
import trajectron.visualization
import trajectron.visualization as visualization
from matplotlib import pyplot as plt

View file

@ -9,14 +9,14 @@ import random
import pathlib
import warnings
from tqdm import tqdm
import visualization
import evaluation
import trajectron.visualization as visualization
import trajectron.evaluation as evaluation
import matplotlib.pyplot as plt
from argument_parser import args
from model.trajectron import Trajectron
from model.model_registrar import ModelRegistrar
from model.model_utils import cyclical_lr
from model.dataset import EnvironmentDataset, collate
from trajectron.argument_parser import args
from trajectron.model.trajectron import Trajectron
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.model.model_utils import cyclical_lr
from trajectron.model.dataset import EnvironmentDataset, collate
from tensorboardX import SummaryWriter
# torch.autograd.set_detect_anomaly(True)