Trajectron as a module while supporting old models

This commit is contained in:
Ruben van de Ven 2023-10-09 21:06:36 +02:00
parent dfa1d43f2e
commit 4fa3ce95ee
4 changed files with 30 additions and 19 deletions

View File

@ -5,13 +5,12 @@ import torch
import dill
import random
import pathlib
import evaluation
import numpy as np
import visualization as vis
from argument_parser import args
from model.online.online_trajectron import OnlineTrajectron
from model.model_registrar import ModelRegistrar
from environment import Environment, Scene
import trajectron.visualization as vis
from trajectron.argument_parser import args
from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.environment import Environment, Scene
import matplotlib.pyplot as plt
if not torch.cuda.is_available() or args.device == 'cpu':

View File

@ -1,11 +1,23 @@
import os
import torch
import torch.nn as nn
import pickle
def get_model_device(model):
return next(model.parameters()).device
class PickleModuleCompatibility:
'''
Migrating Trajectron++ to a module structure
while maintaining compatibility with models generated
before the migration
'''
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'model' or module[:6] == 'model.':
module = 'trajectron.' + module
return super().find_class(module, name)
class ModelRegistrar(nn.Module):
def __init__(self, model_dir, device):
@ -65,7 +77,7 @@ class ModelRegistrar(nn.Module):
print('')
print('Loading from ' + save_path)
self.model_dict = torch.load(save_path, map_location=self.device)
self.model_dict = torch.load(save_path, map_location=self.device, pickle_module=PickleModuleCompatibility)
print('Loaded!')
print('')

View File

@ -4,13 +4,13 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import defaultdict, Counter
from model.components import *
from model.model_utils import *
from model.dataset import get_relative_robot_traj
import model.dynamics as dynamic_module
from model.mgcvae import MultimodalGenerativeCVAE
from environment.scene_graph import DirectedEdge
from environment.node_type import NodeType
from trajectron.model.components import *
from trajectron.model.model_utils import *
from trajectron.model.dataset import get_relative_robot_traj
import trajectron.model.dynamics as dynamic_module
from trajectron.model.mgcvae import MultimodalGenerativeCVAE
from trajectron.environment.scene_graph import DirectedEdge
from trajectron.environment.node_type import NodeType
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):

View File

@ -1,10 +1,10 @@
import torch
import numpy as np
from collections import Counter
from model.trajectron import Trajectron
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
from model.model_utils import ModeKeys
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
from trajectron.model.trajectron import Trajectron
from trajectron.model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
from trajectron.model.model_utils import ModeKeys
from trajectron.environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
class OnlineTrajectron(Trajectron):