modified training step display

This commit is contained in:
NM512
2023-06-24 23:05:45 +09:00
parent f3fe3a872e
commit 34a44916f7
2 changed files with 22 additions and 16 deletions

View File

@@ -213,10 +213,12 @@ def make_env(config, logger, mode, train_eps, eval_eps):
env = wrappers.OneHotAction(env)
elif suite == "MemoryMaze":
from envs.memorymaze import MemoryMaze
env = MemoryMaze(task)
env = wrappers.OneHotAction(env)
elif suite == "crafter":
import envs.crafter as crafter
env = crafter.Crafter(task, config.size)
env = wrappers.OneHotAction(env)
else:
@@ -254,17 +256,19 @@ class ProcessEpisodeWrap:
length = len(episode["reward"]) - 1
score = float(episode["reward"].astype(np.float64).sum())
video = episode["image"]
# add new episode
cache[str(filename)] = episode
if mode == "train":
total = 0
step_in_dataset = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if not config.dataset_size or total <= config.dataset_size - length:
total += len(ep["reward"]) - 1
if (
not config.dataset_size
or step_in_dataset + (len(ep["reward"]) - 1) <= config.dataset_size
):
step_in_dataset += len(ep["reward"]) - 1
else:
del cache[key]
logger.scalar("dataset_size", total)
# use dataset_size as log step for a condition of envs > 1
log_step = total * config.action_repeat
logger.scalar("dataset_size", step_in_dataset)
elif mode == "eval":
# keep only last item for saving memory
while len(cache) > 1:
@@ -285,7 +289,6 @@ class ProcessEpisodeWrap:
score = sum(cls.eval_scores) / len(cls.eval_scores)
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
episode_num = len(cls.eval_scores)
log_step = logger.step
logger.video(f"{mode}_policy", video[None])
cls.eval_done = True
@@ -295,7 +298,7 @@ class ProcessEpisodeWrap:
logger.scalar(
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
)
logger.write(step=log_step)
logger.write(step=logger.step)
def main(config):