modified based on author's implementation
This commit is contained in:
88
dreamer.py
88
dreamer.py
@@ -22,6 +22,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch import distributions as torchd
|
||||
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
|
||||
@@ -31,7 +32,8 @@ class Dreamer(nn.Module):
|
||||
self._config = config
|
||||
self._logger = logger
|
||||
self._should_log = tools.Every(config.log_every)
|
||||
self._should_train = tools.Every(config.train_every)
|
||||
batch_steps = config.batch_size * config.batch_length
|
||||
self._should_train = tools.Every(batch_steps / config.train_ratio)
|
||||
self._should_pretrain = tools.Once()
|
||||
self._should_reset = tools.Every(config.reset_every)
|
||||
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
|
||||
@@ -146,16 +148,17 @@ class Dreamer(nn.Module):
|
||||
post, context, mets = self._wm._train(data)
|
||||
metrics.update(mets)
|
||||
start = post
|
||||
# start['deter'] (16, 64, 512)
|
||||
if self._config.pred_discount: # Last step could be terminal.
|
||||
start = {k: v[:, :-1] for k, v in post.items()}
|
||||
context = {k: v[:, :-1] for k, v in context.items()}
|
||||
start = {k: v[:-1] for k, v in post.items()}
|
||||
context = {k: v[:-1] for k, v in context.items()}
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||
self._wm.dynamics.get_feat(s)
|
||||
).mode()
|
||||
metrics.update(self._task_behavior._train(start, reward)[-1])
|
||||
if self._config.expl_behavior != "greedy":
|
||||
if self._config.pred_discount:
|
||||
data = {k: v[:, :-1] for k, v in data.items()}
|
||||
data = {k: v[:-1] for k, v in data.items()}
|
||||
mets = self._expl_behavior.train(start, context, data)[-1]
|
||||
metrics.update({"expl_" + key: value for key, value in mets.items()})
|
||||
for name, value in metrics.items():
|
||||
@@ -205,7 +208,12 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
if (mode == "train") or (mode == "eval"):
|
||||
callbacks = [
|
||||
functools.partial(
|
||||
process_episode, config, logger, mode, train_eps, eval_eps
|
||||
ProcessEpisodeWrap.process_episode,
|
||||
config,
|
||||
logger,
|
||||
mode,
|
||||
train_eps,
|
||||
eval_eps,
|
||||
)
|
||||
]
|
||||
env = wrappers.CollectDataset(env, callbacks)
|
||||
@@ -213,31 +221,51 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
return env
|
||||
|
||||
|
||||
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
|
||||
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
|
||||
cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
filename = tools.save_episodes(directory, [episode])[0]
|
||||
length = len(episode["reward"]) - 1
|
||||
score = float(episode["reward"].astype(np.float64).sum())
|
||||
video = episode["image"]
|
||||
if mode == "eval":
|
||||
cache.clear()
|
||||
if mode == "train" and config.dataset_size:
|
||||
total = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
if total <= config.dataset_size - length:
|
||||
total += len(ep["reward"]) - 1
|
||||
class ProcessEpisodeWrap:
|
||||
eval_scores = []
|
||||
eval_lengths = []
|
||||
|
||||
@classmethod
|
||||
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
|
||||
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
|
||||
cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
# this saved episodes is given as train_eps or eval_eps from next call
|
||||
filename = tools.save_episodes(directory, [episode])[0]
|
||||
length = len(episode["reward"]) - 1
|
||||
score = float(episode["reward"].astype(np.float64).sum())
|
||||
video = episode["image"]
|
||||
cache[str(filename)] = episode
|
||||
if mode == "eval":
|
||||
cls.eval_scores.append(score)
|
||||
cls.eval_lengths.append(length)
|
||||
# save when enought number of episodes are stored
|
||||
if len(cls.eval_scores) < config.eval_episode_num:
|
||||
return
|
||||
else:
|
||||
del cache[key]
|
||||
logger.scalar("dataset_size", total + length)
|
||||
cache[str(filename)] = episode
|
||||
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
|
||||
logger.scalar(f"{mode}_return", score)
|
||||
logger.scalar(f"{mode}_length", length)
|
||||
logger.scalar(f"{mode}_episodes", len(cache))
|
||||
if mode == "eval" or config.expl_gifs:
|
||||
logger.video(f"{mode}_policy", video[None])
|
||||
logger.write()
|
||||
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
||||
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
||||
episode_num = len(cls.eval_scores)
|
||||
cls.eval_scores = []
|
||||
cls.eval_lengths = []
|
||||
|
||||
if mode == "train" and config.dataset_size:
|
||||
total = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
if total <= config.dataset_size - length:
|
||||
total += len(ep["reward"]) - 1
|
||||
else:
|
||||
del cache[key]
|
||||
logger.scalar("dataset_size", total + length)
|
||||
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
|
||||
logger.scalar(f"{mode}_return", score)
|
||||
logger.scalar(f"{mode}_length", length)
|
||||
logger.scalar(
|
||||
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
||||
)
|
||||
if mode == "eval" or config.expl_gifs:
|
||||
# only last video in eval videos is preservad
|
||||
logger.video(f"{mode}_policy", video[None])
|
||||
logger.write()
|
||||
|
||||
|
||||
def main(config):
|
||||
@@ -315,7 +343,7 @@ def main(config):
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=1)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
||||
print("Start training.")
|
||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||
|
||||
Reference in New Issue
Block a user