Some comments on parameters

This commit is contained in:
Ruben van de Ven 2024-06-20 19:21:44 +02:00
parent 531d61b69a
commit 53c18d9a7b

View file

@ -326,13 +326,18 @@ class PredictionServer:
start = time.time()
with warnings.catch_warnings():
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
# in the OnlineMultimodalGenerativeCVAE (see trajectron.model.online_mgcvae.py) each node's distribution
# is put stored in self.latent.p_dist by OnlineMultimodalGenerativeCVAE.p_z_x(). Type: torch.distributions.OneHotCategorical
# Later sampling in discrete_latent.py: DiscreteLatent.sample_p()
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=self.config.prediction_horizon, # TODO: make variable
num_samples=self.config.num_samples, # TODO: make variable
full_dist=self.config.full_dist,
gmm_mode=self.config.gmm_mode,
z_mode=self.config.z_mode)
full_dist=self.config.full_dist, # "The models full sampled output, where z and y are sampled sequentially"
gmm_mode=self.config.gmm_mode, # "If True: The mode of the Gaussian Mixture Model (GMM) is sampled (see trajectron.model.mgcvae.py)"
z_mode=self.config.z_mode # "Predictions from the models most-likely high-level latent behavior mode" (see trajecton.models.components.discrete_latent:sample_p(most_likely_z=z_mode))
)
end = time.time()
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
1. / (end - start), len(trajectron.nodes),