Work around bug with map encoding
This commit is contained in:
parent
e5151e1a32
commit
dd64170ac3
1 changed files with 1 additions and 1 deletions
|
@ -209,7 +209,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
if self.node not in maps:
|
if self.node not in maps:
|
||||||
# This means the node was removed (it is only being kept around because of the edge removal filter).
|
# This means the node was removed (it is only being kept around because of the edge removal filter).
|
||||||
me_params = self.hyperparams['map_encoder'][self.node_type]
|
me_params = self.hyperparams['map_encoder'][self.node_type]
|
||||||
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size']))
|
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])).to(self.TD['node_history_encoded'].device)
|
||||||
else:
|
else:
|
||||||
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
|
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
|
||||||
(mode == ModeKeys.TRAIN))
|
(mode == ModeKeys.TRAIN))
|
||||||
|
|
Loading…
Reference in a new issue