Trajectron as a module while supporting old models
This commit is contained in:
parent
dfa1d43f2e
commit
4fa3ce95ee
4 changed files with 30 additions and 19 deletions
|
@ -5,13 +5,12 @@ import torch
|
||||||
import dill
|
import dill
|
||||||
import random
|
import random
|
||||||
import pathlib
|
import pathlib
|
||||||
import evaluation
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import visualization as vis
|
import trajectron.visualization as vis
|
||||||
from argument_parser import args
|
from trajectron.argument_parser import args
|
||||||
from model.online.online_trajectron import OnlineTrajectron
|
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
||||||
from model.model_registrar import ModelRegistrar
|
from trajectron.model.model_registrar import ModelRegistrar
|
||||||
from environment import Environment, Scene
|
from trajectron.environment import Environment, Scene
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
if not torch.cuda.is_available() or args.device == 'cpu':
|
if not torch.cuda.is_available() or args.device == 'cpu':
|
|
@ -1,11 +1,23 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import pickle
|
||||||
|
|
||||||
def get_model_device(model):
|
def get_model_device(model):
|
||||||
return next(model.parameters()).device
|
return next(model.parameters()).device
|
||||||
|
|
||||||
|
class PickleModuleCompatibility:
|
||||||
|
'''
|
||||||
|
Migrating Trajectron++ to a module structure
|
||||||
|
while maintaining compatibility with models generated
|
||||||
|
before the migration
|
||||||
|
'''
|
||||||
|
class Unpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, module, name):
|
||||||
|
if module == 'model' or module[:6] == 'model.':
|
||||||
|
module = 'trajectron.' + module
|
||||||
|
return super().find_class(module, name)
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistrar(nn.Module):
|
class ModelRegistrar(nn.Module):
|
||||||
def __init__(self, model_dir, device):
|
def __init__(self, model_dir, device):
|
||||||
|
@ -65,7 +77,7 @@ class ModelRegistrar(nn.Module):
|
||||||
|
|
||||||
print('')
|
print('')
|
||||||
print('Loading from ' + save_path)
|
print('Loading from ' + save_path)
|
||||||
self.model_dict = torch.load(save_path, map_location=self.device)
|
self.model_dict = torch.load(save_path, map_location=self.device, pickle_module=PickleModuleCompatibility)
|
||||||
print('Loaded!')
|
print('Loaded!')
|
||||||
print('')
|
print('')
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,13 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
from model.components import *
|
from trajectron.model.components import *
|
||||||
from model.model_utils import *
|
from trajectron.model.model_utils import *
|
||||||
from model.dataset import get_relative_robot_traj
|
from trajectron.model.dataset import get_relative_robot_traj
|
||||||
import model.dynamics as dynamic_module
|
import trajectron.model.dynamics as dynamic_module
|
||||||
from model.mgcvae import MultimodalGenerativeCVAE
|
from trajectron.model.mgcvae import MultimodalGenerativeCVAE
|
||||||
from environment.scene_graph import DirectedEdge
|
from trajectron.environment.scene_graph import DirectedEdge
|
||||||
from environment.node_type import NodeType
|
from trajectron.environment.node_type import NodeType
|
||||||
|
|
||||||
|
|
||||||
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from model.trajectron import Trajectron
|
from trajectron.model.trajectron import Trajectron
|
||||||
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
from trajectron.model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
||||||
from model.model_utils import ModeKeys
|
from trajectron.model.model_utils import ModeKeys
|
||||||
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
from trajectron.environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
||||||
|
|
||||||
|
|
||||||
class OnlineTrajectron(Trajectron):
|
class OnlineTrajectron(Trajectron):
|
||||||
|
|
Loading…
Reference in a new issue