only instantiate termination pred head if episodic=true
This commit is contained in:
@@ -25,7 +25,7 @@ class WorldModel(nn.Module):
|
||||
self._encoder = layers.enc(cfg)
|
||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
|
||||
self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) if cfg.episodic else None
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.apply(init.weight_init)
|
||||
|
||||
@@ -24,7 +24,7 @@ class TDMPC2(torch.nn.Module):
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
{'params': self.model._dynamics.parameters()},
|
||||
{'params': self.model._reward.parameters()},
|
||||
{'params': self.model._termination.parameters()},
|
||||
{'params': self.model._termination.parameters() if self.cfg.episodic else []},
|
||||
{'params': self.model._Qs.parameters()},
|
||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user