diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 1d7bd73..f748f0c 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -25,7 +25,7 @@ class WorldModel(nn.Module): self._encoder = layers.enc(cfg) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) - self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) + self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) if cfg.episodic else None self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 7357d48..bdb5225 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -24,7 +24,7 @@ class TDMPC2(torch.nn.Module): {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, {'params': self.model._reward.parameters()}, - {'params': self.model._termination.parameters()}, + {'params': self.model._termination.parameters() if self.cfg.episodic else []}, {'params': self.model._Qs.parameters()}, {'params': self.model._task_emb.parameters() if self.cfg.multitask else [] }