update dependencies
This commit is contained in:
@@ -17,9 +17,6 @@ class Ensemble(nn.Module):
|
||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||
self._repr = str(modules)
|
||||
|
||||
def modules(self):
|
||||
return self.vmap.__wrapped__.stateless_model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||
|
||||
|
||||
@@ -59,17 +59,13 @@ class WorldModel(nn.Module):
|
||||
"""
|
||||
Enables/disables gradient tracking of Q-networks.
|
||||
Avoids unnecessary computation during policy optimization.
|
||||
This method also enables/disables gradients for task embeddings,
|
||||
and sets the dropout probability to 0 if `mode` is False.
|
||||
This method also enables/disables gradients for task embeddings.
|
||||
"""
|
||||
for p in self._Qs.parameters():
|
||||
p.requires_grad_(mode)
|
||||
if self.cfg.multitask:
|
||||
for p in self._task_emb.parameters():
|
||||
p.requires_grad_(mode)
|
||||
for m in self._Qs.modules():
|
||||
if isinstance(m, nn.Dropout):
|
||||
m.p = self.cfg.dropout if mode else 0
|
||||
|
||||
def soft_update_target_Q(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user