From 38b31a5d726256c7e518fd341ddab8f2761b1445 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 2 May 2025 16:24:02 -0700 Subject: [PATCH] only instantiate termination pred head if episodic=true --- tdmpc2/common/world_model.py | 2 +- tdmpc2/tdmpc2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 1d7bd73..f748f0c 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -25,7 +25,7 @@ class WorldModel(nn.Module): self._encoder = layers.enc(cfg) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) - self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) + self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) if cfg.episodic else None self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 7357d48..bdb5225 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -24,7 +24,7 @@ class TDMPC2(torch.nn.Module): {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, {'params': self.model._reward.parameters()}, - {'params': self.model._termination.parameters()}, + {'params': self.model._termination.parameters() if self.cfg.episodic else []}, {'params': self.model._Qs.parameters()}, {'params': self.model._task_emb.parameters() if self.cfg.multitask else [] }