From fb5c21557a64f1c8e633e115e163bedd46b417d7 Mon Sep 17 00:00:00 2001 From: NM512 Date: Sun, 12 Feb 2023 22:35:25 +0900 Subject: [PATCH] Initial Commit --- .gitignore | 134 +++++++++ LICENSE | 21 ++ README.md | 33 +++ configs.yaml | 136 +++++++++ dreamer.py | 343 +++++++++++++++++++++++ exploration.py | 108 ++++++++ models.py | 509 ++++++++++++++++++++++++++++++++++ networks.py | 631 ++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 12 + tools.py | 700 +++++++++++++++++++++++++++++++++++++++++++++++ wrappers.py | 419 ++++++++++++++++++++++++++++ 11 files changed, 3046 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 configs.yaml create mode 100644 dreamer.py create mode 100644 exploration.py create mode 100644 models.py create mode 100644 networks.py create mode 100644 requirements.txt create mode 100644 tools.py create mode 100644 wrappers.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3cbc5d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# +*.sh +logdir* +vis_* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..65c5fe5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 NM512 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..1810177 --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# Dreamer-v3 Pytorch +Pytorch implementation of [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1) + +![image_walker_walk](https://user-images.githubusercontent.com/70328564/218313056-c1158a7d-10f3-4052-b19d-6d642ee4850b.gif) + +## Instructions +Get dependencies: +``` +pip install -r requirements.txt +``` +Train the agent: +``` +python3 dreamer.py --configs defaults --logdir $ABSOLUTEPATH_TO_SAVE_LOG +``` +Monitor results: +``` +tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG +``` +## Evaluation Results +work-in-progress + +![Fig](https://user-images.githubusercontent.com/70328564/218313252-3d42193a-a7c4-4fd1-bd0a-df4f4f5787d5.png) + +## Awesome Environments used for testing: +- Deepmind control suite: https://github.com/deepmind/dm_control +- will be added soon + +## Acknowledgments +This code is heavily inspired by the following works: +- danijar's Dreamer-v2 tensorflow implementation: https://github.com/danijar/dreamerv2 +- jsikyoon's Dreamer-v2 pytorch implementation: https://github.com/jsikyoon/dreamer-torch +- RajGhugare19's Dreamer-v2 pytorch implementation: https://github.com/RajGhugare19/dreamerv2 +- denisyarats's DrQ-v2 original implementation: https://github.com/facebookresearch/drqv2 diff --git a/configs.yaml b/configs.yaml new file mode 100644 index 0000000..3ce6351 --- /dev/null +++ b/configs.yaml @@ -0,0 +1,136 @@ +defaults: + + logdir: null + traindir: null + evaldir: null + offline_traindir: '' + offline_evaldir: '' + seed: 0 + steps: 5e5 + eval_every: 1e4 + log_every: 1e4 + reset_every: 0 + #gpu_growth: True + device: 'cuda:0' + precision: 16 + debug: False + expl_gifs: False + + # Environment + task: 'dmc_walker_walk' + size: [64, 64] + envs: 1 + action_repeat: 2 + time_limit: 1000 + grayscale: False + prefill: 2500 + eval_noise: 0.0 + reward_trans: 'symlog' + obs_trans: 'normalize' + critic_trans: 'symlog' + reward_EMA: True + + # Model + dyn_cell: 'gru_layer_norm' + dyn_hidden: 512 + dyn_deter: 512 + dyn_stoch: 32 + dyn_discrete: 32 + dyn_input_layers: 1 + dyn_output_layers: 1 + dyn_rec_depth: 1 + dyn_shared: False + dyn_mean_act: 'none' + dyn_std_act: 'sigmoid2' + dyn_min_std: 0.1 + dyn_temp_post: True + grad_heads: ['image', 'reward', 'discount'] + units: 256 + reward_layers: 2 + discount_layers: 2 + value_layers: 2 + actor_layers: 2 + act: 'SiLU' + norm: 'LayerNorm' + cnn_depth: 32 + encoder_kernels: [3, 3, 3, 3] + decoder_kernels: [3, 3, 3, 3] + # changed here + value_head: 'twohot' + reward_head: 'twohot' + kl_lscale: '0.1' + kl_rscale: '0.5' + kl_free: '1.0' + kl_forward: False + pred_discount: True + discount_scale: 1.0 + reward_scale: 1.0 + weight_decay: 0.0 + unimix_ratio: 0.01 + + # Training + batch_size: 16 + batch_length: 64 + train_every: 5 + train_steps: 1 + pretrain: 100 + model_lr: 1e-4 + opt_eps: 1e-8 + grad_clip: 1000 + value_lr: 3e-5 + actor_lr: 3e-5 + ac_opt_eps: 1e-5 + value_grad_clip: 100 + actor_grad_clip: 100 + dataset_size: 0 + oversample_ends: False + slow_value_target: True + slow_actor_target: True + slow_target_update: 50 + slow_target_fraction: 0.01 + opt: 'adam' + + # Behavior. + discount: 0.997 + discount_lambda: 0.95 + imag_horizon: 15 + imag_gradient: 'dynamics' + imag_gradient_mix: '0.1' + imag_sample: True + actor_dist: 'trunc_normal' + actor_entropy: '3e-4' + actor_state_entropy: 0.0 + actor_init_std: 1.0 + actor_min_std: 0.1 + actor_disc: 5 + actor_temp: 0.1 + actor_outscale: 0.0 + expl_amount: 0.0 + eval_state_mean: False + collect_dyn_sample: True + behavior_stop_grad: True + value_decay: 0.0 + future_entropy: False + + # Exploration + expl_behavior: 'greedy' + expl_until: 0 + expl_extr_scale: 0.0 + expl_intr_scale: 1.0 + disag_target: 'stoch' + disag_log: True + disag_models: 10 + disag_offset: 1 + disag_layers: 4 + disag_units: 400 + disag_action_cond: False + +debug: + + debug: True + pretrain: 1 + prefill: 1 + train_steps: 1 + batch_size: 10 + batch_length: 20 + diff --git a/dreamer.py b/dreamer.py new file mode 100644 index 0000000..b855e64 --- /dev/null +++ b/dreamer.py @@ -0,0 +1,343 @@ +import argparse +import collections +import functools +import os +import pathlib +import sys +import warnings + +os.environ["MUJOCO_GL"] = "egl" + +import numpy as np +import ruamel.yaml as yaml + +sys.path.append(str(pathlib.Path(__file__).parent)) + +import exploration as expl +import models +import tools +import wrappers + +import torch +from torch import nn +from torch import distributions as torchd + +to_np = lambda x: x.detach().cpu().numpy() + + +class Dreamer(nn.Module): + def __init__(self, config, logger, dataset): + super(Dreamer, self).__init__() + self._config = config + self._logger = logger + self._should_log = tools.Every(config.log_every) + self._should_train = tools.Every(config.train_every) + self._should_pretrain = tools.Once() + self._should_reset = tools.Every(config.reset_every) + self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) + self._metrics = {} + self._step = count_steps(config.traindir) + # Schedules. + config.actor_entropy = lambda x=config.actor_entropy: tools.schedule( + x, self._step + ) + config.actor_state_entropy = ( + lambda x=config.actor_state_entropy: tools.schedule(x, self._step) + ) + config.imag_gradient_mix = lambda x=config.imag_gradient_mix: tools.schedule( + x, self._step + ) + self._dataset = dataset + self._wm = models.WorldModel(self._step, config) + self._task_behavior = models.ImagBehavior( + config, self._wm, config.behavior_stop_grad + ) + reward = lambda f, s, a: self._wm.heads["reward"](f).mean + self._expl_behavior = dict( + greedy=lambda: self._task_behavior, + random=lambda: expl.Random(config), + plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), + )[config.expl_behavior]() + + def __call__(self, obs, reset, state=None, reward=None, training=True): + step = self._step + if self._should_reset(step): + state = None + if state is not None and reset.any(): + mask = 1 - reset + for key in state[0].keys(): + for i in range(state[0][key].shape[0]): + state[0][key][i] *= mask[i] + for i in range(len(state[1])): + state[1][i] *= mask[i] + if training and self._should_train(step): + steps = ( + self._config.pretrain + if self._should_pretrain() + else self._config.train_steps + ) + for _ in range(steps): + self._train(next(self._dataset)) + if self._should_log(step): + for name, values in self._metrics.items(): + self._logger.scalar(name, float(np.mean(values))) + self._metrics[name] = [] + openl = self._wm.video_pred(next(self._dataset)) + self._logger.video("train_openl", to_np(openl)) + self._logger.write(fps=True) + + policy_output, state = self._policy(obs, state, training) + + if training: + self._step += len(reset) + self._logger.step = self._config.action_repeat * self._step + return policy_output, state + + def _policy(self, obs, state, training): + if state is None: + batch_size = len(obs["image"]) + latent = self._wm.dynamics.initial(len(obs["image"])) + action = torch.zeros((batch_size, self._config.num_actions)).to( + self._config.device + ) + else: + latent, action = state + embed = self._wm.encoder(self._wm.preprocess(obs)) + latent, _ = self._wm.dynamics.obs_step( + latent, action, embed, self._config.collect_dyn_sample + ) + if self._config.eval_state_mean: + latent["stoch"] = latent["mean"] + feat = self._wm.dynamics.get_feat(latent) + if not training: + actor = self._task_behavior.actor(feat) + action = actor.mode() + elif self._should_expl(self._step): + actor = self._expl_behavior.actor(feat) + action = actor.sample() + else: + actor = self._task_behavior.actor(feat) + action = actor.sample() + logprob = actor.log_prob(action) + latent = {k: v.detach() for k, v in latent.items()} + action = action.detach() + if self._config.actor_dist == "onehot_gumble": + action = torch.one_hot( + torch.argmax(action, dim=-1), self._config.num_actions + ) + action = self._exploration(action, training) + policy_output = {"action": action, "logprob": logprob} + state = (latent, action) + return policy_output, state + + def _exploration(self, action, training): + amount = self._config.expl_amount if training else self._config.eval_noise + if amount == 0: + return action + if "onehot" in self._config.actor_dist: + probs = amount / self._config.num_actions + (1 - amount) * action + return tools.OneHotDist(probs=probs).sample() + else: + return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1) + raise NotImplementedError(self._config.action_noise) + + def _train(self, data): + metrics = {} + post, context, mets = self._wm._train(data) + metrics.update(mets) + start = post + if self._config.pred_discount: # Last step could be terminal. + start = {k: v[:, :-1] for k, v in post.items()} + context = {k: v[:, :-1] for k, v in context.items()} + reward = lambda f, s, a: self._wm.heads["reward"]( + self._wm.dynamics.get_feat(s) + ).mode() + metrics.update(self._task_behavior._train(start, reward)[-1]) + if self._config.expl_behavior != "greedy": + if self._config.pred_discount: + data = {k: v[:, :-1] for k, v in data.items()} + mets = self._expl_behavior.train(start, context, data)[-1] + metrics.update({"expl_" + key: value for key, value in mets.items()}) + for name, value in metrics.items(): + if not name in self._metrics.keys(): + self._metrics[name] = [value] + else: + self._metrics[name].append(value) + + +def count_steps(folder): + return sum(int(str(n).split("-")[-1][:-4]) - 1 for n in folder.glob("*.npz")) + + +def make_dataset(episodes, config): + generator = tools.sample_episodes( + episodes, config.batch_length, config.oversample_ends + ) + dataset = tools.from_generator(generator, config.batch_size) + return dataset + + +def make_env(config, logger, mode, train_eps, eval_eps): + suite, task = config.task.split("_", 1) + if suite == "dmc": + env = wrappers.DeepMindControl(task, config.action_repeat, config.size) + env = wrappers.NormalizeActions(env) + elif suite == "atari": + env = wrappers.Atari( + task, + config.action_repeat, + config.size, + grayscale=config.grayscale, + life_done=False and ("train" in mode), + sticky_actions=True, + all_actions=True, + ) + env = wrappers.OneHotAction(env) + elif suite == "dmlab": + env = wrappers.DeepMindLabyrinth( + task, mode if "train" in mode else "test", config.action_repeat + ) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit) + env = wrappers.SelectAction(env, key="action") + if (mode == "train") or (mode == "eval"): + callbacks = [ + functools.partial( + process_episode, config, logger, mode, train_eps, eval_eps + ) + ] + env = wrappers.CollectDataset(env, callbacks) + env = wrappers.RewardObs(env) + return env + + +def process_episode(config, logger, mode, train_eps, eval_eps, episode): + directory = dict(train=config.traindir, eval=config.evaldir)[mode] + cache = dict(train=train_eps, eval=eval_eps)[mode] + filename = tools.save_episodes(directory, [episode])[0] + length = len(episode["reward"]) - 1 + score = float(episode["reward"].astype(np.float64).sum()) + video = episode["image"] + if mode == "eval": + cache.clear() + if mode == "train" and config.dataset_size: + total = 0 + for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): + if total <= config.dataset_size - length: + total += len(ep["reward"]) - 1 + else: + del cache[key] + logger.scalar("dataset_size", total + length) + cache[str(filename)] = episode + print(f"{mode.title()} episode has {length} steps and return {score:.1f}.") + logger.scalar(f"{mode}_return", score) + logger.scalar(f"{mode}_length", length) + logger.scalar(f"{mode}_episodes", len(cache)) + if mode == "eval" or config.expl_gifs: + logger.video(f"{mode}_policy", video[None]) + logger.write() + + +def main(config): + logdir = pathlib.Path(config.logdir).expanduser() + config.traindir = config.traindir or logdir / "train_eps" + config.evaldir = config.evaldir or logdir / "eval_eps" + config.steps //= config.action_repeat + config.eval_every //= config.action_repeat + config.log_every //= config.action_repeat + config.time_limit //= config.action_repeat + config.act = getattr(torch.nn, config.act) + config.norm = getattr(torch.nn, config.norm) + + print("Logdir", logdir) + logdir.mkdir(parents=True, exist_ok=True) + config.traindir.mkdir(parents=True, exist_ok=True) + config.evaldir.mkdir(parents=True, exist_ok=True) + step = count_steps(config.traindir) + logger = tools.Logger(logdir, config.action_repeat * step) + + print("Create envs.") + if config.offline_traindir: + directory = config.offline_traindir.format(**vars(config)) + else: + directory = config.traindir + train_eps = tools.load_episodes(directory, limit=config.dataset_size) + if config.offline_evaldir: + directory = config.offline_evaldir.format(**vars(config)) + else: + directory = config.evaldir + eval_eps = tools.load_episodes(directory, limit=1) + make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) + train_envs = [make("train") for _ in range(config.envs)] + eval_envs = [make("eval") for _ in range(config.envs)] + acts = train_envs[0].action_space + config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] + + if not config.offline_traindir: + prefill = max(0, config.prefill - count_steps(config.traindir)) + print(f"Prefill dataset ({prefill} steps).") + if hasattr(acts, "discrete"): + random_actor = tools.OneHotDist( + torch.zeros_like(torch.Tensor(acts.low))[None] + ) + else: + random_actor = torchd.independent.Independent( + torchd.uniform.Uniform( + torch.Tensor(acts.low)[None], torch.Tensor(acts.high)[None] + ), + 1, + ) + + def random_agent(o, d, s, r): + action = random_actor.sample() + logprob = random_actor.log_prob(action) + return {"action": action, "logprob": logprob}, None + + tools.simulate(random_agent, train_envs, prefill) + tools.simulate(random_agent, eval_envs, episodes=1) + logger.step = config.action_repeat * count_steps(config.traindir) + + print("Simulate agent.") + train_dataset = make_dataset(train_eps, config) + eval_dataset = make_dataset(eval_eps, config) + agent = Dreamer(config, logger, train_dataset).to(config.device) + agent.requires_grad_(requires_grad=False) + if (logdir / "latest_model.pt").exists(): + agent.load_state_dict(torch.load(logdir / "latest_model.pt")) + agent._should_pretrain._once = False + + state = None + while agent._step < config.steps: + logger.write() + print("Start evaluation.") + video_pred = agent._wm.video_pred(next(eval_dataset)) + logger.video("eval_openl", to_np(video_pred)) + eval_policy = functools.partial(agent, training=False) + tools.simulate(eval_policy, eval_envs, episodes=1) + print("Start training.") + state = tools.simulate(agent, train_envs, config.eval_every, state=state) + torch.save(agent.state_dict(), logdir / "latest_model.pt") + for env in train_envs + eval_envs: + try: + env.close() + except Exception: + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--configs", nargs="+", required=True) + args, remaining = parser.parse_known_args() + configs = yaml.safe_load( + (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text() + ) + defaults = {} + for name in args.configs: + defaults.update(configs[name]) + parser = argparse.ArgumentParser() + for key, value in sorted(defaults.items(), key=lambda x: x[0]): + arg_type = tools.args_type(value) + parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value)) + main(parser.parse_args(remaining)) diff --git a/exploration.py b/exploration.py new file mode 100644 index 0000000..77cbc3b --- /dev/null +++ b/exploration.py @@ -0,0 +1,108 @@ +import torch +from torch import nn +from torch import distributions as torchd + +import models +import networks +import tools + + +class Random(nn.Module): + def __init__(self, config): + self._config = config + + def actor(self, feat): + shape = feat.shape[:-1] + [self._config.num_actions] + if self._config.actor_dist == "onehot": + return tools.OneHotDist(torch.zeros(shape)) + else: + ones = torch.ones(shape) + return tools.ContDist(torchd.uniform.Uniform(-ones, ones)) + + def train(self, start, context): + return None, {} + + +# class Plan2Explore(tools.Module): +class Plan2Explore(nn.Module): + def __init__(self, config, world_model, reward=None): + self._config = config + self._reward = reward + self._behavior = models.ImagBehavior(config, world_model) + self.actor = self._behavior.actor + stoch_size = config.dyn_stoch + if config.dyn_discrete: + stoch_size *= config.dyn_discrete + size = { + "embed": 32 * config.cnn_depth, + "stoch": stoch_size, + "deter": config.dyn_deter, + "feat": config.dyn_stoch + config.dyn_deter, + }[self._config.disag_target] + kw = dict( + inp_dim=config.dyn_stoch, # pytorch version + shape=size, + layers=config.disag_layers, + units=config.disag_units, + act=config.act, + ) + self._networks = [networks.DenseHead(**kw) for _ in range(config.disag_models)] + self._opt = tools.optimizer( + config.opt, + self.parameters(), + config.model_lr, + config.opt_eps, + config.weight_decay, + ) + # self._opt = tools.Optimizer( + # 'ensemble', config.model_lr, config.opt_eps, config.grad_clip, + # config.weight_decay, opt=config.opt) + + def train(self, start, context, data): + metrics = {} + stoch = start["stoch"] + if self._config.dyn_discrete: + stoch = tf.reshape( + stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]) + ) + target = { + "embed": context["embed"], + "stoch": stoch, + "deter": start["deter"], + "feat": context["feat"], + }[self._config.disag_target] + inputs = context["feat"] + if self._config.disag_action_cond: + inputs = tf.concat([inputs, data["action"]], -1) + metrics.update(self._train_ensemble(inputs, target)) + metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1]) + return None, metrics + + def _intrinsic_reward(self, feat, state, action): + inputs = feat + if self._config.disag_action_cond: + inputs = tf.concat([inputs, action], -1) + preds = [head(inputs, tf.float32).mean() for head in self._networks] + disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1) + if self._config.disag_log: + disag = tf.math.log(disag) + reward = self._config.expl_intr_scale * disag + if self._config.expl_extr_scale: + reward += tf.cast( + self._config.expl_extr_scale * self._reward(feat, state, action), + tf.float32, + ) + return reward + + def _train_ensemble(self, inputs, targets): + if self._config.disag_offset: + targets = targets[:, self._config.disag_offset :] + inputs = inputs[:, : -self._config.disag_offset] + targets = tf.stop_gradient(targets) + inputs = tf.stop_gradient(inputs) + with tf.GradientTape() as tape: + preds = [head(inputs) for head in self._networks] + likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] + loss = -tf.cast(tf.reduce_sum(likes), tf.float32) + metrics = self._opt(tape, loss, self._networks) + return metrics diff --git a/models.py b/models.py new file mode 100644 index 0000000..7489a17 --- /dev/null +++ b/models.py @@ -0,0 +1,509 @@ +import copy +import torch +from torch import nn +import numpy as np +from PIL import ImageColor, Image, ImageDraw, ImageFont + +import networks +import tools + +to_np = lambda x: x.detach().cpu().numpy() + + +def symlog(x): + return torch.sign(x) * torch.log(torch.abs(x) + 1.0) + + +def symexp(x): + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) + + +class RewardEMA(object): + """running mean and std""" + + def __init__(self, device, alpha=1e-2): + self.device = device + self.scale = torch.zeros((1,)).to(device) + self.alpha = alpha + self.range = torch.tensor([0.05, 0.95]).to(device) + + def __call__(self, x): + flat_x = torch.flatten(x.detach()) + x_quantile = torch.quantile(input=flat_x, q=self.range) + scale = x_quantile[1] - x_quantile[0] + new_scale = self.alpha * scale + (1 - self.alpha) * self.scale + self.scale = new_scale + return x / torch.clip(self.scale, min=1.0) + + +class WorldModel(nn.Module): + def __init__(self, step, config): + super(WorldModel, self).__init__() + self._step = step + self._use_amp = True if config.precision == 16 else False + self._config = config + self.encoder = networks.ConvEncoder( + config.grayscale, + config.cnn_depth, + config.act, + config.norm, + config.encoder_kernels, + ) + if config.size[0] == 64 and config.size[1] == 64: + embed_size = ( + (64 // 2 ** (len(config.encoder_kernels))) ** 2 + * config.cnn_depth + * 2 ** (len(config.encoder_kernels) - 1) + ) + else: + raise NotImplemented(f"{config.size} is not applicable now") + self.dynamics = networks.RSSM( + config.dyn_stoch, + config.dyn_deter, + config.dyn_hidden, + config.dyn_input_layers, + config.dyn_output_layers, + config.dyn_rec_depth, + config.dyn_shared, + config.dyn_discrete, + config.act, + config.norm, + config.dyn_mean_act, + config.dyn_std_act, + config.dyn_temp_post, + config.dyn_min_std, + config.dyn_cell, + config.unimix_ratio, + config.num_actions, + embed_size, + config.device, + ) + self.heads = nn.ModuleDict() + channels = 1 if config.grayscale else 3 + shape = (channels,) + config.size + if config.dyn_discrete: + feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter + else: + feat_size = config.dyn_stoch + config.dyn_deter + self.heads["image"] = networks.ConvDecoder( + feat_size, # pytorch version + config.cnn_depth, + config.act, + config.norm, + shape, + config.decoder_kernels, + ) + if config.reward_head == "twohot": + self.heads["reward"] = networks.DenseHead( + feat_size, # pytorch version + (255,), + config.reward_layers, + config.units, + config.act, + config.norm, + dist=config.reward_head, + ) + else: + self.heads["reward"] = networks.DenseHead( + feat_size, # pytorch version + [], + config.reward_layers, + config.units, + config.act, + config.norm, + dist=config.reward_head, + ) + # added this + self.heads["reward"].apply(tools.weight_init) + if config.pred_discount: + self.heads["discount"] = networks.DenseHead( + feat_size, # pytorch version + [], + config.discount_layers, + config.units, + config.act, + config.norm, + dist="binary", + ) + for name in config.grad_heads: + assert name in self.heads, name + self._model_opt = tools.Optimizer( + "model", + self.parameters(), + config.model_lr, + config.opt_eps, + config.grad_clip, + config.weight_decay, + opt=config.opt, + use_amp=self._use_amp, + ) + self._scales = dict(reward=config.reward_scale, discount=config.discount_scale) + + def _train(self, data): + # action (batch_size, batch_length, act_dim) + # image (batch_size, batch_length, h, w, ch) + # reward (batch_size, batch_length) + # discount (batch_size, batch_length) + data = self.preprocess(data) + + with tools.RequiresGrad(self): + with torch.cuda.amp.autocast(self._use_amp): + embed = self.encoder(data) + post, prior = self.dynamics.observe(embed, data["action"]) + kl_free = tools.schedule(self._config.kl_free, self._step) + kl_lscale = tools.schedule(self._config.kl_lscale, self._step) + kl_rscale = tools.schedule(self._config.kl_rscale, self._step) + kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( + post, prior, self._config.kl_forward, kl_free, kl_lscale, kl_rscale + ) + losses = {} + likes = {} + for name, head in self.heads.items(): + grad_head = name in self._config.grad_heads + feat = self.dynamics.get_feat(post) + feat = feat if grad_head else feat.detach() + pred = head(feat) + # if name == 'image': + # losses[name] = torch.nn.functional.mse_loss(pred.mode(), data[name], 'sum') + like = pred.log_prob(data[name]) + likes[name] = like + losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) + model_loss = sum(losses.values()) + kl_loss + metrics = self._model_opt(model_loss, self.parameters()) + + metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()}) + metrics["kl_free"] = kl_free + metrics["kl_lscale"] = kl_lscale + metrics["kl_rscale"] = kl_rscale + metrics["loss_lhs"] = to_np(loss_lhs) + metrics["loss_rhs"] = to_np(loss_rhs) + metrics["kl"] = to_np(torch.mean(kl_value)) + with torch.cuda.amp.autocast(self._use_amp): + metrics["prior_ent"] = to_np( + torch.mean(self.dynamics.get_dist(prior).entropy()) + ) + metrics["post_ent"] = to_np( + torch.mean(self.dynamics.get_dist(post).entropy()) + ) + context = dict( + embed=embed, + feat=self.dynamics.get_feat(post), + kl=kl_value, + postent=self.dynamics.get_dist(post).entropy(), + ) + post = {k: v.detach() for k, v in post.items()} + return post, context, metrics + + def preprocess(self, obs): + obs = obs.copy() + if self._config.obs_trans == "normalize": + obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5 + elif self._config.obs_trans == "identity": + obs["image"] = torch.Tensor(obs["image"]) + elif self._config.obs_trans == "symlog": + obs["image"] = symlog(torch.Tensor(obs["image"])) + else: + raise NotImplemented(f"{self._config.reward_trans} is not implemented") + if self._config.reward_trans == "tanh": + # (batch_size, batch_length) -> (batch_size, batch_length, 1) + obs["reward"] = torch.tanh(torch.Tensor(obs["reward"])).unsqueeze(-1) + elif self._config.reward_trans == "identity": + # (batch_size, batch_length) -> (batch_size, batch_length, 1) + obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1) + elif self._config.reward_trans == "symlog": + obs["reward"] = symlog(torch.Tensor(obs["reward"])).unsqueeze(-1) + else: + raise NotImplemented(f"{self._config.reward_trans} is not implemented") + if "discount" in obs: + obs["discount"] *= self._config.discount + # (batch_size, batch_length) -> (batch_size, batch_length, 1) + obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1) + obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()} + return obs + + def video_pred(self, data): + data = self.preprocess(data) + embed = self.encoder(data) + + states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5]) + recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6] + reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6] + init = {k: v[:, -1] for k, v in states.items()} + prior = self.dynamics.imagine(data["action"][:6, 5:], init) + openl = self.heads["image"](self.dynamics.get_feat(prior)).mode() + reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() + # observed image is given until 5 steps + model = torch.cat([recon[:, :5], openl], 1) + if self._config.obs_trans == "normalize": + truth = data["image"][:6] + 0.5 + model += 0.5 + elif self._config.obs_trans == "symlog": + truth = symexp(data["image"][:6]) / 255.0 + model = symexp(model) / 255.0 + error = (model - truth + 1) / 2 + + return torch.cat([truth, model, error], 2) + + +class ImagBehavior(nn.Module): + def __init__(self, config, world_model, stop_grad_actor=True, reward=None): + super(ImagBehavior, self).__init__() + self._use_amp = True if config.precision == 16 else False + self._config = config + self._world_model = world_model + self._stop_grad_actor = stop_grad_actor + self._reward = reward + if config.dyn_discrete: + feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter + else: + feat_size = config.dyn_stoch + config.dyn_deter + self.actor = networks.ActionHead( + feat_size, # pytorch version + config.num_actions, + config.actor_layers, + config.units, + config.act, + config.norm, + config.actor_dist, + config.actor_init_std, + config.actor_min_std, + config.actor_dist, + config.actor_temp, + config.actor_outscale, + ) # action_dist -> action_disc? + if config.value_head == "twohot": + self.value = networks.DenseHead( + feat_size, # pytorch version + (255,), + config.value_layers, + config.units, + config.act, + config.norm, + config.value_head, + ) + else: + self.value = networks.DenseHead( + feat_size, # pytorch version + [], + config.value_layers, + config.units, + config.act, + config.norm, + config.value_head, + ) + self.value.apply(tools.weight_init) + if config.slow_value_target or config.slow_actor_target: + self._slow_value = copy.deepcopy(self.value) + self._updates = 0 + kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) + self._actor_opt = tools.Optimizer( + "actor", + self.actor.parameters(), + config.actor_lr, + config.ac_opt_eps, + config.actor_grad_clip, + **kw, + ) + self._value_opt = tools.Optimizer( + "value", + self.value.parameters(), + config.value_lr, + config.ac_opt_eps, + config.value_grad_clip, + **kw, + ) + if self._config.reward_EMA: + self.reward_ema = RewardEMA(device=self._config.device) + + def _train( + self, + start, + objective=None, + action=None, + reward=None, + imagine=None, + tape=None, + repeats=None, + ): + objective = objective or self._reward + self._update_slow_target() + metrics = {} + + with tools.RequiresGrad(self.actor): + with torch.cuda.amp.autocast(self._use_amp): + imag_feat, imag_state, imag_action = self._imagine( + start, self.actor, self._config.imag_horizon, repeats + ) + reward = objective(imag_feat, imag_state, imag_action) + if self._config.reward_trans == "symlog": + # rescale predicted reward by head['reward'] + reward = symexp(reward) + actor_ent = self.actor(imag_feat).entropy() + state_ent = self._world_model.dynamics.get_dist(imag_state).entropy() + # this target is not scaled + # slow is flag to indicate whether slow_target is used for lambda-return + target, weights = self._compute_target( + imag_feat, + imag_state, + imag_action, + reward, + actor_ent, + state_ent, + self._config.slow_actor_target, + ) + actor_loss, mets = self._compute_actor_loss( + imag_feat, + imag_state, + imag_action, + target, + actor_ent, + state_ent, + weights, + ) + metrics.update(mets) + if self._config.slow_value_target != self._config.slow_actor_target: + target, weights = self._compute_target( + imag_feat, + imag_state, + imag_action, + reward, + actor_ent, + state_ent, + self._config.slow_value_target, + ) + value_input = imag_feat + + with tools.RequiresGrad(self.value): + with torch.cuda.amp.autocast(self._use_amp): + value = self.value(value_input[:-1].detach()) + target = torch.stack(target, dim=1) + # only critic target is processed using symlog(not actor) + if self._config.critic_trans == "symlog": + metrics["unscaled_target_mean"] = to_np(torch.mean(target)) + target = symlog(target) + # (time, batch, 1), (time, batch, 1) -> (time, batch) + value_loss = -value.log_prob(target.detach()) + if self._config.value_decay: + value_loss += self._config.value_decay * value.mode() + # (time, batch, 1), (time, batch, 1) -> (1,) + value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) + + metrics["value_mean"] = to_np(torch.mean(value.mode())) + metrics["value_max"] = to_np(torch.max(value.mode())) + metrics["value_min"] = to_np(torch.min(value.mode())) + metrics["value_std"] = to_np(torch.std(value.mode())) + metrics["target_mean"] = to_np(torch.mean(target)) + metrics["reward_mean"] = to_np(torch.mean(reward)) + metrics["reward_std"] = to_np(torch.std(reward)) + metrics["actor_ent"] = to_np(torch.mean(actor_ent)) + with tools.RequiresGrad(self): + metrics.update(self._actor_opt(actor_loss, self.actor.parameters())) + metrics.update(self._value_opt(value_loss, self.value.parameters())) + return imag_feat, imag_state, imag_action, weights, metrics + + def _imagine(self, start, policy, horizon, repeats=None): + dynamics = self._world_model.dynamics + if repeats: + raise NotImplemented("repeats is not implemented in this version") + flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in start.items()} + + def step(prev, _): + state, _, _ = prev + feat = dynamics.get_feat(state) + inp = feat.detach() if self._stop_grad_actor else feat + action = policy(inp).sample() + succ = dynamics.img_step(state, action, sample=self._config.imag_sample) + return succ, feat, action + + feat = 0 * dynamics.get_feat(start) + action = policy(feat).mode() + succ, feats, actions = tools.static_scan( + step, [torch.arange(horizon)], (start, feat, action) + ) + states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} + if repeats: + raise NotImplemented("repeats is not implemented in this version") + + return feats, states, actions + + def _compute_target( + self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, slow + ): + if "discount" in self._world_model.heads: + inp = self._world_model.dynamics.get_feat(imag_state) + discount = self._world_model.heads["discount"](inp).mean + else: + discount = self._config.discount * torch.ones_like(reward) + if self._config.future_entropy and self._config.actor_entropy() > 0: + reward += self._config.actor_entropy() * actor_ent + if self._config.future_entropy and self._config.actor_state_entropy() > 0: + reward += self._config.actor_state_entropy() * state_ent + if slow: + value = self._slow_value(imag_feat).mode() + else: + value = self.value(imag_feat).mode() + if self._config.critic_trans == "symlog": + # After adding this line there is issue + value = symexp(value) + target = tools.lambda_return( + reward[:-1], + value[:-1], + discount[:-1], + bootstrap=value[-1], + lambda_=self._config.discount_lambda, + axis=0, + ) + weights = torch.cumprod( + torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0 + ).detach() + return target, weights + + def _compute_actor_loss( + self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights + ): + metrics = {} + inp = imag_feat.detach() if self._stop_grad_actor else imag_feat + policy = self.actor(inp) + actor_ent = policy.entropy() + # Q-val for actor is not transformed using symlog + target = torch.stack(target, dim=1) + if self._config.reward_EMA: + target = self.reward_ema(target) + metrics["EMA_scale"] = to_np(self.reward_ema.scale) + + if self._config.imag_gradient == "dynamics": + actor_target = target + elif self._config.imag_gradient == "reinforce": + actor_target = ( + policy.log_prob(imag_action)[:-1][:, :, None] + * (target - self.value(imag_feat[:-1]).mode()).detach() + ) + elif self._config.imag_gradient == "both": + actor_target = ( + policy.log_prob(imag_action)[:-1][:, :, None] + * (target - self.value(imag_feat[:-1]).mode()).detach() + ) + mix = self._config.imag_gradient_mix() + actor_target = mix * target + (1 - mix) * actor_target + metrics["imag_gradient_mix"] = mix + else: + raise NotImplementedError(self._config.imag_gradient) + if not self._config.future_entropy and (self._config.actor_entropy() > 0): + actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None] + actor_target += actor_entropy + metrics["actor_entropy"] = to_np(torch.mean(actor_entropy)) + if not self._config.future_entropy and (self._config.actor_state_entropy() > 0): + state_entropy = self._config.actor_state_entropy() * state_ent[:-1] + actor_target += state_entropy + metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy)) + actor_loss = -torch.mean(weights[:-1] * actor_target) + return actor_loss, metrics + + def _update_slow_target(self): + if self._config.slow_value_target or self._config.slow_actor_target: + if self._updates % self._config.slow_target_update == 0: + mix = self._config.slow_target_fraction + for s, d in zip(self.value.parameters(), self._slow_value.parameters()): + d.data = mix * s.data + (1 - mix) * d.data + self._updates += 1 diff --git a/networks.py b/networks.py new file mode 100644 index 0000000..30f6817 --- /dev/null +++ b/networks.py @@ -0,0 +1,631 @@ +import math +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torch import distributions as torchd + +import tools + + +class RSSM(nn.Module): + def __init__( + self, + stoch=30, + deter=200, + hidden=200, + layers_input=1, + layers_output=1, + rec_depth=1, + shared=False, + discrete=False, + act=nn.ELU, + norm=nn.LayerNorm, + mean_act="none", + std_act="softplus", + temp_post=True, + min_std=0.1, + cell="gru", + unimix_ratio=0.01, + num_actions=None, + embed=None, + device=None, + ): + super(RSSM, self).__init__() + self._stoch = stoch + self._deter = deter + self._hidden = hidden + self._min_std = min_std + self._layers_input = layers_input + self._layers_output = layers_output + self._rec_depth = rec_depth + self._shared = shared + self._discrete = discrete + self._act = act + self._norm = norm + self._mean_act = mean_act + self._std_act = std_act + self._temp_post = temp_post + self._unimix_ratio = unimix_ratio + self._embed = embed + self._device = device + + inp_layers = [] + if self._discrete: + inp_dim = self._stoch * self._discrete + num_actions + else: + inp_dim = self._stoch + num_actions + if self._shared: + inp_dim += self._embed + for i in range(self._layers_input): + inp_layers.append(nn.Linear(inp_dim, self._hidden)) + inp_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._inp_layers = nn.Sequential(*inp_layers) + + if cell == "gru": + self._cell = GRUCell(self._hidden, self._deter) + elif cell == "gru_layer_norm": + self._cell = GRUCell(self._hidden, self._deter, norm=True) + else: + raise NotImplementedError(cell) + + img_out_layers = [] + inp_dim = self._deter + for i in range(self._layers_output): + img_out_layers.append(nn.Linear(inp_dim, self._hidden)) + img_out_layers.append(self._norm(self._hidden)) + img_out_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._img_out_layers = nn.Sequential(*img_out_layers) + + obs_out_layers = [] + if self._temp_post: + inp_dim = self._deter + self._embed + else: + inp_dim = self._embed + for i in range(self._layers_output): + obs_out_layers.append(nn.Linear(inp_dim, self._hidden)) + obs_out_layers.append(self._norm(self._hidden)) + obs_out_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._obs_out_layers = nn.Sequential(*obs_out_layers) + + if self._discrete: + self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + else: + self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + + def initial(self, batch_size): + deter = torch.zeros(batch_size, self._deter).to(self._device) + if self._discrete: + state = dict( + logit=torch.zeros([batch_size, self._stoch, self._discrete]).to( + self._device + ), + stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to( + self._device + ), + deter=deter, + ) + else: + state = dict( + mean=torch.zeros([batch_size, self._stoch]).to(self._device), + std=torch.zeros([batch_size, self._stoch]).to(self._device), + stoch=torch.zeros([batch_size, self._stoch]).to(self._device), + deter=deter, + ) + return state + + def observe(self, embed, action, state=None): + swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) + if state is None: + state = self.initial(action.shape[0]) + # (batch, time, ch) -> (time, batch, ch) + embed, action = swap(embed), swap(action) + post, prior = tools.static_scan( + lambda prev_state, prev_act, embed: self.obs_step( + prev_state[0], prev_act, embed + ), + (action, embed), + (state, state), + ) + + # (batch, time, stoch, discrete_num) -> (batch, time, stoch, discrete_num) + post = {k: swap(v) for k, v in post.items()} + prior = {k: swap(v) for k, v in prior.items()} + return post, prior + + def imagine(self, action, state=None): + swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) + if state is None: + state = self.initial(action.shape[0]) + assert isinstance(state, dict), state + action = action + action = swap(action) + prior = tools.static_scan(self.img_step, [action], state) + prior = prior[0] + prior = {k: swap(v) for k, v in prior.items()} + return prior + + def get_feat(self, state): + stoch = state["stoch"] + if self._discrete: + shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete] + stoch = stoch.reshape(shape) + return torch.cat([stoch, state["deter"]], -1) + + def get_dist(self, state, dtype=None): + if self._discrete: + logit = state["logit"] + dist = torchd.independent.Independent( + tools.OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1 + ) + else: + mean, std = state["mean"], state["std"] + dist = tools.ContDist( + torchd.independent.Independent(torchd.normal.Normal(mean, std), 1) + ) + return dist + + def obs_step(self, prev_state, prev_action, embed, sample=True): + # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) + # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs + prior = self.img_step(prev_state, prev_action, None, sample) + if self._shared: + post = self.img_step(prev_state, prev_action, embed, sample) + else: + if self._temp_post: + x = torch.cat([prior["deter"], embed], -1) + else: + x = embed + # (batch_size, prior_deter + embed) -> (batch_size, hidden) + x = self._obs_out_layers(x) + # (batch_size, hidden) -> (batch_size, stoch, discrete_num) + stats = self._suff_stats_layer("obs", x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + post = {"stoch": stoch, "deter": prior["deter"], **stats} + return post, prior + + # this is used for making future image + def img_step(self, prev_state, prev_action, embed=None, sample=True): + # (batch, stoch, discrete_num) + prev_stoch = prev_state["stoch"] + if self._discrete: + shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] + # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num) + prev_stoch = prev_stoch.reshape(shape) + if self._shared: + if embed is None: + shape = list(prev_action.shape[:-1]) + [self._embed] + embed = torch.zeros(shape) + # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed) + x = torch.cat([prev_stoch, prev_action, embed], -1) + else: + x = torch.cat([prev_stoch, prev_action], -1) + # (batch, stoch * discrete_num + action, embed) -> (batch, hidden) + x = self._inp_layers(x) + for _ in range(self._rec_depth): # rec depth is not correctly implemented + deter = prev_state["deter"] + # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter) + x, deter = self._cell(x, [deter]) + deter = deter[0] # Keras wraps the state in a list. + # (batch, deter) -> (batch, hidden) + x = self._img_out_layers(x) + # (batch, hidden) -> (batch_size, stoch, discrete_num) + stats = self._suff_stats_layer("ims", x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + prior = {"stoch": stoch, "deter": deter, **stats} + return prior + + def _suff_stats_layer(self, name, x): + if self._discrete: + if name == "ims": + x = self._ims_stat_layer(x) + elif name == "obs": + x = self._obs_stat_layer(x) + else: + raise NotImplementedError + logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete]) + return {"logit": logit} + else: + if name == "ims": + x = self._ims_stat_layer(x) + elif name == "obs": + x = self._obs_stat_layer(x) + else: + raise NotImplementedError + mean, std = torch.split(x, [self._stoch] * 2, -1) + mean = { + "none": lambda: mean, + "tanh5": lambda: 5.0 * torch.tanh(mean / 5.0), + }[self._mean_act]() + std = { + "softplus": lambda: torch.softplus(std), + "abs": lambda: torch.abs(std + 1), + "sigmoid": lambda: torch.sigmoid(std), + "sigmoid2": lambda: 2 * torch.sigmoid(std / 2), + }[self._std_act]() + std = std + self._min_std + return {"mean": mean, "std": std} + + def kl_loss(self, post, prior, forward, free, lscale, rscale): + kld = torchd.kl.kl_divergence + dist = lambda x: self.get_dist(x) + sg = lambda x: {k: v.detach() for k, v in x.items()} + # forward == false -> (post, prior) + lhs, rhs = (prior, post) if forward else (post, prior) + + # forward == false -> Lrep + value_lhs = value = kld( + dist(lhs) if self._discrete else dist(lhs)._dist, + dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist, + ) + # forward == false -> Ldyn + value_rhs = kld( + dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, + dist(rhs) if self._discrete else dist(rhs)._dist, + ) + loss_lhs = torch.clip(torch.mean(value_lhs), min=free) + loss_rhs = torch.clip(torch.mean(value_rhs), min=free) + loss = lscale * loss_lhs + rscale * loss_rhs + + return loss, value, loss_lhs, loss_rhs + + +class ConvEncoder(nn.Module): + def __init__( + self, + grayscale=False, + depth=32, + act=nn.ELU, + norm=nn.LayerNorm, + kernels=(3, 3, 3, 3), + ): + super(ConvEncoder, self).__init__() + self._act = act + self._norm = norm + self._depth = depth + self._kernels = kernels + h, w = 64, 64 + layers = [] + for i, kernel in enumerate(self._kernels): + if i == 0: + if grayscale: + inp_dim = 1 + else: + inp_dim = 3 + else: + inp_dim = 2 ** (i - 1) * self._depth + depth = 2**i * self._depth + layers.append( + Conv2dSame( + in_channels=inp_dim, + out_channels=depth, + kernel_size=(kernel, kernel), + stride=(2, 2), + ) + ) + h, w = h // 2, w // 2 + # layers.append(norm([depth, h, w])) + layers.append(act()) + self.layers = nn.Sequential(*layers) + + def __call__(self, obs): + x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:])) + x = x.permute(0, 3, 1, 2) + x = self.layers(x) + # prod: product of all elements + x = x.reshape([x.shape[0], np.prod(x.shape[1:])]) + shape = list(obs["image"].shape[:-3]) + [x.shape[-1]] + return x.reshape(shape) + + +class ConvDecoder(nn.Module): + def __init__( + self, + inp_depth, + depth=32, + act=nn.ELU, + norm=nn.LayerNorm, + shape=(3, 64, 64), + kernels=(3, 3, 3, 3), + ): + super(ConvDecoder, self).__init__() + self._inp_depth = inp_depth + self._act = act + self._norm = norm + self._depth = depth + self._shape = shape + self._kernels = kernels + self._embed_size = ( + (64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1) + ) + + self._linear_layer = nn.Linear(inp_depth, self._embed_size) + inp_dim = self._embed_size // 16 + + cnnt_layers = [] + h, w = 4, 4 + for i, kernel in enumerate(self._kernels): + depth = self._embed_size // 16 // (2 ** (i + 1)) + act = self._act + if i == len(self._kernels) - 1: + depth = self._shape[0] + act = None + if i != 0: + inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth + pad_h, outpad_h = calc_same_pad(k=kernel, s=2, d=1) + pad_w, outpad_w = calc_same_pad(k=kernel, s=2, d=1) + cnnt_layers.append( + nn.ConvTranspose2d( + inp_dim, + depth, + kernel, + 2, + padding=(pad_h, pad_w), + output_padding=(outpad_h, outpad_w), + ) + ) + h, w = h * 2, w * 2 + # cnnt_layers.append(norm([depth, h, w])) + if act is not None: + cnnt_layers.append(act()) + self._cnnt_layers = nn.Sequential(*cnnt_layers) + + def __call__(self, features, dtype=None): + x = self._linear_layer(features) + x = x.reshape([-1, 4, 4, self._embed_size // 16]) + x = x.permute(0, 3, 1, 2) + x = self._cnnt_layers(x) + mean = x.reshape(features.shape[:-1] + self._shape) + mean = mean.permute(0, 1, 3, 4, 2) + return tools.ContDist( + torchd.independent.Independent( + torchd.normal.Normal(mean, 1), len(self._shape) + ) + ) + + +class DenseHead(nn.Module): + def __init__( + self, + inp_dim, + shape, + layers, + units, + act=nn.ELU, + norm=nn.LayerNorm, + dist="normal", + std=1.0, + unimix_ratio=0.0, + ): + super(DenseHead, self).__init__() + self._shape = (shape,) if isinstance(shape, int) else shape + if len(self._shape) == 0: + self._shape = (1,) + self._layers = layers + self._units = units + self._act = act + self._norm = norm + self._dist = dist + self._std = std + self._unimix_ratio = unimix_ratio + + mean_layers = [] + for index in range(self._layers): + mean_layers.append(nn.Linear(inp_dim, self._units)) + mean_layers.append(norm(self._units)) + mean_layers.append(act()) + if index == 0: + inp_dim = self._units + mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape))) + self._mean_layers = nn.Sequential(*mean_layers) + + if self._std == "learned": + self._std_layer = nn.Linear(self._units, np.prod(self._shape)) + + def __call__(self, features, dtype=None): + x = features + mean = self._mean_layers(x) + if self._std == "learned": + std = self._std_layer(x) + std = torch.softplus(std) + 0.01 + else: + std = self._std + if self._dist == "normal": + return tools.ContDist( + torchd.independent.Independent( + torchd.normal.Normal(mean, std), len(self._shape) + ) + ) + if self._dist == "huber": + return tools.ContDist( + torchd.independent.Independent( + tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape) + ) + ) + if self._dist == "binary": + return tools.Bernoulli( + torchd.independent.Independent( + torchd.bernoulli.Bernoulli(logits=mean), len(self._shape) + ) + ) + if self._dist == "twohot": + return tools.TwoHotDist(logits=mean, unimix_ratio=self._unimix_ratio) + raise NotImplementedError(self._dist) + + +class ActionHead(nn.Module): + def __init__( + self, + inp_dim, + size, + layers, + units, + act=nn.ELU, + norm=nn.LayerNorm, + dist="trunc_normal", + init_std=0.0, + min_std=0.1, + action_disc=5, + temp=0.1, + outscale=0, + ): + super(ActionHead, self).__init__() + self._size = size + self._layers = layers + self._units = units + self._dist = dist + self._act = act + self._norm = norm + self._min_std = min_std + self._init_std = init_std + self._action_disc = action_disc + self._temp = temp() if callable(temp) else temp + self._outscale = outscale + + pre_layers = [] + for index in range(self._layers): + pre_layers.append(nn.Linear(inp_dim, self._units)) + pre_layers.append(norm(self._units)) + pre_layers.append(act()) + if index == 0: + inp_dim = self._units + self._pre_layers = nn.Sequential(*pre_layers) + + if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: + self._dist_layer = nn.Linear(self._units, 2 * self._size) + elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: + self._dist_layer = nn.Linear(self._units, self._size) + + def __call__(self, features, dtype=None): + x = features + x = self._pre_layers(x) + if self._dist == "tanh_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = torch.tanh(mean) + std = F.softplus(std + self._init_std) + self._min_std + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution( + dist, tools.TanhBijector() + ) + dist = torchd.independent.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == "tanh_normal_5": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = 5 * torch.tanh(mean / 5) + std = F.softplus(std + 5) + 5 + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution( + dist, tools.TanhBijector() + ) + dist = torchd.independent.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == "normal": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + std = F.softplus(std + self._init_std) + self._min_std + dist = torchd.normal.Normal(mean, std) + dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "normal_1": + x = self._dist_layer(x) + dist = torchd.normal.Normal(mean, 1) + dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "trunc_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, [self._size] * 2, -1) + mean = torch.tanh(mean) + std = 2 * torch.sigmoid(std / 2) + self._min_std + dist = tools.SafeTruncatedNormal(mean, std, -1, 1) + dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "onehot": + x = self._dist_layer(x) + dist = tools.OneHotDist(x) + elif self._dist == "onehot_gumble": + x = self._dist_layer(x) + temp = self._temp + dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) + else: + raise NotImplementedError(self._dist) + return dist + + +class GRUCell(nn.Module): + def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): + super(GRUCell, self).__init__() + self._inp_size = inp_size + self._size = size + self._act = act + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(inp_size + size, 3 * size, bias=norm is not None) + if norm: + self._norm = nn.LayerNorm(3 * size) + + @property + def state_size(self): + return self._size + + def forward(self, inputs, state): + state = state[0] # Keras wraps the state in a list. + parts = self._layer(torch.cat([inputs, state], -1)) + if self._norm: + parts = self._norm(parts) + reset, cand, update = torch.split(parts, [self._size] * 3, -1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, [output] + + +class Conv2dSame(torch.nn.Conv2d): + def calc_same_pad(self, i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x): + ih, iw = x.size()[-2:] + pad_h = self.calc_same_pad( + i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] + ) + pad_w = self.calc_same_pad( + i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] + ) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + + ret = F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return ret + + +def calc_same_pad(k, s, d): + val = d * (k - 1) - s + 1 + pad = math.ceil(val / 2) + outpad = pad * 2 - val + return pad, outpad diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..30f5b3d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +torch==1.13.0 +numpy==1.20.1 +torchvision==0.14.0 +tensorboard==2.5.0 +pandas==1.2.4 +matplotlib==3.4.1 +ruamel.yaml==0.17.4 +gym[atari]==0.18.0 +moviepy==1.0.3 +einops==0.3.0 +protobuf==3.20.0 +dm_control==1.0.9 \ No newline at end of file diff --git a/tools.py b/tools.py new file mode 100644 index 0000000..13b351e --- /dev/null +++ b/tools.py @@ -0,0 +1,700 @@ +import datetime +import io +import json +import pathlib +import pickle +import re +import time +import uuid + +import numpy as np + +import torch +from torch import nn +from torch.nn import functional as F +from torch import distributions as torchd +from torch.utils.data import Dataset +from torch.utils.tensorboard import SummaryWriter + + +class RequiresGrad: + + def __init__(self, model): + self._model = model + + def __enter__(self): + self._model.requires_grad_(requires_grad=True) + + def __exit__(self, *args): + self._model.requires_grad_(requires_grad=False) + + +class TimeRecording: + + def __init__(self, comment): + self._comment = comment + + def __enter__(self): + self._st = torch.cuda.Event(enable_timing=True) + self._nd = torch.cuda.Event(enable_timing=True) + self._st.record() + + def __exit__(self, *args): + self._nd.record() + torch.cuda.synchronize() + print(self._comment, self._st.elapsed_time(self._nd)/1000) + + +class Logger: + + def __init__(self, logdir, step): + self._logdir = logdir + self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000) + self._last_step = None + self._last_time = None + self._scalars = {} + self._images = {} + self._videos = {} + self.step = step + + def scalar(self, name, value): + self._scalars[name] = float(value) + + def image(self, name, value): + self._images[name] = np.array(value) + + def video(self, name, value): + self._videos[name] = np.array(value) + + def write(self, fps=False): + scalars = list(self._scalars.items()) + if fps: + scalars.append(('fps', self._compute_fps(self.step))) + print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) + with (self._logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n') + for name, value in scalars: + self._writer.add_scalar('scalars/' + name, value, self.step) + for name, value in self._images.items(): + self._writer.add_image(name, value, self.step) + for name, value in self._videos.items(): + name = name if isinstance(name, str) else name.decode('utf-8') + if np.issubdtype(value.dtype, np.floating): + value = np.clip(255 * value, 0, 255).astype(np.uint8) + B, T, H, W, C = value.shape + value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) + self._writer.add_video(name, value, self.step, 16) + + self._writer.flush() + self._scalars = {} + self._images = {} + self._videos = {} + + def _compute_fps(self, step): + if self._last_step is None: + self._last_time = time.time() + self._last_step = step + return 0 + steps = step - self._last_step + duration = time.time() - self._last_time + self._last_time += duration + self._last_step = step + return steps / duration + + def offline_scalar(self, name, value, step): + self._writer.add_scalar('scalars/'+name, value, step) + + def offline_video(self, name, value, step): + if np.issubdtype(value.dtype, np.floating): + value = np.clip(255 * value, 0, 255).astype(np.uint8) + B, T, H, W, C = value.shape + value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) + self._writer.add_video(name, value, step, 16) + + +def simulate(agent, envs, steps=0, episodes=0, state=None): + # Initialize or unpack simulation state. + if state is None: + step, episode = 0, 0 + done = np.ones(len(envs), np.bool) + length = np.zeros(len(envs), np.int32) + obs = [None] * len(envs) + agent_state = None + reward = [0]*len(envs) + else: + step, episode, done, length, obs, agent_state, reward = state + while (steps and step < steps) or (episodes and episode < episodes): + # Reset envs if necessary. + if done.any(): + indices = [index for index, d in enumerate(done) if d] + results = [envs[i].reset() for i in indices] + for index, result in zip(indices, results): + obs[index] = result + reward = [reward[i]*(1-done[i]) for i in range(len(envs))] + # Step agents. + obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} + action, agent_state = agent(obs, done, agent_state, reward) + if isinstance(action, dict): + action = [ + {k: np.array(action[k][i].detach().cpu()) for k in action} + for i in range(len(envs))] + else: + action = np.array(action) + assert len(action) == len(envs) + # Step envs. + results = [e.step(a) for e, a in zip(envs, action)] + obs, reward, done = zip(*[p[:3] for p in results]) + obs = list(obs) + reward = list(reward) + done = np.stack(done) + episode += int(done.sum()) + length += 1 + step += (done * length).sum() + length *= (1 - done) + + return (step - steps, episode - episodes, done, length, obs, agent_state, reward) + + +def save_episodes(directory, episodes): + directory = pathlib.Path(directory).expanduser() + directory.mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') + filenames = [] + for episode in episodes: + identifier = str(uuid.uuid4().hex) + length = len(episode['reward']) + filename = directory / f'{timestamp}-{identifier}-{length}.npz' + with io.BytesIO() as f1: + np.savez_compressed(f1, **episode) + f1.seek(0) + with filename.open('wb') as f2: + f2.write(f1.read()) + filenames.append(filename) + return filenames + + +def from_generator(generator, batch_size): + while True: + batch = [] + for _ in range(batch_size): + batch.append(next(generator)) + data = {} + for key in batch[0].keys(): + data[key] = [] + for i in range(batch_size): + data[key].append(batch[i][key]) + data[key] = np.stack(data[key], 0) + yield data + + +def sample_episodes(episodes, length=None, balance=False, seed=0): + random = np.random.RandomState(seed) + while True: + episode = random.choice(list(episodes.values())) + if length: + total = len(next(iter(episode.values()))) + available = total - length + if available < 1: + print(f'Skipped short episode of length {available}.') + continue + if balance: + index = min(random.randint(0, total), available) + else: + index = int(random.randint(0, available + 1)) + episode = {k: v[index: index + length] for k, v in episode.items()} + yield episode + + +def load_episodes(directory, limit=None, reverse=True): + directory = pathlib.Path(directory).expanduser() + episodes = {} + total = 0 + if reverse: + for filename in reversed(sorted(directory.glob('*.npz'))): + try: + with filename.open('rb') as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f'Could not load episode: {e}') + continue + episodes[str(filename)] = episode + total += len(episode['reward']) - 1 + if limit and total >= limit: + break + else: + for filename in sorted(directory.glob('*.npz')): + try: + with filename.open('rb') as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f'Could not load episode: {e}') + continue + episodes[str(filename)] = episode + total += len(episode['reward']) - 1 + if limit and total >= limit: + break + return episodes + + +class SampleDist: + + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples + + @property + def name(self): + return 'SampleDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + samples = self._dist.sample(self._samples) + return torch.mean(samples, 0) + + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return sample[torch.argmax(logprob)][0] + + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -torch.mean(logprob, 0) + + +class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): + + def __init__(self, logits=None, probs=None, unimix_ratio=0.0): + if logits is not None and probs is None and unimix_ratio > 0.0: + probs = F.softmax(logits, dim=-1) + probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1] + logits = None + super().__init__(logits=logits, probs=probs) + + def mode(self): + _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) + return _mode.detach() + super().logits - super().logits.detach() + + def sample(self, sample_shape=(), seed=None): + if seed is not None: + raise ValueError('need to check') + sample = super().sample(sample_shape) + probs = super().probs + while len(probs.shape) < len(sample.shape): + probs = probs[None] + sample += probs - probs.detach() + return sample + + +class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical): + + def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'): + if logits is not None and probs is None and unimix_ratio > 0.0: + probs = F.softmax(logits, dim=-1) + probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1] + logits = None + super().__init__(logits=logits, probs=probs) + + self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device) + self.width = (self.buckets[-1] - self.buckets[0]) / 255 + + def mode(self): + _mode = super().probs * self.buckets + return torch.sum(_mode, dim=-1, keepdim=True) + + # Inside OneHotCategorical, log_prob is calculated using only max element in targets + def log_prob(self, x): + # x(time, batch, 1) + x = (x - self.buckets[0]) / self.width + lower_indices = (x).to(torch.int64) + # lower_indices is idnside 0 ~ len(buckets)-2 + lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2) + # upper_indices is inside 1 ~ len(buckets)-1 + upper_indices = lower_indices + 1 + lower_weight = torch.abs(x - upper_indices).squeeze(-1) + upper_weight = torch.abs(x - lower_indices).squeeze(-1) + # (time, batch, 1) -> (time, batch, bucket_class) + lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets))) + upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets))) + + # label = lower_log_prob * lower_weight + upper_log_prob * upper_weight + # # (time, batch, bucket_class) -> (time, batch) + # cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1) + + return lower_weight * lower_log_prob + upper_weight * upper_log_prob + +class ContDist: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + return self._dist.mean + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + return self._dist.log_prob(x) + + +class Bernoulli: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + _mode = torch.round(self._dist.mean) + return _mode.detach() +self._dist.mean - self._dist.mean.detach() + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + _logits = self._dist.base_dist.logits + log_probs0 = -F.softplus(_logits) + log_probs1 = -F.softplus(-_logits) + + return log_probs0 * (1-x) + log_probs1 * x + + +class UnnormalizedHuber(torchd.normal.Normal): + + def __init__(self, loc, scale, threshold=1, **kwargs): + super().__init__(loc, scale, **kwargs) + self._threshold = threshold + + def log_prob(self, event): + return -(torch.sqrt( + (event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) + + def mode(self): + return self.mean + + +class SafeTruncatedNormal(torchd.normal.Normal): + + def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): + super().__init__(loc, scale) + self._low = low + self._high = high + self._clip = clip + self._mult = mult + + def sample(self, sample_shape): + event = super().sample(sample_shape) + if self._clip: + clipped = torch.clip(event, self._low + self._clip, + self._high - self._clip) + event = event - event.detach() + clipped.detach() + if self._mult: + event *= self._mult + return event + + +class TanhBijector(torchd.Transform): + + def __init__(self, validate_args=False, name='tanh'): + super().__init__() + + def _forward(self, x): + return torch.tanh(x) + + def _inverse(self, y): + y = torch.where( + (torch.abs(y) <= 1.), + torch.clamp(y, -0.99999997, 0.99999997), y) + y = torch.atanh(y) + return y + + def _forward_log_det_jacobian(self, x): + log2 = torch.math.log(2.0) + return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) + + +def static_scan_for_lambda_return(fn, inputs, start): + last = start + indices = range(inputs[0].shape[0]) + indices = reversed(indices) + flag = True + for index in indices: + inp = lambda x: (_input[x] for _input in inputs) + last = fn(last, *inp(index)) + if flag: + outputs = last + flag = False + else: + outputs = torch.cat([outputs, last], dim=-1) + outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1]) + outputs = torch.unbind(outputs, dim=0) + return outputs + + +def lambda_return( + reward, value, pcont, bootstrap, lambda_, axis): + # Setting lambda=1 gives a discounted Monte Carlo return. + # Setting lambda=0 gives a fixed 1-step return. + #assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) + assert len(reward.shape) == len(value.shape), (reward.shape, value.shape) + if isinstance(pcont, (int, float)): + pcont = pcont * torch.ones_like(reward) + dims = list(range(len(reward.shape))) + dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] + if axis != 0: + reward = reward.permute(dims) + value = value.permute(dims) + pcont = pcont.permute(dims) + if bootstrap is None: + bootstrap = torch.zeros_like(value[-1]) + next_values = torch.cat([value[1:], bootstrap[None]], 0) + inputs = reward + pcont * next_values * (1 - lambda_) + #returns = static_scan( + # lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, + # (inputs, pcont), bootstrap, reverse=True) + # reimplement to optimize performance + returns = static_scan_for_lambda_return( + lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, + (inputs, pcont), bootstrap) + if axis != 0: + returns = returns.permute(dims) + return returns + + +class Optimizer(): + + def __init__( + self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*', + opt='adam', use_amp=False): + assert 0 <= wd < 1 + assert not clip or 1 <= clip + self._name = name + self._parameters = parameters + self._clip = clip + self._wd = wd + self._wd_pattern = wd_pattern + self._opt = { + 'adam': lambda: torch.optim.Adam(parameters, + lr=lr, + eps=eps), + 'nadam': lambda: NotImplemented( + f'{config.opt} is not implemented'), + 'adamax': lambda: torch.optim.Adamax(parameters, + lr=lr, + eps=eps), + 'sgd': lambda: torch.optim.SGD(parameters, + lr=lr), + 'momentum': lambda: torch.optim.SGD(parameters, + lr=lr, + momentum=0.9), + }[opt]() + self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + + def __call__(self, loss, params, retain_graph=False): + assert len(loss.shape) == 0, loss.shape + metrics = {} + metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy() + self._scaler.scale(loss).backward() + self._scaler.unscale_(self._opt) + #loss.backward(retain_graph=retain_graph) + norm = torch.nn.utils.clip_grad_norm_(params, self._clip) + if self._wd: + self._apply_weight_decay(params) + self._scaler.step(self._opt) + self._scaler.update() + #self._opt.step() + self._opt.zero_grad() + metrics[f'{self._name}_grad_norm'] = norm.item() + return metrics + + def _apply_weight_decay(self, varibs): + nontrivial = (self._wd_pattern != r'.*') + if nontrivial: + raise NotImplementedError + for var in varibs: + var.data = (1 - self._wd) * var.data + + +def args_type(default): + def parse_string(x): + if default is None: + return x + if isinstance(default, bool): + return bool(['False', 'True'].index(x)) + if isinstance(default, int): + return float(x) if ('e' in x or '.' in x) else int(x) + if isinstance(default, (list, tuple)): + return tuple(args_type(default[0])(y) for y in x.split(',')) + return type(default)(x) + def parse_object(x): + if isinstance(default, (list, tuple)): + return tuple(x) + return x + return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x) + + +def static_scan(fn, inputs, start): + last = start + indices = range(inputs[0].shape[0]) + flag = True + for index in indices: + inp = lambda x: (_input[x] for _input in inputs) + last = fn(last, *inp(index)) + if flag: + if type(last) == type({}): + outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} + else: + outputs = [] + for _last in last: + if type(_last) == type({}): + outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) + else: + outputs.append(_last.clone().unsqueeze(0)) + flag = False + else: + if type(last) == type({}): + for key in last.keys(): + outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) + else: + for j in range(len(outputs)): + if type(last[j]) == type({}): + for key in last[j].keys(): + outputs[j][key] = torch.cat([outputs[j][key], + last[j][key].unsqueeze(0)], dim=0) + else: + outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) + if type(last) == type({}): + outputs = [outputs] + return outputs + + +# Original version +#def static_scan2(fn, inputs, start, reverse=False): +# last = start +# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))] +# indices = range(inputs[0].shape[0]) +# if reverse: +# indices = reversed(indices) +# for index in indices: +# inp = lambda x: (_input[x] for _input in inputs) +# last = fn(last, *inp(index)) +# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)] +# if reverse: +# outputs = [list(reversed(x)) for x in outputs] +# res = [[]] * len(outputs) +# for i in range(len(outputs)): +# if type(outputs[i][0]) == type({}): +# _res = {} +# for key in outputs[i][0].keys(): +# _res[key] = [] +# for j in range(len(outputs[i])): +# _res[key].append(outputs[i][j][key]) +# #_res[key] = torch.stack(_res[key], 0) +# _res[key] = faster_stack(_res[key], 0) +# else: +# _res = outputs[i] +# #_res = torch.stack(_res, 0) +# _res = faster_stack(_res, 0) +# res[i] = _res +# return res + + +class Every: + + def __init__(self, every): + self._every = every + self._last = None + + def __call__(self, step): + if not self._every: + return False + if self._last is None: + self._last = step + return True + if step >= self._last + self._every: + self._last += self._every + return True + return False + + +class Once: + + def __init__(self): + self._once = True + + def __call__(self): + if self._once: + self._once = False + return True + return False + + +class Until: + + def __init__(self, until): + self._until = until + + def __call__(self, step): + if not self._until: + return True + return step < self._until + + +def schedule(string, step): + try: + return float(string) + except ValueError: + match = re.match(r'linear\((.+),(.+),(.+)\)', string) + if match: + initial, final, duration = [float(group) for group in match.groups()] + mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] + return (1 - mix) * initial + mix * final + match = re.match(r'warmup\((.+),(.+)\)', string) + if match: + warmup, value = [float(group) for group in match.groups()] + scale = torch.clip(step / warmup, 0, 1) + return scale * value + match = re.match(r'exp\((.+),(.+),(.+)\)', string) + if match: + initial, final, halflife = [float(group) for group in match.groups()] + return (initial - final) * 0.5 ** (step / halflife) + final + match = re.match(r'horizon\((.+),(.+),(.+)\)', string) + if match: + initial, final, duration = [float(group) for group in match.groups()] + mix = torch.clip(step / duration, 0, 1) + horizon = (1 - mix) * initial + mix * final + return 1 - 1 / horizon + raise NotImplementedError(string) + +def weight_init(m): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + gain = nn.init.calculate_gain('relu') + nn.init.orthogonal_(m.weight.data, gain) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) \ No newline at end of file diff --git a/wrappers.py b/wrappers.py new file mode 100644 index 0000000..38448cc --- /dev/null +++ b/wrappers.py @@ -0,0 +1,419 @@ +import threading + +import gym +import numpy as np + + +class DeepMindLabyrinth(object): + ACTION_SET_DEFAULT = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (0, 0, 0, -1, 0, 0, 0), # Backward + (0, 0, -1, 0, 0, 0, 0), # Strafe Left + (0, 0, 1, 0, 0, 0, 0), # Strafe Right + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward + (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward + (0, 0, 0, 0, 1, 0, 0), # Fire + ) + + ACTION_SET_MEDIUM = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (0, 0, 0, -1, 0, 0, 0), # Backward + (0, 0, -1, 0, 0, 0, 0), # Strafe Left + (0, 0, 1, 0, 0, 0, 0), # Strafe Right + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + (0, 0, 0, 0, 0, 0, 0), # Idle. + ) + + ACTION_SET_SMALL = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + ) + + def __init__( + self, + level, + mode, + action_repeat=4, + render_size=(64, 64), + action_set=ACTION_SET_DEFAULT, + level_cache=None, + seed=None, + runfiles_path=None, + ): + assert mode in ("train", "test") + import deepmind_lab + + if runfiles_path: + print("Setting DMLab runfiles path:", runfiles_path) + deepmind_lab.set_runfiles_path(runfiles_path) + self._config = {} + self._config["width"] = render_size[0] + self._config["height"] = render_size[1] + self._config["logLevel"] = "WARN" + if mode == "test": + self._config["allowHoldOutLevels"] = "true" + self._config["mixerSeed"] = 0x600D5EED + self._action_repeat = action_repeat + self._random = np.random.RandomState(seed) + self._env = deepmind_lab.Lab( + level="contributed/dmlab30/" + level, + observations=["RGB_INTERLEAVED"], + config={k: str(v) for k, v in self._config.items()}, + level_cache=level_cache, + ) + self._action_set = action_set + self._last_image = None + self._done = True + + @property + def observation_space(self): + shape = (self._config["height"], self._config["width"], 3) + space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) + return gym.spaces.Dict({"image": space}) + + @property + def action_space(self): + return gym.spaces.Discrete(len(self._action_set)) + + def reset(self): + self._done = False + self._env.reset(seed=self._random.randint(0, 2**31 - 1)) + obs = self._get_obs() + return obs + + def step(self, action): + raw_action = np.array(self._action_set[action], np.intc) + reward = self._env.step(raw_action, num_steps=self._action_repeat) + self._done = not self._env.is_running() + obs = self._get_obs() + return obs, reward, self._done, {} + + def render(self, *args, **kwargs): + if kwargs.get("mode", "rgb_array") != "rgb_array": + raise ValueError("Only render mode 'rgb_array' is supported.") + del args # Unused + del kwargs # Unused + return self._last_image + + def close(self): + self._env.close() + + def _get_obs(self): + if self._done: + image = 0 * self._last_image + else: + image = self._env.observations()["RGB_INTERLEAVED"] + self._last_image = image + return {"image": image} + + +class DeepMindControl: + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): + domain, task = name.split("_", 1) + if domain == "cup": # Only domain with multiple words. + domain = "ball_in_cup" + if isinstance(domain, str): + from dm_control import suite + + self._env = suite.load(domain, task) + else: + assert task is None + self._env = domain() + self._action_repeat = action_repeat + self._size = size + if camera is None: + camera = dict(quadruped=2).get(domain, 0) + self._camera = camera + + @property + def observation_space(self): + spaces = {} + for key, value in self._env.observation_spec().items(): + spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32) + spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + spec = self._env.action_spec() + return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + + def step(self, action): + assert np.isfinite(action).all(), action + reward = 0 + for _ in range(self._action_repeat): + time_step = self._env.step(action) + reward += time_step.reward or 0 + if time_step.last(): + break + obs = dict(time_step.observation) + obs["image"] = self.render() + done = time_step.last() + info = {"discount": np.array(time_step.discount, np.float32)} + return obs, reward, done, info + + def reset(self): + time_step = self._env.reset() + obs = dict(time_step.observation) + obs["image"] = self.render() + return obs + + def render(self, *args, **kwargs): + if kwargs.get("mode", "rgb_array") != "rgb_array": + raise ValueError("Only render mode 'rgb_array' is supported.") + return self._env.physics.render(*self._size, camera_id=self._camera) + + +class Atari: + LOCK = threading.Lock() + + def __init__( + self, + name, + action_repeat=4, + size=(84, 84), + grayscale=True, + noops=30, + life_done=False, + sticky_actions=True, + all_actions=False, + ): + assert size[0] == size[1] + import gym.wrappers + import gym.envs.atari + + if name == "james_bond": + name = "jamesbond" + with self.LOCK: + env = gym.envs.atari.AtariEnv( + game=name, + obs_type="image", + frameskip=1, + repeat_action_probability=0.25 if sticky_actions else 0.0, + full_action_space=all_actions, + ) + # Avoid unnecessary rendering in inner env. + env._get_obs = lambda: None + # Tell wrapper that the inner env has no action repeat. + env.spec = gym.envs.registration.EnvSpec("NoFrameskip-v0") + env = gym.wrappers.AtariPreprocessing( + env, noops, action_repeat, size[0], life_done, grayscale + ) + self._env = env + self._grayscale = grayscale + + @property + def observation_space(self): + return gym.spaces.Dict( + { + "image": self._env.observation_space, + "ram": gym.spaces.Box(0, 255, (128,), np.uint8), + } + ) + + @property + def action_space(self): + return self._env.action_space + + def close(self): + return self._env.close() + + def reset(self): + with self.LOCK: + image = self._env.reset() + if self._grayscale: + image = image[..., None] + obs = {"image": image, "ram": self._env.env._get_ram()} + return obs + + def step(self, action): + image, reward, done, info = self._env.step(action) + if self._grayscale: + image = image[..., None] + obs = {"image": image, "ram": self._env.env._get_ram()} + return obs, reward, done, info + + def render(self, mode): + return self._env.render(mode) + + +class CollectDataset: + def __init__(self, env, callbacks=None, precision=32): + self._env = env + self._callbacks = callbacks or () + self._precision = precision + self._episode = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = {k: self._convert(v) for k, v in obs.items()} + transition = obs.copy() + if isinstance(action, dict): + transition.update(action) + else: + transition["action"] = action + transition["reward"] = reward + transition["discount"] = info.get("discount", np.array(1 - float(done))) + self._episode.append(transition) + if done: + for key, value in self._episode[1].items(): + if key not in self._episode[0]: + self._episode[0][key] = 0 * value + episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} + episode = {k: self._convert(v) for k, v in episode.items()} + info["episode"] = episode + for callback in self._callbacks: + callback(episode) + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + transition = obs.copy() + # Missing keys will be filled with a zeroed out version of the first + # transition, because we do not know what action information the agent will + # pass yet. + transition["reward"] = 0.0 + transition["discount"] = 1.0 + self._episode = [transition] + return obs + + def _convert(self, value): + value = np.array(value) + if np.issubdtype(value.dtype, np.floating): + dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] + elif np.issubdtype(value.dtype, np.signedinteger): + dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] + elif np.issubdtype(value.dtype, np.uint8): + dtype = np.uint8 + else: + raise NotImplementedError(value.dtype) + return value.astype(dtype) + + +class TimeLimit: + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + assert self._step is not None, "Must reset environment." + obs, reward, done, info = self._env.step(action) + self._step += 1 + if self._step >= self._duration: + done = True + if "discount" not in info: + info["discount"] = np.array(1.0).astype(np.float32) + self._step = None + return obs, reward, done, info + + def reset(self): + self._step = 0 + return self._env.reset() + + +class NormalizeActions: + def __init__(self, env): + self._env = env + self._mask = np.logical_and( + np.isfinite(env.action_space.low), np.isfinite(env.action_space.high) + ) + self._low = np.where(self._mask, env.action_space.low, -1) + self._high = np.where(self._mask, env.action_space.high, 1) + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + low = np.where(self._mask, -np.ones_like(self._low), self._low) + high = np.where(self._mask, np.ones_like(self._low), self._high) + return gym.spaces.Box(low, high, dtype=np.float32) + + def step(self, action): + original = (action + 1) / 2 * (self._high - self._low) + self._low + original = np.where(self._mask, original, action) + return self._env.step(original) + + +class OneHotAction: + def __init__(self, env): + assert isinstance(env.action_space, gym.spaces.Discrete) + self._env = env + self._random = np.random.RandomState() + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + shape = (self._env.action_space.n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + space.discrete = True + return space + + def step(self, action): + index = np.argmax(action).astype(int) + reference = np.zeros_like(action) + reference[index] = 1 + if not np.allclose(reference, action): + raise ValueError(f"Invalid one-hot action:\n{action}") + return self._env.step(index) + + def reset(self): + return self._env.reset() + + def _sample_action(self): + actions = self._env.action_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference + + +class RewardObs: + def __init__(self, env): + self._env = env + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def observation_space(self): + spaces = self._env.observation_space.spaces + assert "reward" not in spaces + spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) + return gym.spaces.Dict(spaces) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs["reward"] = reward + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + obs["reward"] = 0.0 + return obs + + +class SelectAction: + def __init__(self, env, key): + self._env = env + self._key = key + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + return self._env.step(action[self._key])