naive support for pixels

This commit is contained in:
Nicklas Hansen
2023-12-22 07:34:40 -08:00
parent 445af9d81d
commit bfb1971898
6 changed files with 118 additions and 14 deletions

View File

@@ -24,6 +24,44 @@ class Ensemble(nn.Module):
return 'Vectorized ' + self._repr
class ShiftAug(nn.Module):
"""
Random shift image augmentation.
Adapted from https://github.com/facebookresearch/drqv2
"""
def __init__(self, pad=3):
super().__init__()
self.pad = pad
def forward(self, x):
x = x.float()
n, _, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, 'replicate')
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode='zeros', align_corners=False)
class PixelPreprocess(nn.Module):
"""
Normalizes pixel observations to [-0.5, 0.5].
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x.div_(255.).sub_(0.5)
class SimNorm(nn.Module):
"""
Simplicial normalization.
@@ -69,16 +107,6 @@ class NormedLinear(nn.Linear):
f"act={self.act.__class__.__name__})"
def enc(cfg, out={}):
"""
Returns a dictionary of encoders for each observation in the dict.
"""
for k in cfg.obs_shape.keys():
assert k == 'state'
out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg))
return nn.ModuleDict(out)
def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.):
"""
Basic building block of TD-MPC2.
@@ -92,3 +120,34 @@ def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.):
mlp.append(NormedLinear(dims[i], dims[i+1], dropout=dropout*(i==0)))
mlp.append(NormedLinear(dims[-2], dims[-1], act=act) if act else nn.Linear(dims[-2], dims[-1]))
return nn.Sequential(*mlp)
def conv(in_shape, num_channels, act=None):
"""
Basic convolutional encoder for TD-MPC2 with raw image observations.
4 layers of convolution with ReLU activations, followed by a linear layer.
"""
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
layers = [
ShiftAug(), PixelPreprocess(),
nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()]
if act:
layers.append(act)
return nn.Sequential(*layers)
def enc(cfg, out={}):
"""
Returns a dictionary of encoders for each observation in the dict.
"""
for k in cfg.obs_shape.keys():
if k == 'state':
out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg))
elif k == 'rgb':
out[k] = conv(cfg.obs_shape[k], cfg.num_channels, act=SimNorm(cfg))
else:
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
return nn.ModuleDict(out)

View File

@@ -49,11 +49,11 @@ def print_run(cfg):
prefix + colored(f'{k.capitalize()+":":<15}', color, attrs=attrs), _limstr(v)
)
obs_dim = cfg.obs_shape['state'][0] if 'state' in cfg.obs_shape else cfg.obs_shape[0]
observations = ", ".join([str(v) for v in cfg.obs_shape.values()])
kvs = [
("task", cfg.task_title),
("steps", f"{int(cfg.steps):,}"),
("observations", obs_dim),
("observations", observations),
("actions", cfg.action_dim),
("experiment", cfg.exp_name),
]

View File

@@ -97,7 +97,9 @@ class WorldModel(nn.Module):
"""
if self.cfg.multitask:
obs = self.task_emb(obs, task)
return self._encoder['state'](obs)
if self.cfg.obs == 'rgb' and obs.ndim == 5:
return torch.stack([self._encoder[self.cfg.obs](o) for o in obs])
return self._encoder[self.cfg.obs](obs)
def next(self, z, a, task):
"""

View File

@@ -3,6 +3,7 @@ defaults:
# environment
task: dog-run
obs: state
# evaluation
checkpoint: ???
@@ -52,6 +53,7 @@ vmax: +10
model_size: ???
num_enc_layers: 2
enc_dim: 256
num_channels: 32
mlp_dim: 512
latent_dim: 512
task_dim: 96

View File

@@ -4,6 +4,7 @@ import warnings
import gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper
from envs.dmcontrol import make_env as make_dm_control_env
from envs.maniskill import make_env as make_maniskill_env
@@ -52,10 +53,12 @@ def make_env(cfg):
if env is None:
raise UnknownTaskError(cfg.task)
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box
cfg.obs_shape = {'state': env.observation_space.shape}
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length)

View File

@@ -0,0 +1,38 @@
from collections import deque
import gym
import numpy as np
import torch
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Compatible with DMControl environments.
"""
def __init__(self, cfg, env, num_frames=3, render_size=64):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
)
self._frames = deque([], maxlen=num_frames)
self._render_size = render_size
def _get_obs(self):
frame = self.env.render(
mode='rgb_array', width=self._render_size, height=self._render_size
).transpose(2, 0, 1)
self._frames.append(frame)
return torch.from_numpy(np.concatenate(self._frames))
def reset(self):
self.env.reset()
for _ in range(self._frames.maxlen):
obs = self._get_obs()
return obs
def step(self, action):
_, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info