added save and load for optimizers
This commit is contained in:
12
dreamer.py
12
dreamer.py
@@ -319,8 +319,10 @@ def main(config):
|
||||
train_dataset,
|
||||
).to(config.device)
|
||||
agent.requires_grad_(requires_grad=False)
|
||||
if (logdir / "latest_model.pt").exists():
|
||||
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
||||
if (logdir / "latest.pt").exists():
|
||||
checkpoint = torch.load(logdir / "latest.pt")
|
||||
agent.load_state_dict(checkpoint["agent_state_dict"])
|
||||
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
|
||||
agent._should_pretrain._once = False
|
||||
|
||||
# make sure eval will be executed once after config.steps
|
||||
@@ -352,7 +354,11 @@ def main(config):
|
||||
steps=config.eval_every,
|
||||
state=state,
|
||||
)
|
||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||
items_to_save = {
|
||||
"agent_state_dict": agent.state_dict(),
|
||||
"optims_state_dict": tools.recursively_collect_optim_state_dict(agent),
|
||||
}
|
||||
torch.save(items_to_save, logdir / "latest.pt")
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
env.close()
|
||||
|
||||
Reference in New Issue
Block a user