applied formatter

This commit is contained in:
NM512
2023-07-23 22:02:06 +09:00
parent afa5ab988d
commit 12ed21e06d
10 changed files with 506 additions and 440 deletions

View File

@@ -217,10 +217,12 @@ def make_env(config, mode):
env = wrappers.OneHotAction(env)
elif suite == "crafter":
import envs.crafter as crafter
env = crafter.Crafter(task, config.size)
env = wrappers.OneHotAction(env)
elif suite == "minecraft":
import envs.minecraft as minecraft
env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
env = wrappers.OneHotAction(env)
else:
@@ -294,7 +296,15 @@ def main(config):
logprob = random_actor.log_prob(action)
return {"action": action, "logprob": logprob}, None
state = tools.simulate(random_agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=prefill)
state = tools.simulate(
random_agent,
train_envs,
train_eps,
config.traindir,
logger,
limit=config.dataset_size,
steps=prefill,
)
logger.step += prefill * config.action_repeat
print(f"Logger: ({logger.step} steps).")
@@ -317,12 +327,29 @@ def main(config):
logger.write()
print("Start evaluation.")
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, eval_eps, config.evaldir, logger, is_eval=True, episodes=config.eval_episode_num)
tools.simulate(
eval_policy,
eval_envs,
eval_eps,
config.evaldir,
logger,
is_eval=True,
episodes=config.eval_episode_num,
)
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
print("Start training.")
state = tools.simulate(agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=config.eval_every, state=state)
state = tools.simulate(
agent,
train_envs,
train_eps,
config.traindir,
logger,
limit=config.dataset_size,
steps=config.eval_every,
state=state,
)
torch.save(agent.state_dict(), logdir / "latest_model.pt")
for env in train_envs + eval_envs:
try: