added save and load for optimizers

This commit is contained in:
NM512
2023-09-27 09:15:37 +09:00
parent 16635df3e4
commit d3576c5a98
3 changed files with 43 additions and 5 deletions

View File

@@ -70,7 +70,7 @@ class Plan2Explore(nn.Module):
[networks.MLP(**kw) for _ in range(config.disag_models)]
)
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._model_opt = tools.Optimizer(
self._expl_opt = tools.Optimizer(
"explorer",
self.parameters(),
config.model_lr,
@@ -129,5 +129,5 @@ class Plan2Explore(nn.Module):
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
)
loss = -torch.mean(likes)
metrics = self._model_opt(loss, self.parameters())
metrics = self._expl_opt(loss, self.parameters())
return metrics