update dependencies

This commit is contained in:
Nicklas Hansen
2023-11-25 19:20:52 -08:00
parent 58a95e431b
commit f3139291e2
3 changed files with 6 additions and 28 deletions

View File

@@ -17,9 +17,6 @@ class Ensemble(nn.Module):
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules)
def modules(self):
return self.vmap.__wrapped__.stateless_model
def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs)

View File

@@ -59,17 +59,13 @@ class WorldModel(nn.Module):
"""
Enables/disables gradient tracking of Q-networks.
Avoids unnecessary computation during policy optimization.
This method also enables/disables gradients for task embeddings,
and sets the dropout probability to 0 if `mode` is False.
This method also enables/disables gradients for task embeddings.
"""
for p in self._Qs.parameters():
p.requires_grad_(mode)
if self.cfg.multitask:
for p in self._task_emb.parameters():
p.requires_grad_(mode)
for m in self._Qs.modules():
if isinstance(m, nn.Dropout):
m.p = self.cfg.dropout if mode else 0
def soft_update_target_Q(self):
"""