Some comments on parameters
This commit is contained in:
parent
531d61b69a
commit
53c18d9a7b
1 changed files with 8 additions and 3 deletions
|
@ -326,13 +326,18 @@ class PredictionServer:
|
|||
start = time.time()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
|
||||
|
||||
# in the OnlineMultimodalGenerativeCVAE (see trajectron.model.online_mgcvae.py) each node's distribution
|
||||
# is put stored in self.latent.p_dist by OnlineMultimodalGenerativeCVAE.p_z_x(). Type: torch.distributions.OneHotCategorical
|
||||
# Later sampling in discrete_latent.py: DiscreteLatent.sample_p()
|
||||
dists, preds = trajectron.incremental_forward(input_dict,
|
||||
maps,
|
||||
prediction_horizon=self.config.prediction_horizon, # TODO: make variable
|
||||
num_samples=self.config.num_samples, # TODO: make variable
|
||||
full_dist=self.config.full_dist,
|
||||
gmm_mode=self.config.gmm_mode,
|
||||
z_mode=self.config.z_mode)
|
||||
full_dist=self.config.full_dist, # "The model’s full sampled output, where z and y are sampled sequentially"
|
||||
gmm_mode=self.config.gmm_mode, # "If True: The mode of the Gaussian Mixture Model (GMM) is sampled (see trajectron.model.mgcvae.py)"
|
||||
z_mode=self.config.z_mode # "Predictions from the model’s most-likely high-level latent behavior mode" (see trajecton.models.components.discrete_latent:sample_p(most_likely_z=z_mode))
|
||||
)
|
||||
end = time.time()
|
||||
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
|
||||
1. / (end - start), len(trajectron.nodes),
|
||||
|
|
Loading…
Reference in a new issue