fix bug when using envs > 1
This commit is contained in:
59
dreamer.py
59
dreamer.py
@@ -39,6 +39,7 @@ class Dreamer(nn.Module):
|
||||
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
|
||||
self._metrics = {}
|
||||
self._step = count_steps(config.traindir)
|
||||
self._update_count = 0
|
||||
# Schedules.
|
||||
config.actor_entropy = lambda x=config.actor_entropy: tools.schedule(
|
||||
x, self._step
|
||||
@@ -75,14 +76,16 @@ class Dreamer(nn.Module):
|
||||
state[0][key][i] *= mask[i]
|
||||
for i in range(len(state[1])):
|
||||
state[1][i] *= mask[i]
|
||||
if training and self._should_train(step):
|
||||
if training:
|
||||
steps = (
|
||||
self._config.pretrain
|
||||
if self._should_pretrain()
|
||||
else self._config.train_steps
|
||||
else self._should_train(step)
|
||||
)
|
||||
for _ in range(steps):
|
||||
self._train(next(self._dataset))
|
||||
self._update_count += 1
|
||||
self._metrics["update_count"] = self._update_count
|
||||
if self._should_log(step):
|
||||
for name, values in self._metrics.items():
|
||||
self._logger.scalar(name, float(np.mean(values)))
|
||||
@@ -227,6 +230,8 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
class ProcessEpisodeWrap:
|
||||
eval_scores = []
|
||||
eval_lengths = []
|
||||
last_step_at_eval = -1
|
||||
eval_done = False
|
||||
|
||||
@classmethod
|
||||
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
|
||||
@@ -238,20 +243,6 @@ class ProcessEpisodeWrap:
|
||||
score = float(episode["reward"].astype(np.float64).sum())
|
||||
video = episode["image"]
|
||||
cache[str(filename)] = episode
|
||||
if mode == "eval":
|
||||
cls.eval_scores.append(score)
|
||||
cls.eval_lengths.append(length)
|
||||
# save when enought number of episodes are stored
|
||||
if len(cls.eval_scores) < config.eval_episode_num:
|
||||
return
|
||||
else:
|
||||
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
||||
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
||||
episode_num = len(cls.eval_scores)
|
||||
cls.eval_scores = []
|
||||
cls.eval_lengths = []
|
||||
cache.clear()
|
||||
|
||||
if mode == "train":
|
||||
total = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
@@ -260,16 +251,39 @@ class ProcessEpisodeWrap:
|
||||
else:
|
||||
del cache[key]
|
||||
logger.scalar("dataset_size", total)
|
||||
# use dataset_size as log step for a condition of envs > 1
|
||||
log_step = total * config.action_repeat
|
||||
elif mode == "eval":
|
||||
# start saving episodes for evaluation
|
||||
if cls.last_step_at_eval != logger.step:
|
||||
# keep only last item
|
||||
while len(cache) > 1:
|
||||
# FIFO
|
||||
cache.popitem()
|
||||
cls.eval_scores = []
|
||||
cls.eval_lengths = []
|
||||
cls.eval_done = False
|
||||
cls.last_step_at_eval = logger.step
|
||||
|
||||
cls.eval_scores.append(score)
|
||||
cls.eval_lengths.append(length)
|
||||
# ignore if number of eval episodes exceeds eval_episode_num
|
||||
if len(cls.eval_scores) < config.eval_episode_num or cls.eval_done:
|
||||
return
|
||||
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
||||
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
||||
episode_num = len(cls.eval_scores)
|
||||
log_step = logger.step
|
||||
logger.video(f"{mode}_policy", video[None])
|
||||
cls.eval_done = True
|
||||
|
||||
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
|
||||
logger.scalar(f"{mode}_return", score)
|
||||
logger.scalar(f"{mode}_length", length)
|
||||
logger.scalar(
|
||||
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
||||
)
|
||||
if mode == "eval" or config.expl_gifs:
|
||||
# only last video in eval videos is preservad
|
||||
logger.video(f"{mode}_policy", video[None])
|
||||
logger.write()
|
||||
logger.write(step=log_step)
|
||||
|
||||
|
||||
def main(config):
|
||||
@@ -329,7 +343,6 @@ def main(config):
|
||||
return {"action": action, "logprob": logprob}, None
|
||||
|
||||
tools.simulate(random_agent, train_envs, prefill)
|
||||
tools.simulate(random_agent, eval_envs, episodes=1)
|
||||
logger.step = config.action_repeat * count_steps(config.traindir)
|
||||
|
||||
print("Simulate agent.")
|
||||
@@ -345,10 +358,10 @@ def main(config):
|
||||
while agent._step < config.steps:
|
||||
logger.write()
|
||||
print("Start evaluation.")
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
print("Start training.")
|
||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||
|
||||
Reference in New Issue
Block a user