Trajectron as a module while supporting old models
This commit is contained in:
parent
dfa1d43f2e
commit
4fa3ce95ee
4 changed files with 30 additions and 19 deletions
|
@ -5,13 +5,12 @@ import torch
|
|||
import dill
|
||||
import random
|
||||
import pathlib
|
||||
import evaluation
|
||||
import numpy as np
|
||||
import visualization as vis
|
||||
from argument_parser import args
|
||||
from model.online.online_trajectron import OnlineTrajectron
|
||||
from model.model_registrar import ModelRegistrar
|
||||
from environment import Environment, Scene
|
||||
import trajectron.visualization as vis
|
||||
from trajectron.argument_parser import args
|
||||
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
||||
from trajectron.model.model_registrar import ModelRegistrar
|
||||
from trajectron.environment import Environment, Scene
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if not torch.cuda.is_available() or args.device == 'cpu':
|
|
@ -1,11 +1,23 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import pickle
|
||||
|
||||
def get_model_device(model):
|
||||
return next(model.parameters()).device
|
||||
|
||||
class PickleModuleCompatibility:
|
||||
'''
|
||||
Migrating Trajectron++ to a module structure
|
||||
while maintaining compatibility with models generated
|
||||
before the migration
|
||||
'''
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
if module == 'model' or module[:6] == 'model.':
|
||||
module = 'trajectron.' + module
|
||||
return super().find_class(module, name)
|
||||
|
||||
|
||||
class ModelRegistrar(nn.Module):
|
||||
def __init__(self, model_dir, device):
|
||||
|
@ -65,7 +77,7 @@ class ModelRegistrar(nn.Module):
|
|||
|
||||
print('')
|
||||
print('Loading from ' + save_path)
|
||||
self.model_dict = torch.load(save_path, map_location=self.device)
|
||||
self.model_dict = torch.load(save_path, map_location=self.device, pickle_module=PickleModuleCompatibility)
|
||||
print('Loaded!')
|
||||
print('')
|
||||
|
||||
|
|
|
@ -4,13 +4,13 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from collections import defaultdict, Counter
|
||||
from model.components import *
|
||||
from model.model_utils import *
|
||||
from model.dataset import get_relative_robot_traj
|
||||
import model.dynamics as dynamic_module
|
||||
from model.mgcvae import MultimodalGenerativeCVAE
|
||||
from environment.scene_graph import DirectedEdge
|
||||
from environment.node_type import NodeType
|
||||
from trajectron.model.components import *
|
||||
from trajectron.model.model_utils import *
|
||||
from trajectron.model.dataset import get_relative_robot_traj
|
||||
import trajectron.model.dynamics as dynamic_module
|
||||
from trajectron.model.mgcvae import MultimodalGenerativeCVAE
|
||||
from trajectron.environment.scene_graph import DirectedEdge
|
||||
from trajectron.environment.node_type import NodeType
|
||||
|
||||
|
||||
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from model.trajectron import Trajectron
|
||||
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
||||
from model.model_utils import ModeKeys
|
||||
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
||||
from trajectron.model.trajectron import Trajectron
|
||||
from trajectron.model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
||||
from trajectron.model.model_utils import ModeKeys
|
||||
from trajectron.environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
||||
|
||||
|
||||
class OnlineTrajectron(Trajectron):
|
||||
|
|
Loading…
Reference in a new issue