applied formatter
This commit is contained in:
33
dreamer.py
33
dreamer.py
@@ -217,10 +217,12 @@ def make_env(config, mode):
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "crafter":
|
||||
import envs.crafter as crafter
|
||||
|
||||
env = crafter.Crafter(task, config.size)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "minecraft":
|
||||
import envs.minecraft as minecraft
|
||||
|
||||
env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
|
||||
env = wrappers.OneHotAction(env)
|
||||
else:
|
||||
@@ -294,7 +296,15 @@ def main(config):
|
||||
logprob = random_actor.log_prob(action)
|
||||
return {"action": action, "logprob": logprob}, None
|
||||
|
||||
state = tools.simulate(random_agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=prefill)
|
||||
state = tools.simulate(
|
||||
random_agent,
|
||||
train_envs,
|
||||
train_eps,
|
||||
config.traindir,
|
||||
logger,
|
||||
limit=config.dataset_size,
|
||||
steps=prefill,
|
||||
)
|
||||
logger.step += prefill * config.action_repeat
|
||||
print(f"Logger: ({logger.step} steps).")
|
||||
|
||||
@@ -317,12 +327,29 @@ def main(config):
|
||||
logger.write()
|
||||
print("Start evaluation.")
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, eval_eps, config.evaldir, logger, is_eval=True, episodes=config.eval_episode_num)
|
||||
tools.simulate(
|
||||
eval_policy,
|
||||
eval_envs,
|
||||
eval_eps,
|
||||
config.evaldir,
|
||||
logger,
|
||||
is_eval=True,
|
||||
episodes=config.eval_episode_num,
|
||||
)
|
||||
if config.video_pred_log:
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
print("Start training.")
|
||||
state = tools.simulate(agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=config.eval_every, state=state)
|
||||
state = tools.simulate(
|
||||
agent,
|
||||
train_envs,
|
||||
train_eps,
|
||||
config.traindir,
|
||||
logger,
|
||||
limit=config.dataset_size,
|
||||
steps=config.eval_every,
|
||||
state=state,
|
||||
)
|
||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user