Refactor Trajectron++ to work as module and expose training command
This commit is contained in:
parent
4fa3ce95ee
commit
51d6157af9
5 changed files with 574 additions and 393 deletions
931
poetry.lock
generated
931
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,9 +1,17 @@
|
|||
[tool.poetry]
|
||||
name = "Trajectron-plus-plus"
|
||||
version = "0.1.0"
|
||||
description = "Predict trajectories for anomaly detection"
|
||||
version = "0.1.1"
|
||||
description = "This repository contains the code for Trajectron++: Dynamically-Feasible Trajectory Forecasting With Heterogeneous Data by Tim Salzmann*, Boris Ivanovic*, Punarjay Chakravarty, and Marco Pavone (* denotes equal contribution)."
|
||||
authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
|
||||
packages = [
|
||||
{ include = "trajectron" },
|
||||
]
|
||||
|
||||
[tool.poetry.scripts]
|
||||
trajectron_train = "trajectron.train:main"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9,<3.12"
|
||||
|
@ -28,8 +36,6 @@ matplotlib = "^3.5"
|
|||
#scikit-learn = "0.22.1"
|
||||
#filterpy = "^1.4.5"
|
||||
tqdm = "^4.65.0"
|
||||
#ipywidgets = "^8.0.6"
|
||||
#deep-sort-realtime = "^1.3.2"
|
||||
scipy = "^1.11.3"
|
||||
pandas = "^2.1.1"
|
||||
orjson = "^3.9.7"
|
||||
|
@ -41,7 +47,7 @@ notebook = "^7.0.4"
|
|||
scikit-learn = "^1.3.1"
|
||||
seaborn = "^0.13.0"
|
||||
setuptools = "^68.2.2"
|
||||
tensorboard = "1.14"
|
||||
tensorboard = "^2.11"
|
||||
tensorboardx = "^2.6.2.2"
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ parser = argparse.ArgumentParser()
|
|||
parser.add_argument("--conf",
|
||||
help="path to json config file for hyperparameters",
|
||||
type=str,
|
||||
default='../config/config.json')
|
||||
default='./config/config.json')
|
||||
|
||||
parser.add_argument("--debug",
|
||||
help="disable all disk writing processes.",
|
||||
|
@ -97,7 +97,7 @@ parser.add_argument('--no_edge_encoding',
|
|||
parser.add_argument("--data_dir",
|
||||
help="what dir to look in for data",
|
||||
type=str,
|
||||
default='../experiments/processed')
|
||||
default='./experiments/processed')
|
||||
|
||||
parser.add_argument("--train_data_dict",
|
||||
help="what file to load for training data",
|
||||
|
|
|
@ -3,7 +3,7 @@ from scipy.interpolate import RectBivariateSpline
|
|||
from scipy.ndimage import binary_dilation
|
||||
from scipy.stats import gaussian_kde
|
||||
from trajectron.utils import prediction_output_to_trajectories
|
||||
import trajectron.visualization
|
||||
import trajectron.visualization as visualization
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
|
|
|
@ -9,14 +9,14 @@ import random
|
|||
import pathlib
|
||||
import warnings
|
||||
from tqdm import tqdm
|
||||
import visualization
|
||||
import evaluation
|
||||
import trajectron.visualization as visualization
|
||||
import trajectron.evaluation as evaluation
|
||||
import matplotlib.pyplot as plt
|
||||
from argument_parser import args
|
||||
from model.trajectron import Trajectron
|
||||
from model.model_registrar import ModelRegistrar
|
||||
from model.model_utils import cyclical_lr
|
||||
from model.dataset import EnvironmentDataset, collate
|
||||
from trajectron.argument_parser import args
|
||||
from trajectron.model.trajectron import Trajectron
|
||||
from trajectron.model.model_registrar import ModelRegistrar
|
||||
from trajectron.model.model_utils import cyclical_lr
|
||||
from trajectron.model.dataset import EnvironmentDataset, collate
|
||||
from tensorboardX import SummaryWriter
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
|
|
Loading…
Reference in a new issue