diff --git a/trajectron/model/online/online_mgcvae.py b/trajectron/model/online/online_mgcvae.py index 1617115..91b13a5 100644 --- a/trajectron/model/online/online_mgcvae.py +++ b/trajectron/model/online/online_mgcvae.py @@ -209,7 +209,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): if self.node not in maps: # This means the node was removed (it is only being kept around because of the edge removal filter). me_params = self.hyperparams['map_encoder'][self.node_type] - self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])) + self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])).to(self.TD['node_history_encoded'].device) else: encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1., (mode == ModeKeys.TRAIN))