modifications for minecraft

This commit is contained in:
NM512
2023-08-05 21:13:57 +09:00
parent 8c471e12d6
commit 8571cf656a
3 changed files with 21 additions and 16 deletions

View File

@@ -230,7 +230,8 @@ def make_env(config, mode):
env = wrappers.TimeLimit(env, config.time_limit)
env = wrappers.SelectAction(env, key="action")
env = wrappers.UUID(env)
env = wrappers.RewardObs(env)
if suite == "minecraft":
env = wrappers.RewardObs(env)
return env
@@ -326,20 +327,21 @@ def main(config):
# make sure eval will be executed once after config.steps
while agent._step < config.steps + config.eval_every:
logger.write()
print("Start evaluation.")
eval_policy = functools.partial(agent, training=False)
tools.simulate(
eval_policy,
eval_envs,
eval_eps,
config.evaldir,
logger,
is_eval=True,
episodes=config.eval_episode_num,
)
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
if config.eval_episode_num > 0:
print("Start evaluation.")
eval_policy = functools.partial(agent, training=False)
tools.simulate(
eval_policy,
eval_envs,
eval_eps,
config.evaldir,
logger,
is_eval=True,
episodes=config.eval_episode_num,
)
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
print("Start training.")
state = tools.simulate(
agent,