separated cache management of episode from env

This commit is contained in:
NM512
2023-07-22 19:22:41 +09:00
parent 88514ec022
commit 9ca5082da3
3 changed files with 194 additions and 167 deletions

View File

@@ -36,7 +36,8 @@ class Dreamer(nn.Module):
self._should_reset = tools.Every(config.reset_every)
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
self._metrics = {}
self._step = count_steps(config.traindir)
# this is update step
self._step = logger.step // config.action_repeat
self._update_count = 0
# Schedules.
config.actor_entropy = lambda x=config.actor_entropy: tools.schedule(
@@ -226,82 +227,23 @@ def make_env(config, logger, mode, train_eps, eval_eps):
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
env = wrappers.SelectAction(env, key="action")
if (mode == "train") or (mode == "eval"):
callbacks = [
functools.partial(
ProcessEpisodeWrap.process_episode,
config,
logger,
mode,
train_eps,
eval_eps,
)
]
env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks)
env = wrappers.UUID(env)
# if (mode == "train") or (mode == "eval"):
# callbacks = [
# functools.partial(
# ProcessEpisodeWrap.process_episode,
# config,
# logger,
# mode,
# train_eps,
# eval_eps,
# )
# ]
# env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks)
env = wrappers.RewardObs(env)
return env
class ProcessEpisodeWrap:
eval_scores = []
eval_lengths = []
last_step_at_eval = -1
eval_done = False
@classmethod
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
cache = dict(train=train_eps, eval=eval_eps)[mode]
# this saved episodes is given as train_eps or eval_eps from next call
filename = tools.save_episodes(directory, [episode])[0]
length = len(episode["reward"]) - 1
score = float(episode["reward"].astype(np.float64).sum())
video = episode["image"]
# add new episode
cache[str(filename)] = episode
if mode == "train":
step_in_dataset = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if (
not config.dataset_size
or step_in_dataset + (len(ep["reward"]) - 1) <= config.dataset_size
):
step_in_dataset += len(ep["reward"]) - 1
else:
del cache[key]
logger.scalar("dataset_size", step_in_dataset)
elif mode == "eval":
# keep only last item for saving memory
while len(cache) > 1:
# FIFO
cache.popitem()
# start counting scores for evaluation
if cls.last_step_at_eval != logger.step:
cls.eval_scores = []
cls.eval_lengths = []
cls.eval_done = False
cls.last_step_at_eval = logger.step
cls.eval_scores.append(score)
cls.eval_lengths.append(length)
# ignore if number of eval episodes exceeds eval_episode_num
if len(cls.eval_scores) < config.eval_episode_num or cls.eval_done:
return
score = sum(cls.eval_scores) / len(cls.eval_scores)
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
episode_num = len(cls.eval_scores)
logger.video(f"{mode}_policy", video[None])
cls.eval_done = True
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
logger.scalar(f"{mode}_return", score)
logger.scalar(f"{mode}_length", length)
logger.scalar(
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
)
logger.write(step=logger.step)
def main(config):
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps"
@@ -316,6 +258,7 @@ def main(config):
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
# step in logger is environmental step
logger = tools.Logger(logdir, config.action_repeat * step)
print("Create envs.")
@@ -357,8 +300,9 @@ def main(config):
logprob = random_actor.log_prob(action)
return {"action": action, "logprob": logprob}, None
state = tools.simulate(random_agent, train_envs, prefill)
logger.step = config.action_repeat * count_steps(config.traindir)
state = tools.simulate(random_agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=prefill)
logger.step += prefill * config.action_repeat
print(f"Logger: ({logger.step} steps).")
print("Simulate agent.")
train_dataset = make_dataset(train_eps, config)
@@ -379,12 +323,12 @@ def main(config):
logger.write()
print("Start evaluation.")
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
tools.simulate(eval_policy, eval_envs, eval_eps, config.evaldir, logger, is_eval=True, episodes=config.eval_episode_num)
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
print("Start training.")
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
state = tools.simulate(agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=config.eval_every, state=state)
torch.save(agent.state_dict(), logdir / "latest_model.pt")
for env in train_envs + eval_envs:
try: