naive support for pixels
This commit is contained in:
@@ -24,6 +24,44 @@ class Ensemble(nn.Module):
|
|||||||
return 'Vectorized ' + self._repr
|
return 'Vectorized ' + self._repr
|
||||||
|
|
||||||
|
|
||||||
|
class ShiftAug(nn.Module):
|
||||||
|
"""
|
||||||
|
Random shift image augmentation.
|
||||||
|
Adapted from https://github.com/facebookresearch/drqv2
|
||||||
|
"""
|
||||||
|
def __init__(self, pad=3):
|
||||||
|
super().__init__()
|
||||||
|
self.pad = pad
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.float()
|
||||||
|
n, _, h, w = x.size()
|
||||||
|
assert h == w
|
||||||
|
padding = tuple([self.pad] * 4)
|
||||||
|
x = F.pad(x, padding, 'replicate')
|
||||||
|
eps = 1.0 / (h + 2 * self.pad)
|
||||||
|
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
|
||||||
|
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
||||||
|
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||||
|
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
||||||
|
shift = torch.randint(0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype)
|
||||||
|
shift *= 2.0 / (h + 2 * self.pad)
|
||||||
|
grid = base_grid + shift
|
||||||
|
return F.grid_sample(x, grid, padding_mode='zeros', align_corners=False)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelPreprocess(nn.Module):
|
||||||
|
"""
|
||||||
|
Normalizes pixel observations to [-0.5, 0.5].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.div_(255.).sub_(0.5)
|
||||||
|
|
||||||
|
|
||||||
class SimNorm(nn.Module):
|
class SimNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
Simplicial normalization.
|
Simplicial normalization.
|
||||||
@@ -69,16 +107,6 @@ class NormedLinear(nn.Linear):
|
|||||||
f"act={self.act.__class__.__name__})"
|
f"act={self.act.__class__.__name__})"
|
||||||
|
|
||||||
|
|
||||||
def enc(cfg, out={}):
|
|
||||||
"""
|
|
||||||
Returns a dictionary of encoders for each observation in the dict.
|
|
||||||
"""
|
|
||||||
for k in cfg.obs_shape.keys():
|
|
||||||
assert k == 'state'
|
|
||||||
out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg))
|
|
||||||
return nn.ModuleDict(out)
|
|
||||||
|
|
||||||
|
|
||||||
def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.):
|
def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.):
|
||||||
"""
|
"""
|
||||||
Basic building block of TD-MPC2.
|
Basic building block of TD-MPC2.
|
||||||
@@ -92,3 +120,34 @@ def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.):
|
|||||||
mlp.append(NormedLinear(dims[i], dims[i+1], dropout=dropout*(i==0)))
|
mlp.append(NormedLinear(dims[i], dims[i+1], dropout=dropout*(i==0)))
|
||||||
mlp.append(NormedLinear(dims[-2], dims[-1], act=act) if act else nn.Linear(dims[-2], dims[-1]))
|
mlp.append(NormedLinear(dims[-2], dims[-1], act=act) if act else nn.Linear(dims[-2], dims[-1]))
|
||||||
return nn.Sequential(*mlp)
|
return nn.Sequential(*mlp)
|
||||||
|
|
||||||
|
|
||||||
|
def conv(in_shape, num_channels, act=None):
|
||||||
|
"""
|
||||||
|
Basic convolutional encoder for TD-MPC2 with raw image observations.
|
||||||
|
4 layers of convolution with ReLU activations, followed by a linear layer.
|
||||||
|
"""
|
||||||
|
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
|
||||||
|
layers = [
|
||||||
|
ShiftAug(), PixelPreprocess(),
|
||||||
|
nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()]
|
||||||
|
if act:
|
||||||
|
layers.append(act)
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def enc(cfg, out={}):
|
||||||
|
"""
|
||||||
|
Returns a dictionary of encoders for each observation in the dict.
|
||||||
|
"""
|
||||||
|
for k in cfg.obs_shape.keys():
|
||||||
|
if k == 'state':
|
||||||
|
out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg))
|
||||||
|
elif k == 'rgb':
|
||||||
|
out[k] = conv(cfg.obs_shape[k], cfg.num_channels, act=SimNorm(cfg))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
|
||||||
|
return nn.ModuleDict(out)
|
||||||
|
|||||||
@@ -49,11 +49,11 @@ def print_run(cfg):
|
|||||||
prefix + colored(f'{k.capitalize()+":":<15}', color, attrs=attrs), _limstr(v)
|
prefix + colored(f'{k.capitalize()+":":<15}', color, attrs=attrs), _limstr(v)
|
||||||
)
|
)
|
||||||
|
|
||||||
obs_dim = cfg.obs_shape['state'][0] if 'state' in cfg.obs_shape else cfg.obs_shape[0]
|
observations = ", ".join([str(v) for v in cfg.obs_shape.values()])
|
||||||
kvs = [
|
kvs = [
|
||||||
("task", cfg.task_title),
|
("task", cfg.task_title),
|
||||||
("steps", f"{int(cfg.steps):,}"),
|
("steps", f"{int(cfg.steps):,}"),
|
||||||
("observations", obs_dim),
|
("observations", observations),
|
||||||
("actions", cfg.action_dim),
|
("actions", cfg.action_dim),
|
||||||
("experiment", cfg.exp_name),
|
("experiment", cfg.exp_name),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -97,7 +97,9 @@ class WorldModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
obs = self.task_emb(obs, task)
|
obs = self.task_emb(obs, task)
|
||||||
return self._encoder['state'](obs)
|
if self.cfg.obs == 'rgb' and obs.ndim == 5:
|
||||||
|
return torch.stack([self._encoder[self.cfg.obs](o) for o in obs])
|
||||||
|
return self._encoder[self.cfg.obs](obs)
|
||||||
|
|
||||||
def next(self, z, a, task):
|
def next(self, z, a, task):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ defaults:
|
|||||||
|
|
||||||
# environment
|
# environment
|
||||||
task: dog-run
|
task: dog-run
|
||||||
|
obs: state
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
checkpoint: ???
|
checkpoint: ???
|
||||||
@@ -52,6 +53,7 @@ vmax: +10
|
|||||||
model_size: ???
|
model_size: ???
|
||||||
num_enc_layers: 2
|
num_enc_layers: 2
|
||||||
enc_dim: 256
|
enc_dim: 256
|
||||||
|
num_channels: 32
|
||||||
mlp_dim: 512
|
mlp_dim: 512
|
||||||
latent_dim: 512
|
latent_dim: 512
|
||||||
task_dim: 96
|
task_dim: 96
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import warnings
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
from envs.wrappers.multitask import MultitaskWrapper
|
from envs.wrappers.multitask import MultitaskWrapper
|
||||||
|
from envs.wrappers.pixels import PixelWrapper
|
||||||
from envs.wrappers.tensor import TensorWrapper
|
from envs.wrappers.tensor import TensorWrapper
|
||||||
from envs.dmcontrol import make_env as make_dm_control_env
|
from envs.dmcontrol import make_env as make_dm_control_env
|
||||||
from envs.maniskill import make_env as make_maniskill_env
|
from envs.maniskill import make_env as make_maniskill_env
|
||||||
@@ -52,10 +53,12 @@ def make_env(cfg):
|
|||||||
if env is None:
|
if env is None:
|
||||||
raise UnknownTaskError(cfg.task)
|
raise UnknownTaskError(cfg.task)
|
||||||
env = TensorWrapper(env)
|
env = TensorWrapper(env)
|
||||||
|
if cfg.get('obs', 'state') == 'rgb':
|
||||||
|
env = PixelWrapper(cfg, env)
|
||||||
try: # Dict
|
try: # Dict
|
||||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||||
except: # Box
|
except: # Box
|
||||||
cfg.obs_shape = {'state': env.observation_space.shape}
|
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
|
||||||
cfg.action_dim = env.action_space.shape[0]
|
cfg.action_dim = env.action_space.shape[0]
|
||||||
cfg.episode_length = env.max_episode_steps
|
cfg.episode_length = env.max_episode_steps
|
||||||
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
||||||
|
|||||||
38
tdmpc2/envs/wrappers/pixels.py
Normal file
38
tdmpc2/envs/wrappers/pixels.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class PixelWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for pixel observations. Compatible with DMControl environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg, env, num_frames=3, render_size=64):
|
||||||
|
super().__init__(env)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.env = env
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
|
||||||
|
)
|
||||||
|
self._frames = deque([], maxlen=num_frames)
|
||||||
|
self._render_size = render_size
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
frame = self.env.render(
|
||||||
|
mode='rgb_array', width=self._render_size, height=self._render_size
|
||||||
|
).transpose(2, 0, 1)
|
||||||
|
self._frames.append(frame)
|
||||||
|
return torch.from_numpy(np.concatenate(self._frames))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.env.reset()
|
||||||
|
for _ in range(self._frames.maxlen):
|
||||||
|
obs = self._get_obs()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
_, reward, done, info = self.env.step(action)
|
||||||
|
return self._get_obs(), reward, done, info
|
||||||
Reference in New Issue
Block a user