diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index cb63997..1e0adb3 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -24,6 +24,44 @@ class Ensemble(nn.Module): return 'Vectorized ' + self._repr +class ShiftAug(nn.Module): + """ + Random shift image augmentation. + Adapted from https://github.com/facebookresearch/drqv2 + """ + def __init__(self, pad=3): + super().__init__() + self.pad = pad + + def forward(self, x): + x = x.float() + n, _, h, w = x.size() + assert h == w + padding = tuple([self.pad] * 4) + x = F.pad(x, padding, 'replicate') + eps = 1.0 / (h + 2 * self.pad) + arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h] + arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) + base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) + base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) + shift = torch.randint(0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype) + shift *= 2.0 / (h + 2 * self.pad) + grid = base_grid + shift + return F.grid_sample(x, grid, padding_mode='zeros', align_corners=False) + + +class PixelPreprocess(nn.Module): + """ + Normalizes pixel observations to [-0.5, 0.5]. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.div_(255.).sub_(0.5) + + class SimNorm(nn.Module): """ Simplicial normalization. @@ -69,16 +107,6 @@ class NormedLinear(nn.Linear): f"act={self.act.__class__.__name__})" -def enc(cfg, out={}): - """ - Returns a dictionary of encoders for each observation in the dict. - """ - for k in cfg.obs_shape.keys(): - assert k == 'state' - out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg)) - return nn.ModuleDict(out) - - def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.): """ Basic building block of TD-MPC2. @@ -92,3 +120,34 @@ def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.): mlp.append(NormedLinear(dims[i], dims[i+1], dropout=dropout*(i==0))) mlp.append(NormedLinear(dims[-2], dims[-1], act=act) if act else nn.Linear(dims[-2], dims[-1])) return nn.Sequential(*mlp) + + +def conv(in_shape, num_channels, act=None): + """ + Basic convolutional encoder for TD-MPC2 with raw image observations. + 4 layers of convolution with ReLU activations, followed by a linear layer. + """ + assert in_shape[-1] == 64 # assumes rgb observations to be 64x64 + layers = [ + ShiftAug(), PixelPreprocess(), + nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True), + nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True), + nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True), + nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()] + if act: + layers.append(act) + return nn.Sequential(*layers) + + +def enc(cfg, out={}): + """ + Returns a dictionary of encoders for each observation in the dict. + """ + for k in cfg.obs_shape.keys(): + if k == 'state': + out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg)) + elif k == 'rgb': + out[k] = conv(cfg.obs_shape[k], cfg.num_channels, act=SimNorm(cfg)) + else: + raise NotImplementedError(f"Encoder for observation type {k} not implemented.") + return nn.ModuleDict(out) diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index 39c93fe..ea26996 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -49,11 +49,11 @@ def print_run(cfg): prefix + colored(f'{k.capitalize()+":":<15}', color, attrs=attrs), _limstr(v) ) - obs_dim = cfg.obs_shape['state'][0] if 'state' in cfg.obs_shape else cfg.obs_shape[0] + observations = ", ".join([str(v) for v in cfg.obs_shape.values()]) kvs = [ ("task", cfg.task_title), ("steps", f"{int(cfg.steps):,}"), - ("observations", obs_dim), + ("observations", observations), ("actions", cfg.action_dim), ("experiment", cfg.exp_name), ] diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 8c9c5fd..a780ad0 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -97,7 +97,9 @@ class WorldModel(nn.Module): """ if self.cfg.multitask: obs = self.task_emb(obs, task) - return self._encoder['state'](obs) + if self.cfg.obs == 'rgb' and obs.ndim == 5: + return torch.stack([self._encoder[self.cfg.obs](o) for o in obs]) + return self._encoder[self.cfg.obs](obs) def next(self, z, a, task): """ diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 083bdcf..ae98d43 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -3,6 +3,7 @@ defaults: # environment task: dog-run +obs: state # evaluation checkpoint: ??? @@ -52,6 +53,7 @@ vmax: +10 model_size: ??? num_enc_layers: 2 enc_dim: 256 +num_channels: 32 mlp_dim: 512 latent_dim: 512 task_dim: 96 diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 0d78d27..5efcb73 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -4,6 +4,7 @@ import warnings import gym from envs.wrappers.multitask import MultitaskWrapper +from envs.wrappers.pixels import PixelWrapper from envs.wrappers.tensor import TensorWrapper from envs.dmcontrol import make_env as make_dm_control_env # from envs.maniskill import make_env as make_maniskill_env @@ -52,10 +53,12 @@ def make_env(cfg): if env is None: raise UnknownTaskError(cfg.task) env = TensorWrapper(env) + if cfg.get('obs', 'state') == 'rgb': + env = PixelWrapper(cfg, env) try: # Dict cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} except: # Box - cfg.obs_shape = {'state': env.observation_space.shape} + cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.action_dim = env.action_space.shape[0] cfg.episode_length = env.max_episode_steps cfg.seed_steps = max(1000, 5*cfg.episode_length) diff --git a/tdmpc2/envs/wrappers/pixels.py b/tdmpc2/envs/wrappers/pixels.py new file mode 100644 index 0000000..c299875 --- /dev/null +++ b/tdmpc2/envs/wrappers/pixels.py @@ -0,0 +1,38 @@ +from collections import deque + +import gym +import numpy as np +import torch + + +class PixelWrapper(gym.Wrapper): + """ + Wrapper for pixel observations. Compatible with DMControl environments. + """ + + def __init__(self, cfg, env, num_frames=3, render_size=64): + super().__init__(env) + self.cfg = cfg + self.env = env + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8 + ) + self._frames = deque([], maxlen=num_frames) + self._render_size = render_size + + def _get_obs(self): + frame = self.env.render( + mode='rgb_array', width=self._render_size, height=self._render_size + ).transpose(2, 0, 1) + self._frames.append(frame) + return torch.from_numpy(np.concatenate(self._frames)) + + def reset(self): + self.env.reset() + for _ in range(self._frames.maxlen): + obs = self._get_obs() + return obs + + def step(self, action): + _, reward, done, info = self.env.step(action) + return self._get_obs(), reward, done, info