first commit

This commit is contained in:
Nicklas Hansen
2023-10-25 18:26:00 -07:00
commit b67b21c5c6
165 changed files with 16364 additions and 0 deletions

0
tdmpc2/__init__.py Executable file
View File

60
tdmpc2/common/__init__.py Normal file
View File

@@ -0,0 +1,60 @@
MODEL_SIZE = { # parameters (M)
1: {'enc_dim': 256,
'mlp_dim': 384,
'latent_dim': 128,
'num_enc_layers': 2,
'num_q': 2},
5: {'enc_dim': 256,
'mlp_dim': 512,
'latent_dim': 512,
'num_enc_layers': 2},
19: {'enc_dim': 1024,
'mlp_dim': 1024,
'latent_dim': 768,
'num_enc_layers': 3},
48: {'enc_dim': 1792,
'mlp_dim': 1792,
'latent_dim': 768,
'num_enc_layers': 4},
317: {'enc_dim': 4096,
'mlp_dim': 4096,
'latent_dim': 1376,
'num_enc_layers': 5,
'num_q': 8},
}
TASK_SET = {
'mt30': [
# 19 original dmcontrol tasks
'walker-stand', 'walker-walk', 'walker-run', 'cheetah-run', 'reacher-easy',
'reacher-hard', 'acrobot-swingup', 'pendulum-swingup', 'cartpole-balance', 'cartpole-balance-sparse',
'cartpole-swingup', 'cartpole-swingup-sparse', 'cup-catch', 'finger-spin', 'finger-turn-easy',
'finger-turn-hard', 'fish-swim', 'hopper-stand', 'hopper-hop',
# 11 custom dmcontrol tasks
'walker-walk-backwards', 'walker-run-backwards', 'cheetah-run-backwards', 'cheetah-run-front', 'cheetah-run-back',
'cheetah-jump', 'hopper-hop-backwards', 'reacher-three-easy', 'reacher-three-hard', 'cup-spin',
'pendulum-spin',
],
'mt80': [
# 19 original dmcontrol tasks
'walker-stand', 'walker-walk', 'walker-run', 'cheetah-run', 'reacher-easy',
'reacher-hard', 'acrobot-swingup', 'pendulum-swingup', 'cartpole-balance', 'cartpole-balance-sparse',
'cartpole-swingup', 'cartpole-swingup-sparse', 'cup-catch', 'finger-spin', 'finger-turn-easy',
'finger-turn-hard', 'fish-swim', 'hopper-stand', 'hopper-hop',
# 11 custom dmcontrol tasks
'walker-walk-backwards', 'walker-run-backwards', 'cheetah-run-backwards', 'cheetah-run-front', 'cheetah-run-back',
'cheetah-jump', 'hopper-hop-backwards', 'reacher-three-easy', 'reacher-three-hard', 'cup-spin',
'pendulum-spin',
# meta-world mt50
'mw-assembly', 'mw-basketball', 'mw-button-press-topdown', 'mw-button-press-topdown-wall', 'mw-button-press',
'mw-button-press-wall', 'mw-coffee-button', 'mw-coffee-pull', 'mw-coffee-push', 'mw-dial-turn',
'mw-disassemble', 'mw-door-open', 'mw-door-close', 'mw-drawer-close', 'mw-drawer-open',
'mw-faucet-open', 'mw-faucet-close', 'mw-hammer', 'mw-handle-press-side', 'mw-handle-press',
'mw-handle-pull-side', 'mw-handle-pull', 'mw-lever-pull', 'mw-peg-insert-side', 'mw-peg-unplug-side',
'mw-pick-out-of-hole', 'mw-pick-place', 'mw-pick-place-wall', 'mw-plate-slide', 'mw-plate-slide-side',
'mw-plate-slide-back', 'mw-plate-slide-back-side', 'mw-push-back', 'mw-push', 'mw-push-wall',
'mw-reach', 'mw-reach-wall', 'mw-shelf-place', 'mw-soccer', 'mw-stick-push',
'mw-stick-pull', 'mw-sweep-into', 'mw-sweep', 'mw-window-open', 'mw-window-close',
'mw-bin-picking', 'mw-box-close', 'mw-door-lock', 'mw-door-unlock', 'mw-hand-insert',
],
}

115
tdmpc2/common/buffer.py Normal file
View File

@@ -0,0 +1,115 @@
from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir
class DataPrepTransform(Transform):
"""
Preprocesses data for TD-MPC2 training.
Replay data is expected to be a TensorDict with the following keys:
obs: observations
action: actions
reward: rewards
task: task IDs (optional)
A TensorDict with T time steps has T+1 observations and T actions and rewards.
The first actions and rewards in each TensorDict are dummies and should be ignored.
"""
def __init__(self):
super().__init__([])
def forward(self, td):
td = td.permute(1,0)
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
class Buffer():
"""
Create a replay buffer for TD-MPC2 training.
Uses CUDA memory if available, and CPU memory otherwise.
"""
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._num_eps = 0
@property
def capacity(self):
"""Return the capacity of the buffer."""
return self._capacity
@property
def num_eps(self):
"""Return the number of episodes in the buffer."""
return self._num_eps
def _reserve_buffer(self, storage):
"""
Reserve a buffer with the given storage.
Uses the RandomSampler to sample trajectories,
and the RandomCropTensorDict transform to crop trajectories to the desired length.
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
"""
return ReplayBuffer(
storage=storage,
sampler=RandomSampler(),
pin_memory=True,
prefetch=1,
transform=Compose(
RandomCropTensorDict(self.cfg.horizon+1, -1),
DataPrepTransform(),
),
batch_size=self.cfg.batch_size,
)
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
print('Using CPU memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
)
else: # Sufficient CUDA memory
print('Using CUDA memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
def add(self, tds):
"""Add an episode to the buffer. All episodes are expected to have the same length."""
if self._num_eps == 0:
self._buffer = self._init(tds)
self._buffer.add(tds)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
return obs.to(self._device, non_blocking=True), \
action.to(self._device, non_blocking=True), \
reward.to(self._device, non_blocking=True), \
task.to(self._device, non_blocking=True) if task is not None else None
def save(self):
"""Save the buffer to disk. Useful for storing offline datasets."""
td = self._buffer._storage._storage.cpu()
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
torch.save(td, fp)

22
tdmpc2/common/init.py Normal file
View File

@@ -0,0 +1,22 @@
import torch.nn as nn
def weight_init(m):
"""Custom weight initialization for TD-MPC2."""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -0.02, 0.02)
elif isinstance(m, nn.ParameterList):
for i,p in enumerate(m):
if p.dim() == 3: # Linear
nn.init.trunc_normal_(p, std=0.02) # Weight
nn.init.constant_(m[i+1], 0) # Bias
def zero_(params):
"""Initialize parameters to zero."""
for p in params:
p.data.fill_(0)

97
tdmpc2/common/layers.py Normal file
View File

@@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import combine_state_for_ensemble
class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
"""
def __init__(self, modules, **kwargs):
super().__init__()
modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules)
def modules(self):
return self.vmap.__wrapped__.stateless_model
def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs)
def __repr__(self):
return 'Vectorized ' + self._repr
class SimNorm(nn.Module):
"""
Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616.
"""
def __init__(self, cfg):
super().__init__()
self.dim = cfg.simnorm_dim
def forward(self, x):
shp = x.shape
x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1)
return x.view(*shp)
def __repr__(self):
return f"SimNorm(dim={self.dim})"
class NormedLinear(nn.Linear):
"""
Linear layer with LayerNorm, activation, and optionally dropout.
"""
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features)
self.act = act
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
def forward(self, x):
x = super().forward(x)
if self.dropout:
x = self.dropout(x)
return self.act(self.ln(x))
def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return f"NormedLinear(in_features={self.in_features}, "\
f"out_features={self.out_features}, "\
f"bias={self.bias is not None}{repr_dropout}, "\
f"act={self.act.__class__.__name__})"
def enc(cfg, out={}):
"""
Returns a dictionary of encoders for each observation in the dict.
"""
for k in cfg.obs_shape.keys():
assert k == 'state'
out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg))
return nn.ModuleDict(out)
def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.):
"""
Basic building block of TD-MPC2.
MLP with LayerNorm, Mish activations, and optionally dropout.
"""
if isinstance(mlp_dims, int):
mlp_dims = [mlp_dims]
dims = [in_dim] + mlp_dims + [out_dim]
mlp = nn.ModuleList()
for i in range(len(dims) - 2):
mlp.append(NormedLinear(dims[i], dims[i+1], dropout=dropout*(i==0)))
mlp.append(NormedLinear(dims[-2], dims[-1], act=act) if act else nn.Linear(dims[-2], dims[-1]))
return nn.Sequential(*mlp)

238
tdmpc2/common/logger.py Executable file
View File

@@ -0,0 +1,238 @@
import os
import datetime
import re
import numpy as np
import pandas as pd
from termcolor import colored
from omegaconf import OmegaConf
from common import TASK_SET
CONSOLE_FORMAT = [
("iteration", "I", "int"),
("episode", "E", "int"),
("step", "I", "int"),
("episode_reward", "R", "float"),
("episode_success", "S", "float"),
("total_time", "T", "time"),
]
CAT_TO_COLOR = {
"pretrain": "yellow",
"train": "blue",
"eval": "green",
}
def make_dir(dir_path):
"""Create directory if it does not already exist."""
try:
os.makedirs(dir_path)
except OSError:
pass
return dir_path
def print_run(cfg):
"""
Pretty-printing of current run information.
Logger calls this method at initialization.
"""
prefix, color, attrs = " ", "green", ["bold"]
def _limstr(s, maxlen=36):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def _pprint(k, v):
print(
prefix + colored(f'{k.capitalize()+":":<15}', color, attrs=attrs), _limstr(v)
)
obs_dim = cfg.obs_shape['state'][0] if 'state' in cfg.obs_shape else cfg.obs_shape[0]
kvs = [
("task", cfg.task_title),
("steps", f"{int(cfg.steps):,}"),
("observations", obs_dim),
("actions", cfg.action_dim),
("experiment", cfg.exp_name),
]
w = np.max([len(_limstr(str(kv[1]))) for kv in kvs]) + 25
div = "-" * w
print(div)
for k, v in kvs:
_pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False):
"""
Return a wandb-safe group name for logging.
Optionally returns group name as list.
"""
lst = [cfg.task, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
return lst if return_list else "-".join(lst)
class VideoRecorder:
"""Utility class for logging evaluation videos."""
def __init__(self, cfg, wandb, fps=15):
self.cfg = cfg
self._save_dir = make_dir(cfg.work_dir / 'eval_video')
self._wandb = wandb
self.fps = fps
self.frames = []
self.enabled = False
def init(self, env, enabled=True):
self.frames = []
self.enabled = self._save_dir and self._wandb and enabled
self.record(env)
def record(self, env):
if self.enabled:
self.frames.append(env.render())
def save(self, step, key='videos/eval_video'):
if self.enabled and len(self.frames) > 0:
frames = np.stack(self.frames)
return self._wandb.log(
{key: self._wandb.Video(frames.transpose(0, 3, 1, 2), fps=self.fps, format='mp4')}, step=step
)
class Logger:
"""Primary logging object. Logs either locally or using wandb."""
def __init__(self, cfg):
self._log_dir = make_dir(cfg.work_dir)
self._model_dir = make_dir(self._log_dir / "models")
self._save_csv = cfg.save_csv
self._save_agent = cfg.save_agent
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._eval = []
print_run(cfg)
self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none")
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False
cfg.save_video = False
self._wandb = None
self._video = None
return
os.environ["WANDB_SILENT"] = "true" if cfg.wandb_silent else "false"
import wandb
wandb.init(
project=self.project,
entity=self.entity,
name=str(cfg.seed),
group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
self._wandb = wandb
self._video = (
VideoRecorder(cfg, self._wandb)
if self._wandb and cfg.save_video
else None
)
@property
def video(self):
return self._video
@property
def model_dir(self):
return self._model_dir
def save_agent(self, agent=None, identifier='final'):
if self._save_agent and agent:
fp = self._model_dir / f'{str(identifier)}.pt'
agent.save(fp)
if self._wandb:
artifact = self._wandb.Artifact(
self._group + '-' + str(self._seed) + '-' + str(identifier),
type='model',
)
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
def finish(self, agent=None):
try:
self.save_agent(agent)
except Exception as e:
print(colored(f"Failed to save model: {e}", "red"))
if self._wandb:
self._wandb.finish()
def _format(self, key, value, ty):
if ty == "int":
return f'{colored(key+":", "blue")} {int(value):,}'
elif ty == "float":
return f'{colored(key+":", "blue")} {value:.01f}'
elif ty == "time":
value = str(datetime.timedelta(seconds=int(value)))
return f'{colored(key+":", "blue")} {value}'
else:
raise f"invalid log format type: {ty}"
def _print(self, d, category):
category = colored(category, CAT_TO_COLOR[category])
pieces = [f" {category:<14}"]
for k, disp_k, ty in CONSOLE_FORMAT:
if k in d:
pieces.append(f"{self._format(disp_k, d[k], ty):<22}")
print(" ".join(pieces))
def pprint_multitask(self, d, cfg):
"""Pretty-print evaluation metrics for multi-task training."""
print(colored(f'Evaluated agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold']))
dmcontrol_reward = []
metaworld_reward = []
metaworld_success = []
for k, v in d.items():
if '+' not in k:
continue
task = k.split('+')[1]
if task in TASK_SET['mt30'] and k.startswith('episode_reward'): # DMControl
dmcontrol_reward.append(v)
print(colored(f' {task:<22}\tR: {v:.01f}', 'yellow'))
elif task in TASK_SET['mt80'] and task not in TASK_SET['mt30']: # Meta-World
if k.startswith('episode_reward'):
metaworld_reward.append(v)
elif k.startswith('episode_success'):
metaworld_success.append(v)
print(colored(f' {task:<22}\tS: {v:.02f}', 'yellow'))
dmcontrol_reward = np.nanmean(dmcontrol_reward)
d['episode_reward+avg_dmcontrol'] = dmcontrol_reward
print(colored(f' {"dmcontrol":<22}\tR: {dmcontrol_reward:.01f}', 'yellow', attrs=['bold']))
if cfg.task == 'mt80':
metaworld_reward = np.nanmean(metaworld_reward)
metaworld_success = np.nanmean(metaworld_success)
d['episode_reward+avg_metaworld'] = metaworld_reward
d['episode_success+avg_metaworld'] = metaworld_success
print(colored(f' {"metaworld":<22}\tR: {metaworld_reward:.01f}', 'yellow', attrs=['bold']))
print(colored(f' {"metaworld":<22}\tS: {metaworld_success:.02f}', 'yellow', attrs=['bold']))
def log(self, d, category="train"):
assert category in CAT_TO_COLOR.keys(), f"invalid category: {category}"
if self._wandb:
if category in {"train", "eval"}:
xkey = "step"
elif category == "pretrain":
xkey = "iteration"
for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d[xkey])
if category == "eval" and self._save_csv:
keys = ["step", "episode_reward"]
self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.csv", header=keys, index=None
)
self._print(d, category)

95
tdmpc2/common/math.py Normal file
View File

@@ -0,0 +1,95 @@
import torch
import torch.nn.functional as F
def soft_ce(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)
return -(target * pred).sum(-1, keepdim=True)
@torch.jit.script
def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1)
@torch.jit.script
def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std
@torch.jit.script
def _gaussian_logprob(residual):
return residual - 0.5 * torch.log(2 * torch.pi)
def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
if size is None:
size = eps.size(-1)
return _gaussian_logprob(residual) * size
@torch.jit.script
def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= _squash(pi).sum(-1, keepdim=True)
return mu, pi, log_pi
@torch.jit.script
def symlog(x):
"""
Symmetric logarithmic function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * torch.log(1 + torch.abs(x))
@torch.jit.script
def symexp(x):
"""
Symmetric exponential function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
def two_hot(x, cfg):
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
if cfg.num_bins == 0:
return x
elif cfg.num_bins == 1:
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
return soft_two_hot
DREG_BINS = None
def two_hot_inv(x, cfg):
"""Converts a batch of soft two-hot encoded vectors to scalars."""
global DREG_BINS
if cfg.num_bins == 0:
return x
elif cfg.num_bins == 1:
return symexp(x)
if DREG_BINS is None:
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
x = F.softmax(x, dim=-1)
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
return symexp(x)

60
tdmpc2/common/parser.py Executable file
View File

@@ -0,0 +1,60 @@
import re
from pathlib import Path
import hydra
from omegaconf import OmegaConf
from common import MODEL_SIZE, TASK_SET
def parse_cfg(cfg: OmegaConf) -> OmegaConf:
"""
Parses a Hydra config. Mostly for convenience.
"""
# Logic
for k in cfg.keys():
try:
v = cfg[k]
if v == None:
v = True
except:
pass
# Algebraic expressions
for k in cfg.keys():
try:
v = cfg[k]
if isinstance(v, str):
match = re.match(r"(\d+)([+\-*/])(\d+)", v)
if match:
cfg[k] = eval(match.group(1) + match.group(2) + match.group(3))
if isinstance(cfg[k], float) and cfg[k].is_integer():
cfg[k] = int(cfg[k])
except:
pass
# Convenience
cfg.work_dir = Path(hydra.utils.get_original_cwd()) / 'logs' / cfg.task / str(cfg.seed) / cfg.exp_name
cfg.task_title = cfg.task.replace("-", " ").title()
cfg.bin_size = (cfg.vmax - cfg.vmin) / (cfg.num_bins-1) # Bin size for discrete regression
# Model size
assert cfg.model_size in MODEL_SIZE.keys(), \
f'Invalid model size {cfg.model_size}. Must be one of {list(MODEL_SIZE.keys())}'
for k, v in MODEL_SIZE[cfg.model_size].items():
cfg[k] = v
if cfg.task == 'mt30' and cfg.model_size == 19:
cfg.latent_dim = 512 # This checkpoint is slightly smaller
# Multi-task
cfg.multitask = cfg.task in TASK_SET.keys()
if cfg.multitask:
cfg.task_title = cfg.task.upper()
# Account for slight inconsistency in task_dim for the mt30 experiments
cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.model_size in {1, 317} else 64
else:
cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
return cfg

48
tdmpc2/common/scale.py Normal file
View File

@@ -0,0 +1,48 @@
import torch
class RunningScale:
"""Running trimmed scale estimator."""
def __init__(self, cfg):
self.cfg = cfg
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))
def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles)
def load_state_dict(self, state_dict):
self._value.data.copy_(state_dict['value'])
self._percentiles.data.copy_(state_dict['percentiles'])
@property
def value(self):
return self._value.cpu().item()
def _percentile(self, x):
x_dtype, x_shape = x.dtype, x.shape
x = x.view(x.shape[0], -1)
in_sorted, _ = torch.sort(x, dim=0)
positions = self._percentiles * (x.shape[0]-1) / 100
floored = torch.floor(positions)
ceiled = floored + 1
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
weight_ceiled = positions-floored
weight_floored = 1.0 - weight_ceiled
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
return (d0+d1).view(-1, *x_shape[1:]).type(x_dtype)
def update(self, x):
percentiles = self._percentile(x.detach())
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
self._value.data.lerp_(value, self.cfg.tau)
def __call__(self, x, update=False):
if update:
self.update(x)
return x * (1/self.value)
def __repr__(self):
return f'RunningScale(S: {self.value})'

12
tdmpc2/common/seed.py Normal file
View File

@@ -0,0 +1,12 @@
import random
import numpy as np
import torch
def set_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@@ -0,0 +1,174 @@
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from common import layers, math, init
class WorldModel(nn.Module):
"""
TD-MPC2 implicit world model architecture.
Can be used for both single-task and multi-task experiments.
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.multitask:
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
for i in range(len(cfg.tasks)):
self._action_masks[i, :cfg.action_dims[i]] = 1.
self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min)
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
@property
def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def to(self, *args, **kwargs):
"""
Overriding `to` method to also move additional tensors to device.
"""
super().to(*args, **kwargs)
if self.cfg.multitask:
self._action_masks = self._action_masks.to(*args, **kwargs)
self.log_std_min = self.log_std_min.to(*args, **kwargs)
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
return self
def train(self, mode=True):
"""
Overriding `train` method to keep target Q-networks in eval mode.
"""
super().train(mode)
self._target_Qs.train(False)
return self
def track_q_grad(self, mode=True):
"""
Enables/disables gradient tracking of Q-networks.
Avoids unnecessary computation during policy optimization.
This method also enables/disables gradients for task embeddings,
and sets the dropout probability to 0 if `mode` is False.
"""
for p in self._Qs.parameters():
p.requires_grad_(mode)
if self.cfg.multitask:
for p in self._task_emb.parameters():
p.requires_grad_(mode)
for m in self._Qs.modules():
if isinstance(m, nn.Dropout):
m.p = self.cfg.dropout if mode else 0
def soft_update_target_Q(self):
"""
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
p_target.data.lerp_(p.data, self.cfg.tau)
def task_emb(self, x, task):
"""
Continuous task embedding for multi-task experiments.
Retrieves the task embedding for a given task ID `task`
and concatenates it to the input `x`.
"""
if isinstance(task, int):
task = torch.tensor([task], device=x.device)
emb = self._task_emb(task.long())
if x.ndim == 3:
emb = emb.unsqueeze(0).repeat(x.shape[0], 1, 1)
elif emb.shape[0] == 1:
emb = emb.repeat(x.shape[0], 1)
return torch.cat([x, emb], dim=-1)
def encode(self, obs, task):
"""
Encodes an observation into its latent representation.
This implementation assumes a single state-based observation.
"""
if self.cfg.multitask:
obs = self.task_emb(obs, task)
return self._encoder['state'](obs)
def next(self, z, a, task):
"""
Predicts the next latent state given the current latent state and action.
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1)
return self._dynamics(z)
def reward(self, z, a, task):
"""
Predicts instantaneous (single-step) reward.
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1)
return self._reward(z)
def pi(self, z, task):
"""
Samples an action from the policy prior.
The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network.
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
# Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1)
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
eps = torch.randn_like(mu)
if self.cfg.multitask: # Mask out unused action dimensions
mu = mu * self._action_masks[task]
log_std = log_std * self._action_masks[task]
eps = eps * self._action_masks[task]
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
else: # No masking
action_dims = None
log_pi = math.gaussian_logprob(eps, log_std, size=action_dims)
pi = mu + eps * log_std.exp()
mu, pi, log_pi = math.squash(mu, pi, log_pi)
return mu, pi, log_pi, log_std
def Q(self, z, a, task, return_type='min', target=False):
"""
Predict state-action value.
`return_type` can be one of [`min`, `avg`, `all`]:
- `min`: return the minimum of two randomly subsampled Q-values.
- `avg`: return the average of two randomly subsampled Q-values.
- `all`: return all Q-values.
`target` specifies whether to use the target Q-networks or not.
"""
assert return_type in {'min', 'avg', 'all'}
if self.cfg.multitask:
z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1)
out = (self._target_Qs if target else self._Qs)(z)
if return_type == 'all':
return out
Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)]
Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg)
return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2

86
tdmpc2/config.yaml Executable file
View File

@@ -0,0 +1,86 @@
defaults:
- override hydra/launcher: submitit_local
# environment
task: dog-run
# evaluation
checkpoint: ???
eval_episodes: 10
eval_freq: 50000
# training
steps: 10_000_000
batch_size: 256
reward_coef: 0.1
value_coef: 0.1
consistency_coef: 20
rho: 0.5
lr: 3e-4
enc_lr_scale: 0.3
grad_clip_norm: 20
tau: 0.01
discount_denom: 5
discount_min: 0.95
discount_max: 0.995
buffer_size: 1_000_000
exp_name: default
data_dir: ???
# planning
mpc: true
iterations: 6
num_samples: 512
num_elites: 64
num_pi_trajs: 24
horizon: 3
min_std: 0.05
max_std: 2
temperature: 0.5
# actor
log_std_min: -10
log_std_max: 2
entropy_coef: 1e-4
# critic
num_bins: 101
vmin: -10
vmax: +10
# architecture
model_size: 5
num_enc_layers: 2
enc_dim: 256
mlp_dim: 512
latent_dim: 512
task_dim: 96
num_q: 5
dropout: 0.01
simnorm_dim: 8
# logging
wandb_project: ???
wandb_entity: ???
wandb_silent: false
disable_wandb: true
save_csv: true
# misc
save_video: true
save_agent: true
seed: 1
# convenience
work_dir: ???
task_title: ???
multitask: ???
tasks: ???
obs_shape: ???
action_dim: ???
episode_length: ???
obs_shapes: ???
action_dims: ???
episode_lengths: ???
seed_steps: ???
bin_size: ???

62
tdmpc2/envs/__init__.py Normal file
View File

@@ -0,0 +1,62 @@
from copy import deepcopy
import warnings
import gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.tensor import TensorWrapper
from envs.dmcontrol import make_env as make_dm_control_env
from envs.maniskill import make_env as make_maniskill_env
from envs.metaworld import make_env as make_metaworld_env
from envs.myosuite import make_env as make_myosuite_env
from envs.exceptions import UnknownTaskError
warnings.filterwarnings('ignore', category=DeprecationWarning)
def make_multitask_env(cfg):
"""
Make a multi-task environment for TD-MPC2 experiments.
"""
print('Creating multi-task environment with tasks:', cfg.tasks)
envs = []
for task in cfg.tasks:
_cfg = deepcopy(cfg)
_cfg.task = task
_cfg.multitask = False
env = make_env(_cfg)
if env is None:
raise UnknownTaskError(task)
envs.append(env)
env = MultitaskWrapper(cfg, envs)
cfg.obs_shapes = env._obs_dims
cfg.action_dims = env._action_dims
cfg.episode_lengths = env._episode_lengths
return env
def make_env(cfg):
"""
Make an environment for TD-MPC2 experiments.
"""
gym.logger.set_level(40)
if cfg.multitask:
env = make_multitask_env(cfg)
else:
env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try:
env = fn(cfg)
except UnknownTaskError:
pass
if env is None:
raise UnknownTaskError(cfg.task)
env = TensorWrapper(env)
try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box
cfg.obs_shape = {'state': env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length)
return env

200
tdmpc2/envs/dmcontrol.py Normal file
View File

@@ -0,0 +1,200 @@
from collections import deque, defaultdict
from typing import Any, NamedTuple
import dm_env
import numpy as np
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
from dm_control import suite
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs
from envs.exceptions import UnknownTaskError
import gym
class ExtendedTimeStep(NamedTuple):
step_type: Any
reward: Any
discount: Any
observation: Any
action: Any
def first(self):
return self.step_type == StepType.FIRST
def mid(self):
return self.step_type == StepType.MID
def last(self):
return self.step_type == StepType.LAST
class ActionRepeatWrapper(dm_env.Environment):
def __init__(self, env, num_repeats):
self._env = env
self._num_repeats = num_repeats
def step(self, action):
reward = 0.0
discount = 1.0
for i in range(self._num_repeats):
time_step = self._env.step(action)
reward += (time_step.reward or 0.0) * discount
discount *= time_step.discount
if time_step.last():
break
return time_step._replace(reward=reward, discount=discount)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ActionDTypeWrapper(dm_env.Environment):
def __init__(self, env, dtype):
self._env = env
wrapped_action_spec = env.action_spec()
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
dtype,
wrapped_action_spec.minimum,
wrapped_action_spec.maximum,
'action')
def step(self, action):
action = action.astype(self._env.action_spec().dtype)
return self._env.step(action)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._action_spec
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ExtendedTimeStepWrapper(dm_env.Environment):
def __init__(self, env):
self._env = env
def reset(self):
time_step = self._env.reset()
return self._augment_time_step(time_step)
def step(self, action):
time_step = self._env.step(action)
return self._augment_time_step(time_step, action)
def _augment_time_step(self, time_step, action=None):
if action is None:
action_spec = self.action_spec()
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
return ExtendedTimeStep(observation=time_step.observation,
step_type=time_step.step_type,
action=action,
reward=time_step.reward or 0.0,
discount=time_step.discount or 1.0)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class TimeStepToGymWrapper:
def __init__(self, env, domain, task):
obs_shp = []
for v in env.observation_spec().values():
try:
shp = np.prod(v.shape)
except:
shp = 1
obs_shp.append(shp)
obs_shp = (int(np.sum(obs_shp)),)
act_shp = env.action_spec().shape
self.observation_space = gym.spaces.Box(
low=np.full(
obs_shp,
-np.inf,
dtype=np.float32),
high=np.full(
obs_shp,
np.inf,
dtype=np.float32),
dtype=np.float32,
)
self.action_space = gym.spaces.Box(
low=np.full(act_shp, env.action_spec().minimum),
high=np.full(act_shp, env.action_spec().maximum),
dtype=env.action_spec().dtype)
self.env = env
self.domain = domain
self.task = task
self.max_episode_steps = 500
self.t = 0
@property
def unwrapped(self):
return self.env
@property
def reward_range(self):
return None
@property
def metadata(self):
return None
def _obs_to_array(self, obs):
return np.concatenate([v.flatten() for v in obs.values()])
def reset(self):
self.t = 0
return self._obs_to_array(self.env.reset().observation)
def step(self, action):
self.t += 1
time_step = self.env.step(action)
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float)
def render(self, mode='rgb_array', width=384, height=384, camera_id=0):
camera_id = dict(quadruped=2).get(self.domain, camera_id)
return self.env.physics.render(height, width, camera_id)
def make_env(cfg):
"""
Make DMControl environment.
Adapted from https://github.com/facebookresearch/drqv2
"""
domain, task = cfg.task.replace('-', '_').split('_', 1)
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
if (domain, task) not in suite.ALL_TASKS:
raise UnknownTaskError(cfg.task)
env = suite.load(domain,
task,
task_kwargs={'random': cfg.seed},
visualize_reward=False)
env = ActionDTypeWrapper(env, np.float32)
env = ActionRepeatWrapper(env, 2)
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
env = ExtendedTimeStepWrapper(env)
env = TimeStepToGymWrapper(env, domain, task)
return env

View File

@@ -0,0 +1,4 @@
class UnknownTaskError(Exception):
def __init__(self, task):
super().__init__(f'Unknown task: {task}')

79
tdmpc2/envs/maniskill.py Normal file
View File

@@ -0,0 +1,79 @@
import gym
import numpy as np
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
import mani_skill2.envs
MANISKILL_TASKS = {
'lift-cube': dict(
env='LiftCube-v0',
control_mode='pd_ee_delta_pos',
),
'pick-cube': dict(
env='PickCube-v0',
control_mode='pd_ee_delta_pos',
),
'stack-cube': dict(
env='StackCube-v0',
control_mode='pd_ee_delta_pos',
),
'pick-ycb': dict(
env='PickSingleYCB-v0',
control_mode='pd_ee_delta_pose',
),
'turn-faucet': dict(
env='TurnFaucet-v0',
control_mode='pd_ee_delta_pose',
),
}
class ManiSkillWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
self.env = env
self.cfg = cfg
self.observation_space = self.env.observation_space
self.action_space = gym.spaces.Box(
low=np.full(self.env.action_space.shape, self.env.action_space.low.min()),
high=np.full(self.env.action_space.shape, self.env.action_space.high.max()),
dtype=self.env.action_space.dtype,
)
def reset(self):
return self.env.reset()
def step(self, action):
reward = 0
for _ in range(2):
obs, r, _, info = self.env.step(action)
reward += r
return obs, reward, False, info
@property
def unwrapped(self):
return self.env.unwrapped
def render(self, args, **kwargs):
return self.env.render(mode='cameras')
def make_env(cfg):
"""
Make ManiSkill2 environment.
"""
if cfg.task not in MANISKILL_TASKS:
raise UnknownTaskError(cfg.task)
task_cfg = MANISKILL_TASKS[cfg.task]
env = gym.make(
task_cfg['env'],
obs_mode='state',
control_mode=task_cfg['control_mode'],
render_camera_cfgs=dict(width=384, height=384),
)
env = ManiSkillWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps
return env

52
tdmpc2/envs/metaworld.py Normal file
View File

@@ -0,0 +1,52 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
class MetaWorldWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
self.env = env
self.cfg = cfg
self.camera_name = "corner2"
self.env.model.cam_pos[2] = [0.75, 0.075, 0.7]
self.env._freeze_rand_vec = False
def reset(self, **kwargs):
obs = super().reset(**kwargs).astype(np.float32)
self.env.step(np.zeros(self.env.action_space.shape))
return obs
def step(self, action):
reward = 0
for _ in range(2):
obs, r, _, info = self.env.step(action.copy())
reward += r
obs = obs.astype(np.float32)
return obs, reward, False, info
@property
def unwrapped(self):
return self.env.unwrapped
def render(self, *args, **kwargs):
return self.env.render(
offscreen=True, resolution=(384, 384), camera_name=self.camera_name
).copy()
def make_env(cfg):
"""
Make Meta-World environment.
"""
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
raise UnknownTaskError(cfg.task)
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps
return env

59
tdmpc2/envs/myosuite.py Normal file
View File

@@ -0,0 +1,59 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
MYOSUITE_TASKS = {
'myo-finger-reach': 'myoFingerReachFixed-v0',
'myo-finger-reach-hard': 'myoFingerReachRandom-v0',
'myo-finger-pose': 'myoFingerPoseFixed-v0',
'myo-finger-pose-hard': 'myoFingerPoseRandom-v0',
'myo-hand-reach': 'myoHandReachFixed-v0',
'myo-hand-reach-hard': 'myoHandReachRandom-v0',
'myo-hand-pose': 'myoHandPoseFixed-v0',
'myo-hand-pose-hard': 'myoHandPoseRandom-v0',
'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0',
'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0',
'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
}
class MyoSuiteWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
self.env = env
self.cfg = cfg
self.camera_id = 'hand_side_inter'
def step(self, action):
obs, reward, _, info = self.env.step(action.copy())
obs = obs.astype(np.float32)
info['success'] = info['solved']
return obs, reward, False, info
@property
def unwrapped(self):
return self.env.unwrapped
def render(self, *args, **kwargs):
return self.env.sim.renderer.render_offscreen(
width=384, height=384, camera_id=self.camera_id
).copy()
def make_env(cfg):
"""
Make Myosuite environment.
"""
if not cfg.task in MYOSUITE_TASKS:
raise UnknownTaskError(cfg.task)
import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps
return env

View File

View File

@@ -0,0 +1,99 @@
import collections
import os
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import ball_in_cup
from dm_control.suite import common
from dm_control.utils import rewards
from dm_control.utils import io as resources
import numpy as np
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
_DIST_TARGET = 0.5
_TARGET_SPEED = 6.
_DEFAULT_TIME_LIMIT = 20 # (seconds)
_CONTROL_TIMESTEP = .02 # (seconds)
def get_model_and_assets():
"""Returns a tuple containing the model XML string and a dict of assets."""
return resources.GetResource(os.path.join(_TASKS_DIR, 'ball_in_cup.xml')), common.ASSETS
@ball_in_cup.SUITE.add('custom')
def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Ball-in-Cup Spin task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = CustomBallInCup(random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
**environment_kwargs)
class Physics(mujoco.Physics):
"""Physics with additional features for the Ball-in-Cup domain."""
def ball_to_target(self):
"""Returns the vector from the ball to the target."""
target = self.named.data.site_xpos['target', ['x', 'z']]
ball = self.named.data.xpos['ball', ['x', 'z']]
return target - ball
def in_target(self):
"""Returns 1 if the ball is in the target, 0 otherwise."""
ball_to_target = abs(self.ball_to_target())
target_size = self.named.model.site_size['target', [0, 2]]
ball_size = self.named.model.geom_size['ball', 0]
return float(all(ball_to_target < target_size - ball_size))
class CustomBallInCup(ball_in_cup.BallInCup):
"""Custom Ball-in-Cup tasks."""
def initialize_episode(self, physics):
# Find a collision-free random initial position of the ball.
penetrating = True
valid_pos = False
init_out_of_target = self.random.uniform() < 0.1
while penetrating or not valid_pos:
# Assign a random ball position.
physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
# Check for collisions.
physics.after_reset()
penetrating = physics.data.ncon > 0
valid_pos = bool(physics.in_target()) or init_out_of_target
base.Task.initialize_episode(self, physics)
def get_observation(self, physics):
"""Returns an observation of the state."""
obs = collections.OrderedDict()
obs['position'] = physics.position()
obs['velocity'] = physics.velocity()
return obs
def get_reward(self, physics):
dist = np.linalg.norm(physics.ball_to_target())
ball_vel_x = abs(physics.named.data.qvel['ball_x'])
ball_vel_z = abs(physics.named.data.qvel['ball_z'])
ball_vel = np.linalg.norm([ball_vel_x, ball_vel_z])
# reward: spin around target (maximize distance to target + ball velocity)
dist_reward = rewards.tolerance(dist,
bounds=(_DIST_TARGET, float('inf')),
margin=_DIST_TARGET/2,
value_at_margin=0.5,
sigmoid='linear')
not_in_target = 1 - physics.in_target()
vel_reward = rewards.tolerance(ball_vel,
bounds=(_TARGET_SPEED, float('inf')),
margin=_TARGET_SPEED/2,
value_at_margin=0.5,
sigmoid='linear')
spin_reward = not_in_target * (dist_reward + 2*vel_reward) / 3
return spin_reward

View File

@@ -0,0 +1,53 @@
<mujoco model="ball in cup">
<include file="./common/visual.xml"/>
<include file="./common/skybox.xml"/>
<include file="./common/materials.xml"/>
<default>
<motor ctrllimited="true" ctrlrange="-1 1" gear="5"/>
<default class="cup">
<joint type="slide" damping="3" stiffness="20"/>
<geom type="capsule" size=".008" material="self"/>
</default>
</default>
<worldbody>
<light name="light" directional="true" diffuse=".6 .6 .6" pos="0 0 2" specular=".3 .3 .3"/>
<geom name="ground" type="plane" pos="0 0 0" size=".6 .2 10" material="grid"/>
<camera name="cam0" pos="0 -1 .8" xyaxes="1 0 0 0 1 2"/>
<camera name="cam1" pos="0 -1 .4" xyaxes="1 0 0 0 0 1" />
<body name="cup" pos="0 0 .6" childclass="cup">
<joint name="cup_x" axis="1 0 0"/>
<joint name="cup_z" axis="0 0 1"/>
<geom name="cup_part_0" fromto="-.05 0 0 -.05 0 -.075" />
<geom name="cup_part_1" fromto="-.05 0 -.075 -.025 0 -.1" />
<geom name="cup_part_2" fromto="-.025 0 -.1 .025 0 -.1" />
<geom name="cup_part_3" fromto=".025 0 -.1 .05 0 -.075" />
<geom name="cup_part_4" fromto=".05 0 -.075 .05 0 0" />
<site name="cup" pos="0 0 -.108" size=".005"/>
<site name="target" type="box" pos="0 0 -.05" size=".05 .006 .05" group="4"/>
</body>
<body name="ball" pos="0 0 .2">
<joint name="ball_x" type="slide" axis="1 0 0"/>
<joint name="ball_z" type="slide" axis="0 0 1"/>
<geom name="ball" type="sphere" size=".025" material="effector"/>
<site name="ball" size=".005"/>
</body>
</worldbody>
<actuator>
<motor name="x" joint="cup_x"/>
<motor name="z" joint="cup_z"/>
</actuator>
<tendon>
<spatial name="string" limited="true" range="0 0.3" width="0.003">
<site site="ball"/>
<site site="cup"/>
</spatial>
</tendon>
</mujoco>

View File

@@ -0,0 +1,268 @@
import os
from dm_control.rl import control
from dm_control.suite import common
from dm_control.suite import cheetah
from dm_control.utils import rewards
from dm_control.utils import io as resources
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
_CHEETAH_JUMP_HEIGHT = 1.2
_CHEETAH_LIE_HEIGHT = 0.25
_CHEETAH_SPIN_SPEED = 8
def get_model_and_assets():
"""Returns a tuple containing the model XML string and a dict of assets."""
return resources.GetResource(os.path.join(_TASKS_DIR, 'cheetah.xml')), common.ASSETS
@cheetah.SUITE.add('custom')
def run_backwards(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Run Backwards task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='run-backwards', move_speed=cheetah._RUN_SPEED*0.8, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def stand_front(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Stand Front task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='stand-front', move_speed=0.5, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def stand_back(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Stand Back task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='stand-back', move_speed=0.5, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def jump(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Jump task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='jump', move_speed=0.5, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def run_front(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Run Front task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='run-front', move_speed=cheetah._RUN_SPEED*0.6, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def run_back(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Run Back task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='run-back', move_speed=cheetah._RUN_SPEED*0.6, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def lie_down(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Lie Down task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='lie-down', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def legs_up(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Legs Up task."""
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='legs-up', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def flip(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Flip task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='flip', move_speed=cheetah._RUN_SPEED, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
@cheetah.SUITE.add('custom')
def flip_backwards(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Flip Backwards task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = CustomCheetah(goal='flip-backwards', move_speed=cheetah._RUN_SPEED*0.8, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(physics, task, time_limit=time_limit,
**environment_kwargs)
class Physics(cheetah.Physics):
"""Physics simulation with additional features for the Cheetah domain."""
def angmomentum(self):
"""Returns the angular momentum of torso of the Cheetah about Y axis."""
return self.named.data.subtree_angmom['torso'][1]
class CustomCheetah(cheetah.Cheetah):
"""Custom Cheetah tasks."""
def __init__(self, goal='run-backwards', move_speed=0, random=None):
super().__init__(random)
self._goal = goal
self._move_speed = move_speed
def _run_backwards_reward(self, physics):
return rewards.tolerance(physics.speed(),
bounds=(-float('inf'), -self._move_speed),
margin=self._move_speed,
value_at_margin=0,
sigmoid='linear')
def _stand_one_foot_reward(self, physics, foot):
"""Note: `foot` is the foot that is *not* on the ground."""
torso_height = physics.named.data.xpos['torso', 'z']
foot_height = physics.named.data.xpos[foot, 'z']
height_reward = rewards.tolerance((torso_height + foot_height)/2,
bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
margin=_CHEETAH_JUMP_HEIGHT/2)
horizontal_speed_reward = rewards.tolerance(physics.speed(),
bounds=(-self._move_speed, self._move_speed),
margin=self._move_speed,
value_at_margin=0,
sigmoid='linear')
stand_reward = (5*height_reward + horizontal_speed_reward) / 6
return stand_reward
def _stand_front_reward(self, physics):
return self._stand_one_foot_reward(physics, 'bfoot')
def _stand_back_reward(self, physics):
return self._stand_one_foot_reward(physics, 'ffoot')
def _jump_reward(self, physics):
front_reward = self._stand_front_reward(physics)
back_reward = self._stand_back_reward(physics)
jump_reward = (front_reward + back_reward) / 2
return jump_reward
def _run_one_foot_reward(self, physics, foot):
"""Note: `foot` is the foot that is *not* on the ground."""
torso_height = physics.named.data.xpos['torso', 'z']
foot_height = physics.named.data.xpos[foot, 'z']
torso_up = rewards.tolerance(torso_height,
bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
margin=_CHEETAH_JUMP_HEIGHT/2)
foot_up = rewards.tolerance(foot_height,
bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
margin=_CHEETAH_JUMP_HEIGHT/2)
up_reward = (3*foot_up + 2*torso_up) / 5
if self._move_speed == 0:
return up_reward
horizontal_speed_reward = rewards.tolerance(physics.speed(),
bounds=(self._move_speed, float('inf')),
margin=self._move_speed,
value_at_margin=0,
sigmoid='linear')
return up_reward * (5*horizontal_speed_reward + 1) / 6
def _run_front_reward(self, physics):
return self._run_one_foot_reward(physics, 'bfoot')
def _run_back_reward(self, physics):
return self._run_one_foot_reward(physics, 'ffoot')
def _lie_down_reward(self, physics):
torso_height = physics.named.data.xpos['torso', 'z']
feet_height = (physics.named.data.xpos['ffoot', 'z'] + physics.named.data.xpos['bfoot', 'z']) / 2
torso_down = rewards.tolerance(torso_height,
bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
margin=_CHEETAH_LIE_HEIGHT,
value_at_margin=0,
sigmoid='linear')
feet_down = rewards.tolerance(feet_height,
bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
margin=_CHEETAH_LIE_HEIGHT,
value_at_margin=0,
sigmoid='linear')
lie_down_reward = (3*torso_down + feet_down) / 4
return lie_down_reward
def _legs_up_reward(self, physics):
torso_height = physics.named.data.xpos['torso', 'z']
torso_down = rewards.tolerance(torso_height,
bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
margin=_CHEETAH_LIE_HEIGHT/2)
get_up = self._run_one_foot_reward(physics, 'bfoot')
legs_up_reward = (5*torso_down + get_up) / 6
return legs_up_reward
def _flip_reward(self, physics, forward=True):
spin_reward = rewards.tolerance(
(1. if forward else -1.) * physics.angmomentum(),
bounds=(_CHEETAH_SPIN_SPEED, float('inf')),
margin=_CHEETAH_SPIN_SPEED,
value_at_margin=0,
sigmoid='linear')
horizontal_speed_reward = rewards.tolerance(
(1. if forward else -1.) * physics.speed(),
bounds=(self._move_speed, float('inf')),
margin=self._move_speed,
value_at_margin=0,
sigmoid='linear')
flip_reward = (2*spin_reward + horizontal_speed_reward) / 3
return flip_reward
def get_reward(self, physics):
if self._goal == 'run-backwards':
return self._run_backwards_reward(physics)
elif self._goal == 'stand-front':
return self._stand_front_reward(physics)
elif self._goal == 'stand-back':
return self._stand_back_reward(physics)
elif self._goal == 'jump':
return self._jump_reward(physics)
elif self._goal == 'run-front':
return self._run_front_reward(physics)
elif self._goal == 'run-back':
return self._run_back_reward(physics)
elif self._goal == 'lie-down':
return self._lie_down_reward(physics)
elif self._goal == 'legs-up':
return self._legs_up_reward(physics)
elif self._goal == 'flip':
return self._flip_reward(physics, forward=True)
elif self._goal == 'flip-backwards':
return self._flip_reward(physics, forward=False)
else:
raise NotImplementedError(f'Goal {self._goal} is not implemented.')
if __name__ == '__main__':
env = jump()
obs = env.reset()
import numpy as np
next_obs, reward, done, info = env.step(np.zeros(6))
print(reward)

View File

@@ -0,0 +1,73 @@
<mujoco model="cheetah">
<include file="./common/skybox.xml"/>
<include file="./common/visual.xml"/>
<include file="./common/materials.xml"/>
<compiler settotalmass="14"/>
<default>
<default class="cheetah">
<joint limited="true" damping=".01" armature=".1" stiffness="8" type="hinge" axis="0 1 0"/>
<geom contype="1" conaffinity="1" condim="3" friction=".4 .1 .1" material="self"/>
</default>
<default class="free">
<joint limited="false" damping="0" armature="0" stiffness="0"/>
</default>
<motor ctrllimited="true" ctrlrange="-1 1"/>
</default>
<statistic center="0 0 .7" extent="2"/>
<option timestep="0.01"/>
<worldbody>
<geom name="ground" type="plane" conaffinity="1" pos="98 0 0" size="200 .8 .5" material="grid"/>
<body name="torso" pos="0 0 .7" childclass="cheetah">
<light name="light" pos="0 0 2" mode="trackcom"/>
<camera name="side" pos="0 -3 0" quat="0.707 0.707 0 0" mode="trackcom"/>
<camera name="back" pos="-1.8 -1.3 0.8" xyaxes="0.45 -0.9 0 0.3 0.15 0.94" mode="trackcom"/>
<joint name="rootx" type="slide" axis="1 0 0" class="free"/>
<joint name="rootz" type="slide" axis="0 0 1" class="free"/>
<joint name="rooty" type="hinge" axis="0 1 0" class="free"/>
<geom name="torso" type="capsule" fromto="-.5 0 0 .5 0 0" size="0.046"/>
<geom name="head" type="capsule" pos=".6 0 .1" euler="0 50 0" size="0.046 .15"/>
<body name="bthigh" pos="-.5 0 0">
<joint name="bthigh" range="-30 60" stiffness="240" damping="6"/>
<geom name="bthigh" type="capsule" pos=".1 0 -.13" euler="0 -218 0" size="0.046 .145"/>
<body name="bshin" pos=".16 0 -.25">
<joint name="bshin" range="-50 50" stiffness="180" damping="4.5"/>
<geom name="bshin" type="capsule" pos="-.14 0 -.07" euler="0 -116 0" size="0.046 .15"/>
<body name="bfoot" pos="-.28 0 -.14">
<joint name="bfoot" range="-230 50" stiffness="120" damping="3"/>
<geom name="bfoot" type="capsule" pos=".03 0 -.097" euler="0 -15 0" size="0.046 .094"/>
</body>
</body>
</body>
<body name="fthigh" pos=".5 0 0">
<joint name="fthigh" range="-57 .40" stiffness="180" damping="4.5"/>
<geom name="fthigh" type="capsule" pos="-.07 0 -.12" euler="0 30 0" size="0.046 .133"/>
<body name="fshin" pos="-.14 0 -.24">
<joint name="fshin" range="-70 50" stiffness="120" damping="3"/>
<geom name="fshin" type="capsule" pos=".065 0 -.09" euler="0 -34 0" size="0.046 .106"/>
<body name="ffoot" pos=".13 0 -.18">
<joint name="ffoot" range="-28 28" stiffness="60" damping="1.5"/>
<geom name="ffoot" type="capsule" pos=".045 0 -.07" euler="0 -34 0" size="0.046 .07"/>
</body>
</body>
</body>
</body>
</worldbody>
<sensor>
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
</sensor>
<actuator>
<motor name="bthigh" joint="bthigh" gear="120" />
<motor name="bshin" joint="bshin" gear="90" />
<motor name="bfoot" joint="bfoot" gear="60" />
<motor name="fthigh" joint="fthigh" gear="90" />
<motor name="fshin" joint="fshin" gear="60" />
<motor name="ffoot" joint="ffoot" gear="30" />
</actuator>
</mujoco>

79
tdmpc2/envs/tasks/fish.py Normal file
View File

@@ -0,0 +1,79 @@
import collections
import os
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite import fish
from dm_control.utils import rewards
from dm_control.utils import io as resources
import numpy as np
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
_DEFAULT_TIME_LIMIT = 40
_CONTROL_TIMESTEP = .04
_JOINTS = ['tail1',
'tail_twist',
'tail2',
'finright_roll',
'finright_pitch',
'finleft_roll',
'finleft_pitch']
def get_model_and_assets():
"""Returns a tuple containing the model XML string and a dict of assets."""
return resources.GetResource(os.path.join(_TASKS_DIR, 'fish.xml')), common.ASSETS
@fish.SUITE.add('custom')
def obstacles(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Fish Obstacles task."""
physics = fish.Physics.from_xml_string(*get_model_and_assets())
task = Obstacles(random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
**environment_kwargs)
class Obstacles(fish.Swim):
"""A custom Fish Obstacles task."""
def __init__(self, random=None):
super().__init__(random=random)
def in_wall(self, physics, name, min_distance=0.08):
"""Returns True if the given body is too close to a wall."""
for wall in ['wall0', 'wall1', 'wall2', 'wall3']:
l1_dist = np.min(np.abs(physics.named.data.geom_xpos[name][:2] - physics.named.data.geom_xpos[wall][:2]))
if l1_dist < min_distance:
return True
return False
def initialize_episode(self, physics):
in_wall = True
while in_wall:
# Randomize fish position.
quat = self.random.randn(4)
physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
for joint in _JOINTS:
physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
# Randomize target position.
physics.named.model.geom_pos['target', 'x'] = self.random.uniform(-.4, .4)
physics.named.model.geom_pos['target', 'y'] = self.random.uniform(-.4, .4)
physics.named.model.geom_pos['target', 'z'] = self.random.uniform(.1, .3)
# Make sure target is not too close to a wall.
physics.after_reset()
in_wall = self.in_wall(physics, 'target')
base.Task.initialize_episode(self, physics)
def get_reward(self, physics):
radii = physics.named.model.geom_size[['mouth', 'target'], 0].sum()
in_target = rewards.tolerance(np.linalg.norm(physics.mouth_to_target()),
bounds=(0, radii), margin=2*radii)
is_upright = 0.5 * (physics.upright() + 1)
is_not_in_wall = 1. - self.in_wall(physics, 'torso', min_distance=0.06)
return is_not_in_wall * (7*in_target + is_upright) / 8

View File

@@ -0,0 +1,93 @@
<mujoco model="fish">
<include file="./common/visual.xml"/>
<include file="./common/materials.xml"/>
<asset>
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>
</asset>
<option timestep="0.004" density="5000">
<flag gravity="disable" constraint="disable"/>
</option>
<default>
<general ctrllimited="true"/>
<default class="fish">
<joint type="hinge" limited="false" range="-60 60" damping="2e-5" solreflimit=".1 1" solimplimit="0 .8 .1"/>
<geom material="self"/>
</default>
<default class="wall">
<geom type="box" material="self"/>
</default>
</default>
<worldbody>
<camera name="tracking_top" pos="0 0 1" xyaxes="1 0 0 0 1 0" mode="trackcom"/>
<camera name="tracking_x" pos="-.3 0 .2" xyaxes="0 -1 0 0.342 0 0.940" fovy="60" mode="trackcom"/>
<camera name="tracking_y" pos="0 -.3 .2" xyaxes="1 0 0 0 0.342 0.940" fovy="60" mode="trackcom"/>
<camera name="fixed_top" pos="0 0 5.5" fovy="10"/>
<geom name="ground" type="plane" size=".5 .5 .1" material="grid"/>
<geom name="wall0" class="wall" pos="-.15 -.15 .1" size=".05 .05 .1"/>
<geom name="wall1" class="wall" pos=".15 -.15 .1" size=".05 .05 .1"/>
<geom name="wall2" class="wall" pos=".15 .15 .1" size=".05 .05 .1"/>
<geom name="wall3" class="wall" pos="-.15 .15 .1" size=".05 .05 .1"/>
<geom name="target" type="sphere" pos="0 .4 .1" size=".04" material="target"/>
<body name="torso" pos="0 0 .1" childclass="fish">
<light name="light" diffuse=".6 .6 .6" pos="0 0 0.5" dir="0 0 -1" specular=".3 .3 .3" mode="track"/>
<joint name="root" type="free" damping="0" limited="false"/>
<site name="torso" size=".01" rgba="0 0 0 0"/>
<geom name="eye" type="ellipsoid" pos="0 .055 .015" size=".008 .012 .008" euler="-10 0 0" material="eye" mass="0"/>
<camera name="eye" pos="0 .06 .02" xyaxes="1 0 0 0 0 1"/>
<geom name="mouth" type="capsule" fromto="0 .079 0 0 .07 0" size=".005" material="effector" mass="0"/>
<geom name="lower_mouth" type="capsule" fromto="0 .079 -.004 0 .07 -.003" size=".0045" material="effector" mass="0"/>
<geom name="torso" type="ellipsoid" size=".01 .08 .04" mass="0"/>
<geom name="back_fin" type="ellipsoid" size=".001 .03 .015" pos="0 -.03 .03" material="effector" mass="0"/>
<geom name="torso_massive" type="box" size=".002 .06 .03" group="4"/>
<body name="tail1" pos="0 -.09 0">
<joint name="tail1" axis="0 0 1" pos="0 .01 0"/>
<joint name="tail_twist" axis="0 1 0" pos="0 .01 0" range="-30 30"/>
<geom name="tail1" type="ellipsoid" size=".001 .008 .016"/>
<body name="tail2" pos="0 -.028 0">
<joint name="tail2" axis="0 0 1" pos="0 .02 0" stiffness="8e-5"/>
<geom name="tail2" type="ellipsoid" size=".001 .018 .035"/>
</body>
</body>
<body name="finright" pos=".01 0 0">
<joint name="finright_roll" axis="0 1 0"/>
<joint name="finright_pitch" axis="1 0 0" pos="0 .005 0"/>
<geom name="finright" type="ellipsoid" pos=".015 0 0" size=".02 .015 .001" />
</body>
<body name="finleft" pos="-.01 0 0">
<joint name="finleft_roll" axis="0 1 0"/>
<joint name="finleft_pitch" axis="1 0 0" pos="0 .005 0"/>
<geom name="finleft" type="ellipsoid" pos="-.015 0 0" size=".02 .015 .001"/>
</body>
</body>
</worldbody>
<tendon>
<fixed name="fins_flap">
<joint joint="finleft_roll" coef="-.5"/>
<joint joint="finright_roll" coef=".5"/>
</fixed>
<fixed name="fins_sym" stiffness="1e-4">
<joint joint="finleft_roll" coef=".5"/>
<joint joint="finright_roll" coef=".5"/>
</fixed>
</tendon>
<actuator>
<position name="tail" joint="tail1" ctrlrange="-1 1" kp="5e-4"/>
<position name="tail_twist" joint="tail_twist" ctrlrange="-1 1" kp="1e-4"/>
<position name="fins_flap" tendon="fins_flap" ctrlrange="-1 1" kp="3e-4"/>
<position name="finleft_pitch" joint="finleft_pitch" ctrlrange="-1 1" kp="1e-4"/>
<position name="finright_pitch" joint="finright_pitch" ctrlrange="-1 1" kp="1e-4"/>
</actuator>
<sensor>
<velocimeter name="velocimeter" site="torso"/>
<gyro name="gyro" site="torso"/>
</sensor>
</mujoco>

114
tdmpc2/envs/tasks/hopper.py Normal file
View File

@@ -0,0 +1,114 @@
import os
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import common
from dm_control.suite import hopper
from dm_control.utils import rewards
from dm_control.utils import io as resources
import numpy as np
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
_CONTROL_TIMESTEP = .02 # (Seconds)
# Default duration of an episode, in seconds.
_DEFAULT_TIME_LIMIT = 20
# Minimal height of torso over foot above which stand reward is 1.
_STAND_HEIGHT = 0.6
# Hopping speed above which hop reward is 1.
_HOP_SPEED = 2
# Angular momentum above which reward is 1.
_SPIN_SPEED = 5
def get_model_and_assets():
"""Returns a tuple containing the model XML string and a dict of assets."""
return resources.GetResource(os.path.join(_TASKS_DIR, 'hopper.xml')), common.ASSETS
@hopper.SUITE.add('custom')
def hop_backwards(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Hop Backwards task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = CustomHopper(goal='hop-backwards', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
**environment_kwargs)
@hopper.SUITE.add('custom')
def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Flip task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = CustomHopper(goal='flip', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
**environment_kwargs)
@hopper.SUITE.add('custom')
def flip_backwards(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Flip Backwards task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = CustomHopper(goal='flip-backwards', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
**environment_kwargs)
class Physics(hopper.Physics):
def angmomentum(self):
"""Returns the angular momentum of torso of the Cheetah about Y axis."""
return self.named.data.subtree_angmom['torso'][1]
class CustomHopper(hopper.Hopper):
"""Custom Hopper tasks."""
def __init__(self, goal='hop-backwards', random=None):
super().__init__(None, random)
self._goal = goal
def _hop_backwards_reward(self, physics):
standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
hopping = rewards.tolerance(physics.speed(),
bounds=(-float('inf'), -_HOP_SPEED/2),
margin=_HOP_SPEED/4,
value_at_margin=0.5,
sigmoid='linear')
return standing * hopping
def _flip_reward(self, physics, forward=True):
reward = rewards.tolerance((1. if forward else -1.) * physics.angmomentum(),
bounds=(_SPIN_SPEED, float('inf')),
margin=_SPIN_SPEED/2,
value_at_margin=0,
sigmoid='linear')
return reward
def get_reward(self, physics):
if self._goal == 'hop-backwards':
return self._hop_backwards_reward(physics)
elif self._goal == 'flip':
return self._flip_reward(physics, forward=True)
elif self._goal == 'flip-backwards':
return self._flip_reward(physics, forward=False)
else:
raise NotImplementedError(f'Goal {self._goal} is not implemented.')
if __name__ == '__main__':
env = hop_backwards()
obs = env.reset()
import numpy as np
next_obs, reward, done, info = env.step(np.zeros(2))
print(reward)

View File

@@ -0,0 +1,66 @@
<mujoco model="planar hopper">
<include file="./common/skybox.xml"/>
<include file="./common/visual.xml"/>
<include file="./common/materials.xml"/>
<statistic extent="2" center="0 0 .5"/>
<default>
<default class="hopper">
<joint type="hinge" axis="0 1 0" limited="true" damping=".05" armature=".2"/>
<geom type="capsule" material="self"/>
<site type="sphere" size="0.05" group="3"/>
</default>
<default class="free">
<joint limited="false" damping="0" armature="0" stiffness="0"/>
</default>
<motor ctrlrange="-1 1" ctrllimited="true"/>
</default>
<option timestep="0.005"/>
<worldbody>
<camera name="cam0" pos="0 -2.8 0.8" euler="90 0 0" mode="trackcom"/>
<camera name="back" pos="-2 -.2 1.2" xyaxes="0.2 -1 0 .5 0 2" mode="trackcom"/>
<geom name="floor" type="plane" conaffinity="1" pos="48 0 0" size="50 1 .2" material="grid"/>
<body name="torso" pos="0 0 1" childclass="hopper">
<light name="top" pos="0 0 2" mode="trackcom"/>
<joint name="rootx" type="slide" axis="1 0 0" class="free"/>
<joint name="rootz" type="slide" axis="0 0 1" class="free"/>
<joint name="rooty" type="hinge" axis="0 1 0" class="free"/>
<geom name="torso" fromto="0 0 -.05 0 0 .2" size="0.0653"/>
<geom name="nose" fromto=".08 0 .13 .15 0 .14" size="0.03"/>
<body name="pelvis" pos="0 0 -.05">
<joint name="waist" range="-30 30"/>
<geom name="pelvis" fromto="0 0 0 0 0 -.15" size="0.065"/>
<body name="thigh" pos="0 0 -.2">
<joint name="hip" range="-170 10"/>
<geom name="thigh" fromto="0 0 0 0 0 -.33" size="0.04"/>
<body name="calf" pos="0 0 -.33">
<joint name="knee" range="5 150"/>
<geom name="calf" fromto="0 0 0 0 0 -.32" size="0.03"/>
<body name="foot" pos="0 0 -.32">
<joint name="ankle" range="-45 45"/>
<geom name="foot" fromto="-.08 0 0 .17 0 0" size="0.04"/>
<site name="touch_toe" pos=".17 0 0"/>
<site name="touch_heel" pos="-.08 0 0"/>
</body>
</body>
</body>
</body>
</body>
</worldbody>
<sensor>
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
<touch name="touch_toe" site="touch_toe"/>
<touch name="touch_heel" site="touch_heel"/>
</sensor>
<actuator>
<motor name="waist" joint="waist" gear="30"/>
<motor name="hip" joint="hip" gear="40"/>
<motor name="knee" joint="knee" gear="30"/>
<motor name="ankle" joint="ankle" gear="10"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,43 @@
import os
from dm_control.rl import control
from dm_control.suite import pendulum
from dm_control.suite import common
from dm_control.utils import rewards
from dm_control.utils import io as resources
import numpy as np
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
_DEFAULT_TIME_LIMIT = 20
_TARGET_SPEED = 9.
def get_model_and_assets():
"""Returns a tuple containing the model XML string and a dict of assets."""
return resources.GetResource(os.path.join(_TASKS_DIR, 'pendulum.xml')), common.ASSETS
@pendulum.SUITE.add('custom')
def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None,
environment_kwargs=None):
"""Returns pendulum spin task."""
physics = pendulum.Physics.from_xml_string(*get_model_and_assets())
task = Spin(random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, **environment_kwargs)
class Spin(pendulum.SwingUp):
"""A custom Pendulum Spin task."""
def __init__(self, random=None):
super().__init__(random=random)
def get_reward(self, physics):
return rewards.tolerance(np.linalg.norm(physics.angular_velocity()),
bounds=(_TARGET_SPEED, float('inf')),
margin=_TARGET_SPEED/2,
value_at_margin=0.5,
sigmoid='linear')

View File

@@ -0,0 +1,26 @@
<mujoco model="pendulum">
<include file="./common/visual.xml"/>
<include file="./common/skybox.xml"/>
<include file="./common/materials.xml"/>
<option timestep="0.02">
<flag contact="disable" energy="enable"/>
</option>
<worldbody>
<light name="light" pos="0 0 2"/>
<geom name="floor" size="2 2 .2" type="plane" material="grid"/>
<camera name="fixed" pos="0 -1.5 2" xyaxes='1 0 0 0 1 1'/>
<camera name="lookat" mode="targetbodycom" target="pole" pos="0 -2 1"/>
<body name="pole" pos="0 0 .6">
<joint name="hinge" type="hinge" axis="0 1 0" damping="0.1"/>
<geom name="base" material="decoration" type="cylinder" fromto="0 -.03 0 0 .03 0" size="0.021" mass="0"/>
<geom name="pole" material="self" type="capsule" fromto="0 0 0 0 0 0.5" size="0.02" mass="0"/>
<geom name="mass" material="effector" type="sphere" pos="0 0 0.5" size="0.05" mass="1"/>
</body>
</worldbody>
<actuator>
<motor name="torque" joint="hinge" gear="1" ctrlrange="-1 1" ctrllimited="true"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,89 @@
import collections
import os
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import common
from dm_control.suite import reacher
from dm_control.utils import io as resources
import numpy as np
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
_DEFAULT_TIME_LIMIT = 20
_BIG_TARGET = .05
_SMALL_TARGET = .015
def get_model_and_assets(links):
"""Returns a tuple containing the model XML string and a dict of assets."""
assert links in {3, 4}, 'Only 3 or 4 links are supported.'
fn = 'reacher_three_links.xml' if links == 3 else 'reacher_four_links.xml'
return resources.GetResource(os.path.join(_TASKS_DIR, fn)), common.ASSETS
@reacher.SUITE.add('custom')
def three_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns three-link reacher with sparse reward with 5e-2 tol and randomized target."""
physics = Physics.from_xml_string(*get_model_and_assets(links=3))
task = CustomThreeLinkReacher(target_size=_BIG_TARGET, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, **environment_kwargs)
@reacher.SUITE.add('custom')
def three_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns three-link reacher with sparse reward with 1e-2 tol and randomized target."""
physics = Physics.from_xml_string(*get_model_and_assets(links=3))
task = CustomThreeLinkReacher(target_size=_SMALL_TARGET, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, **environment_kwargs)
@reacher.SUITE.add('custom')
def four_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns three-link reacher with sparse reward with 5e-2 tol and randomized target."""
physics = Physics.from_xml_string(*get_model_and_assets(links=4))
task = CustomThreeLinkReacher(target_size=_BIG_TARGET, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, **environment_kwargs)
@reacher.SUITE.add('custom')
def four_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns three-link reacher with sparse reward with 1e-2 tol and randomized target."""
physics = Physics.from_xml_string(*get_model_and_assets(links=4))
task = CustomThreeLinkReacher(target_size=_SMALL_TARGET, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, **environment_kwargs)
class Physics(mujoco.Physics):
"""Physics simulation with additional features for the Reacher domain."""
def finger_to_target(self):
"""Returns the vector from target to finger in global coordinates."""
return (self.named.data.geom_xpos['target', :2] -
self.named.data.geom_xpos['finger', :2])
def finger_to_target_dist(self):
"""Returns the signed distance between the finger and target surface."""
return np.linalg.norm(self.finger_to_target())
class CustomThreeLinkReacher(reacher.Reacher):
"""Custom Reacher tasks."""
def __init__(self, target_size, random=None):
super().__init__(target_size, random)
def get_observation(self, physics):
obs = collections.OrderedDict()
obs['position'] = physics.position()
obs['to_target'] = physics.finger_to_target()
obs['velocity'] = physics.velocity()
return obs

View File

@@ -0,0 +1,57 @@
<mujoco model="two-link planar reacher">
<include file="./common/skybox.xml"/>
<include file="./common/visual.xml"/>
<include file="./common/materials.xml"/>
<option timestep="0.02">
<flag contact="disable"/>
</option>
<default>
<joint type="hinge" axis="0 0 1" damping="0.01"/>
<motor gear=".05" ctrlrange="-1 1" ctrllimited="true"/>
</default>
<worldbody>
<light name="light" pos="0 0 1"/>
<camera name="fixed" pos="0 0 .75" quat="1 0 0 0"/>
<!-- Arena -->
<geom name="ground" type="plane" pos="0 0 0" size=".3 .3 10" material="grid"/>
<geom name="wall_x" type="plane" pos="-.3 0 .02" zaxis="1 0 0" size=".02 .3 .02" material="decoration"/>
<geom name="wall_y" type="plane" pos="0 -.3 .02" zaxis="0 1 0" size=".3 .02 .02" material="decoration"/>
<geom name="wall_neg_x" type="plane" pos=".3 0 .02" zaxis="-1 0 0" size=".02 .3 .02" material="decoration"/>
<geom name="wall_neg_y" type="plane" pos="0 .3 .02" zaxis="0 -1 0" size=".3 .02 .02" material="decoration"/>
<!-- Arm -->
<geom name="root" type="cylinder" fromto="0 0 0 0 0 0.02" size=".011" material="decoration"/>
<body name="arm0" pos="0 0 .01">
<geom name="arm0" type="capsule" fromto="0 0 0 0.06 0 0" size=".01" material="self"/>
<joint name="shoulder0"/>
<body name="arm1" pos=".06 0 0">
<geom name="arm1" type="capsule" fromto="0 0 0 0.06 0 0" size=".01" material="self"/>
<joint name="shoulder1" limited="true" range="-80 80"/>
<body name="arm2" pos=".06 0 0">
<geom name="arm2" type="capsule" fromto="0 0 0 0.06 0 0" size=".01" material="self"/>
<joint name="shoulder2" limited="true" range="-80 80"/>
<body name="hand" pos=".06 0 0">
<geom name="hand" type="capsule" fromto="0 0 0 0.1 0 0" size=".01" material="self"/>
<joint name="wrist" limited="true" range="-80 80"/>
<body name="finger" pos=".06 0 0">
<camera name="hand" pos="0 0 .2" mode="track"/>
<geom name="finger" type="sphere" size=".01" material="effector"/>
</body>
</body>
</body>
</body>
</body>
<!-- Target -->
<geom name="target" pos="0 0 .01" material="target" type="sphere" size=".05"/>
</worldbody>
<actuator>
<motor name="shoulder0" joint="shoulder0"/>
<motor name="shoulder1" joint="shoulder1"/>
<motor name="shoulder2" joint="shoulder2"/>
<motor name="wrist" joint="wrist"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,52 @@
<mujoco model="two-link planar reacher">
<include file="./common/skybox.xml"/>
<include file="./common/visual.xml"/>
<include file="./common/materials.xml"/>
<option timestep="0.02">
<flag contact="disable"/>
</option>
<default>
<joint type="hinge" axis="0 0 1" damping="0.01"/>
<motor gear=".05" ctrlrange="-1 1" ctrllimited="true"/>
</default>
<worldbody>
<light name="light" pos="0 0 1"/>
<camera name="fixed" pos="0 0 .75" quat="1 0 0 0"/>
<!-- Arena -->
<geom name="ground" type="plane" pos="0 0 0" size=".3 .3 10" material="grid"/>
<geom name="wall_x" type="plane" pos="-.3 0 .02" zaxis="1 0 0" size=".02 .3 .02" material="decoration"/>
<geom name="wall_y" type="plane" pos="0 -.3 .02" zaxis="0 1 0" size=".3 .02 .02" material="decoration"/>
<geom name="wall_neg_x" type="plane" pos=".3 0 .02" zaxis="-1 0 0" size=".02 .3 .02" material="decoration"/>
<geom name="wall_neg_y" type="plane" pos="0 .3 .02" zaxis="0 -1 0" size=".3 .02 .02" material="decoration"/>
<!-- Arm -->
<geom name="root" type="cylinder" fromto="0 0 0 0 0 0.02" size=".011" material="decoration"/>
<body name="arm0" pos="0 0 .01">
<geom name="arm0" type="capsule" fromto="0 0 0 0.09 0 0" size=".01" material="self"/>
<joint name="shoulder0"/>
<body name="arm1" pos=".09 0 0">
<geom name="arm1" type="capsule" fromto="0 0 0 0.09 0 0" size=".01" material="self"/>
<joint name="shoulder1" limited="true" range="-80 80"/>
<body name="hand" pos=".09 0 0">
<geom name="hand" type="capsule" fromto="0 0 0 0.1 0 0" size=".01" material="self"/>
<joint name="wrist" limited="true" range="-80 80"/>
<body name="finger" pos=".09 0 0">
<camera name="hand" pos="0 0 .2" mode="track"/>
<geom name="finger" type="sphere" size=".01" material="effector"/>
</body>
</body>
</body>
</body>
<!-- Target -->
<geom name="target" pos="0 0 .01" material="target" type="sphere" size=".05"/>
</worldbody>
<actuator>
<motor name="shoulder0" joint="shoulder0"/>
<motor name="shoulder1" joint="shoulder1"/>
<motor name="wrist" joint="wrist"/>
</actuator>
</mujoco>

223
tdmpc2/envs/tasks/walker.py Normal file
View File

@@ -0,0 +1,223 @@
import os
from dm_control.rl import control
from dm_control.suite import common
from dm_control.suite import walker
from dm_control.utils import rewards
from dm_control.utils import io as resources
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
_YOGA_STAND_HEIGHT = 1.0
_YOGA_LIE_DOWN_HEIGHT = 0.08
_YOGA_LEGS_UP_HEIGHT = 1.1
def get_model_and_assets():
"""Returns a tuple containing the model XML string and a dict of assets."""
return resources.GetResource(os.path.join(_TASKS_DIR, 'walker.xml')), common.ASSETS
@walker.SUITE.add('custom')
def walk_backwards(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Walk Backwards task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = BackwardsPlanarWalker(move_speed=walker._WALK_SPEED, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
@walker.SUITE.add('custom')
def run_backwards(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Run Backwards task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = BackwardsPlanarWalker(move_speed=walker._RUN_SPEED, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
@walker.SUITE.add('custom')
def arabesque(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Arabesque task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = YogaPlanarWalker(goal='arabesque', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
@walker.SUITE.add('custom')
def lie_down(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Lie Down task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = YogaPlanarWalker(goal='lie_down', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
@walker.SUITE.add('custom')
def legs_up(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Legs Up task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = YogaPlanarWalker(goal='legs_up', random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
@walker.SUITE.add('custom')
def headstand(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Headstand task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = YogaPlanarWalker(goal='flip', move_speed=0, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
@walker.SUITE.add('custom')
def flip(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Flip task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = YogaPlanarWalker(goal='flip', move_speed=walker._RUN_SPEED*0.75, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
@walker.SUITE.add('custom')
def backflip(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Backflip task."""
physics = walker.Physics.from_xml_string(*get_model_and_assets())
task = YogaPlanarWalker(goal='flip', move_speed=-walker._RUN_SPEED*0.75, random=random)
environment_kwargs = environment_kwargs or {}
return control.Environment(
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
**environment_kwargs)
class BackwardsPlanarWalker(walker.PlanarWalker):
"""Backwards PlanarWalker task."""
def __init__(self, move_speed, random=None):
super().__init__(move_speed, random)
def get_reward(self, physics):
standing = rewards.tolerance(physics.torso_height(),
bounds=(walker._STAND_HEIGHT, float('inf')),
margin=walker._STAND_HEIGHT/2)
upright = (1 + physics.torso_upright()) / 2
stand_reward = (3*standing + upright) / 4
if self._move_speed == 0:
return stand_reward
else:
move_reward = rewards.tolerance(physics.horizontal_velocity(),
bounds=(-float('inf'), -self._move_speed),
margin=self._move_speed/2,
value_at_margin=0.5,
sigmoid='linear')
return stand_reward * (5*move_reward + 1) / 6
class YogaPlanarWalker(walker.PlanarWalker):
"""Yoga PlanarWalker tasks."""
def __init__(self, goal='arabesque', move_speed=0, random=None):
super().__init__(0, random)
self._goal = goal
self._move_speed = move_speed
def _arabesque_reward(self, physics):
standing = rewards.tolerance(physics.torso_height(),
bounds=(_YOGA_STAND_HEIGHT, float('inf')),
margin=_YOGA_STAND_HEIGHT/2)
left_foot_height = physics.named.data.xpos['left_foot', 'z']
right_foot_height = physics.named.data.xpos['right_foot', 'z']
left_foot_down = rewards.tolerance(left_foot_height,
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
margin=_YOGA_STAND_HEIGHT/2)
right_foot_up = rewards.tolerance(right_foot_height,
bounds=(_YOGA_STAND_HEIGHT, float('inf')),
margin=_YOGA_STAND_HEIGHT/2)
upright = (1 - physics.torso_upright()) / 2
arabesque_reward = (3*standing + left_foot_down + right_foot_up + upright) / 6
return arabesque_reward
def _lie_down_reward(self, physics):
torso_down = rewards.tolerance(physics.torso_height(),
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
margin=_YOGA_LIE_DOWN_HEIGHT/2)
thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
thigh_down = rewards.tolerance(thigh_height,
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
margin=_YOGA_LIE_DOWN_HEIGHT/2)
feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
feet_down = rewards.tolerance(feet_height,
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
margin=_YOGA_LIE_DOWN_HEIGHT/2)
upright = (1 - physics.torso_upright()) / 2
lie_down_reward = (3*torso_down + thigh_down + upright) / 5
return lie_down_reward
def _legs_up_reward(self, physics):
torso_down = rewards.tolerance(physics.torso_height(),
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
margin=_YOGA_LIE_DOWN_HEIGHT/2)
thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
thigh_down = rewards.tolerance(thigh_height,
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
margin=_YOGA_LIE_DOWN_HEIGHT/2)
feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
legs_up = rewards.tolerance(feet_height,
bounds=(_YOGA_LEGS_UP_HEIGHT, float('inf')),
margin=_YOGA_LEGS_UP_HEIGHT/2)
upright = (1 - physics.torso_upright()) / 2
legs_up_reward = (3*torso_down + 2*legs_up + thigh_down + upright) / 7
return legs_up_reward
def _flip_reward(self, physics):
thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
thigh_up = rewards.tolerance(thigh_height,
bounds=(_YOGA_STAND_HEIGHT, float('inf')),
margin=_YOGA_STAND_HEIGHT/2)
feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
legs_up = rewards.tolerance(feet_height,
bounds=(_YOGA_LEGS_UP_HEIGHT, float('inf')),
margin=_YOGA_LEGS_UP_HEIGHT/2)
upside_down_reward = (3*legs_up + 2*thigh_up) / 5
if self._move_speed == 0:
return upside_down_reward
move_reward = rewards.tolerance(physics.horizontal_velocity(),
bounds=(self._move_speed, float('inf')) if self._move_speed > 0 else (-float('inf'), self._move_speed),
margin=abs(self._move_speed)/2,
value_at_margin=0.5,
sigmoid='linear')
return upside_down_reward * (5*move_reward + 1) / 6
def get_reward(self, physics):
if self._goal == 'arabesque':
return self._arabesque_reward(physics)
elif self._goal == 'lie_down':
return self._lie_down_reward(physics)
elif self._goal == 'legs_up':
return self._legs_up_reward(physics)
elif self._goal == 'flip':
return self._flip_reward(physics)
else:
raise NotImplementedError(f'Goal {self._goal} is not implemented.')
if __name__ == '__main__':
env = legs_up()
obs = env.reset()
import numpy as np
next_obs, reward, done, info = env.step(np.zeros(6))

View File

@@ -0,0 +1,70 @@
<mujoco model="planar walker">
<include file="./common/visual.xml"/>
<include file="./common/skybox.xml"/>
<include file="./common/materials.xml"/>
<option timestep="0.0025"/>
<statistic extent="2" center="0 0 1"/>
<default>
<joint damping=".1" armature="0.01" limited="true" solimplimit="0 .99 .01"/>
<geom contype="1" conaffinity="0" friction=".7 .1 .1"/>
<motor ctrlrange="-1 1" ctrllimited="true"/>
<site size="0.01"/>
<default class="walker">
<geom material="self" type="capsule"/>
<joint axis="0 -1 0"/>
</default>
</default>
<worldbody>
<geom name="floor" type="plane" conaffinity="1" pos="248 0 0" size="500 .8 .2" material="grid" zaxis="0 0 1"/>
<body name="torso" pos="0 0 1.3" childclass="walker">
<light name="light" pos="0 0 2" mode="trackcom"/>
<camera name="side" pos="0 -2 .7" euler="60 0 0" mode="trackcom"/>
<camera name="back" pos="-2 0 .5" xyaxes="0 -1 0 1 0 3" mode="trackcom"/>
<joint name="rootz" axis="0 0 1" type="slide" limited="false" armature="0" damping="0"/>
<joint name="rootx" axis="1 0 0" type="slide" limited="false" armature="0" damping="0"/>
<joint name="rooty" axis="0 1 0" type="hinge" limited="false" armature="0" damping="0"/>
<geom name="torso" size="0.07 0.3"/>
<body name="right_thigh" pos="0 -.05 -0.3">
<joint name="right_hip" range="-20 100"/>
<geom name="right_thigh" pos="0 0 -0.225" size="0.05 0.225"/>
<body name="right_leg" pos="0 0 -0.7">
<joint name="right_knee" pos="0 0 0.25" range="-150 0"/>
<geom name="right_leg" size="0.04 0.25"/>
<body name="right_foot" pos="0.06 0 -0.25">
<joint name="right_ankle" pos="-0.06 0 0" range="-45 45"/>
<geom name="right_foot" zaxis="1 0 0" size="0.05 0.1"/>
</body>
</body>
</body>
<body name="left_thigh" pos="0 .05 -0.3" >
<joint name="left_hip" range="-20 100"/>
<geom name="left_thigh" pos="0 0 -0.225" size="0.05 0.225"/>
<body name="left_leg" pos="0 0 -0.7">
<joint name="left_knee" pos="0 0 0.25" range="-150 0"/>
<geom name="left_leg" size="0.04 0.25"/>
<body name="left_foot" pos="0.06 0 -0.25">
<joint name="left_ankle" pos="-0.06 0 0" range="-45 45"/>
<geom name="left_foot" zaxis="1 0 0" size="0.05 0.1"/>
</body>
</body>
</body>
</body>
</worldbody>
<sensor>
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
</sensor>
<actuator>
<motor name="right_hip" joint="right_hip" gear="100"/>
<motor name="right_knee" joint="right_knee" gear="50"/>
<motor name="right_ankle" joint="right_ankle" gear="20"/>
<motor name="left_hip" joint="left_hip" gear="100"/>
<motor name="left_knee" joint="left_knee" gear="50"/>
<motor name="left_ankle" joint="left_ankle" gear="20"/>
</actuator>
</mujoco>

View File

View File

@@ -0,0 +1,57 @@
import gym
import numpy as np
import torch
class MultitaskWrapper(gym.Wrapper):
"""
Wrapper for multi-task environments.
"""
def __init__(self, cfg, envs):
super().__init__(envs[0])
self.cfg = cfg
self.envs = envs
self._task = cfg.tasks[0]
self._task_idx = 0
self._obs_dims = [env.observation_space.shape[0] for env in self.envs]
self._action_dims = [env.action_space.shape[0] for env in self.envs]
self._episode_lengths = [env.max_episode_steps for env in self.envs]
self._obs_shape = (max(self._obs_dims),)
self._action_dim = max(self._action_dims)
self.observation_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=self._obs_shape, dtype=np.float32
)
self.action_space = gym.spaces.Box(
low=-1, high=1, shape=(self._action_dim,), dtype=np.float32
)
@property
def task(self):
return self._task
@property
def task_idx(self):
return self._task_idx
@property
def _env(self):
return self.envs[self.task_idx]
def rand_act(self):
return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _pad_obs(self, obs):
if obs.shape != self._obs_shape:
obs = torch.cat((obs, torch.zeros(self._obs_shape[0]-obs.shape[0], dtype=obs.dtype, device=obs.device)))
return obs
def reset(self, task_idx=-1):
self._task_idx = task_idx
self._task = self.cfg.tasks[task_idx]
self.env = self._env
return self._pad_obs(self.env.reset())
def step(self, action):
obs, reward, done, info = self.env.step(action[:self.env.action_space.shape[0]])
return self._pad_obs(obs), reward, done, info

View File

@@ -0,0 +1,40 @@
from collections import defaultdict
import gym
import numpy as np
import torch
class TensorWrapper(gym.Wrapper):
"""
Wrapper for converting numpy arrays to torch tensors.
"""
def __init__(self, env):
super().__init__(env)
def rand_act(self):
return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x):
x = torch.from_numpy(x)
if x.dtype == torch.float64:
x = x.float()
return x
def _obs_to_tensor(self, obs):
if isinstance(obs, dict):
for k in obs.keys():
obs[k] = self._try_f32_tensor(obs[k])
else:
obs = self._try_f32_tensor(obs)
return obs
def reset(self, task_idx=None):
return self._obs_to_tensor(self.env.reset())
def step(self, action):
obs, reward, done, info = self.env.step(action.numpy())
info = defaultdict(float, info)
info['success'] = float(info['success'])
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info

View File

@@ -0,0 +1,72 @@
"""
Wrapper for limiting the time steps of an environment.
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
"""
from typing import Optional
import gym
class TimeLimit(gym.Wrapper):
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
Example:
>>> from gym.envs.classic_control import CartPoleEnv
>>> from gym.wrappers import TimeLimit
>>> env = CartPoleEnv()
>>> env = TimeLimit(env, max_episode_steps=1000)
"""
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
"""
super().__init__(env)
if max_episode_steps is None and self.env.spec is not None:
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
def step(self, action):
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
when truncated (the number of steps elapsed >= max episode steps) or
"TimeLimit.truncated"=False if the environment terminated
"""
observation, reward, done, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
# TimeLimit.truncated key may have been already set by the environment
# do not overwrite it
episode_truncated = not done or info.get("TimeLimit.truncated", False)
info["TimeLimit.truncated"] = episode_truncated
done = True
return observation, reward, done, info
def reset(self, **kwargs):
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
Args:
**kwargs: The kwargs to reset the environment with
Returns:
The reset environment
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)

103
tdmpc2/evaluate.py Executable file
View File

@@ -0,0 +1,103 @@
import os
os.environ['MUJOCO_GL'] = 'egl'
import warnings
warnings.filterwarnings('ignore')
import hydra
import imageio
import numpy as np
import torch
from termcolor import colored
from common.parser import parse_cfg
from common.seed import set_seed
from envs import make_env
from tdmpc2 import TDMPC2
torch.backends.cudnn.benchmark = True
@hydra.main(config_name='config', config_path='.')
def evaluate(cfg: dict):
"""
Script for evaluating a single-task / multi-task TD-MPC2 checkpoint.
Most relevant args:
`task`: task name (or mt30/mt80 for multi-task evaluation)
`model_size`: model size, must be one of `[1, 5, 19, 48, 317]` (default: 5)
`checkpoint`: path to model checkpoint to load
`eval_episodes`: number of episodes to evaluate on per task (default: 10)
`save_video`: whether to save a video of the evaluation (default: True)
`seed`: random seed (default: 1)
See config.yaml for a full list of args.
Example usage:
````
$ python evaluate.py task=mt80 model_size=48 checkpoint=/path/to/mt80-48M.pt
$ python evaluate.py task=mt30 model_size=317 checkpoint=/path/to/mt30-317M.pt
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true
```
"""
assert torch.cuda.is_available()
assert cfg.eval_episodes > 0, 'Must evaluate at least 1 episode.'
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold']))
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))
print(colored('To evaluate a multi-task model, use task=mt80 or task=mt30.', 'red', attrs=['bold']))
# Make environment
env = make_env(cfg)
# Load agent
agent = TDMPC2(cfg)
assert os.path.exists(cfg.checkpoint), f'Checkpoint {cfg.checkpoint} not found! Must be a valid filepath.'
agent.load(cfg.checkpoint)
# Evaluate
if cfg.multitask:
print(colored(f'Evaluating agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold']))
else:
print(colored(f'Evaluating agent on {cfg.task}:', 'yellow', attrs=['bold']))
if cfg.save_video:
video_dir = os.path.join(cfg.work_dir, 'videos')
os.makedirs(video_dir, exist_ok=True)
scores = []
tasks = cfg.tasks if cfg.multitask else [cfg.task]
for task_idx, task in enumerate(tasks):
if not cfg.multitask:
task_idx = None
ep_rewards, ep_successes = [], []
for i in range(cfg.eval_episodes):
obs, done, ep_reward, t = env.reset(task_idx=task_idx), False, 0, 0
if cfg.save_video:
frames = [env.render()]
while not done:
action = agent.act(obs, t0=t==0, task=task_idx)
obs, reward, done, info = env.step(action)
ep_reward += reward
t += 1
if cfg.save_video:
frames.append(env.render())
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
if cfg.save_video:
imageio.mimsave(
os.path.join(video_dir, f'{task}-{i}.mp4'), frames, fps=15)
ep_rewards = np.mean(ep_rewards)
ep_successes = np.mean(ep_successes)
if cfg.multitask:
scores.append(ep_successes*100 if task.startswith('mw-') else ep_rewards/10)
print(colored(f' {task:<22}' \
f'\tR: {ep_rewards:.01f} ' \
f'\tS: {ep_successes:.02f}', 'yellow'))
if cfg.multitask:
print(colored(f'Normalized score: {np.mean(scores):.02f}', 'yellow', attrs=['bold']))
if __name__ == '__main__':
evaluate()

286
tdmpc2/tdmpc2.py Executable file
View File

@@ -0,0 +1,286 @@
import numpy as np
import torch
import torch.nn.functional as F
from common import math
from common.scale import RunningScale
from common.world_model import WorldModel
class TDMPC2:
"""
TD-MPC2 agent. Implements training + inference.
Can be used for both single-task and multi-task experiments.
"""
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device('cuda')
self.model = WorldModel(cfg).to(self.device)
self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()},
{'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
], lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5)
self.model.eval()
self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda'
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
def _get_discount(self, episode_length):
"""
Returns discount factor for a given episode length.
Simple heuristic that scales discount linearly with episode length.
Default values should work well for most tasks, but can be changed as needed.
Args:
episode_length (int): Length of the episode. Assumes episodes are of fixed length.
Returns:
float: Discount factor for the task.
"""
frac = episode_length/self.cfg.discount_denom
return min(max((frac-1)/(frac), self.cfg.discount_min), self.cfg.discount_max)
def save(self, fp):
"""
Save state dict of the agent to filepath.
Args:
fp (str): Filepath to save state dict to.
"""
torch.save({"model": self.model.state_dict()}, fp)
def load(self, fp):
"""
Load a saved state dict from filepath (or dictionary) into current agent.
Args:
fp (str or dict): Filepath or state dict to load.
"""
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
self.model.load_state_dict(state_dict["model"])
@torch.no_grad()
def act(self, obs, t0=False, eval_mode=False, task=None):
"""
Select an action by planning in the latent space of the world model.
Args:
obs (torch.Tensor): Observation from the environment.
t0 (bool): Whether this is the first observation in the episode.
eval_mode (bool): Whether to use the mean of the action distribution.
task (int): Task index (only used for multi-task experiments).
Returns:
torch.Tensor: Action to take in the environment.
"""
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
if task is not None:
task = torch.tensor([task], device=self.device)
z = self.model.encode(obs, task)
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
return a.cpu()
@torch.no_grad()
def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G += discount * reward
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
@torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None):
"""
Plan a sequence of actions using the learned world model.
Args:
z (torch.Tensor): Latent state from which to plan.
t0 (bool): Whether this is the first observation in the episode.
eval_mode (bool): Whether to use the mean of the action distribution.
task (Torch.Tensor): Task index (only used for multi-task experiments).
Returns:
torch.Tensor: Action to take in the environment.
"""
# Sample policy trajectories
if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1]
_z = self.model.next(_z, pi_actions[t], task)
pi_actions[-1] = self.model.pi(_z, task)[1]
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device)
if not t0:
mean[:-1] = self._prev_mean[1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI
for i in range(self.cfg.iterations):
# Sample actions
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \
.clamp(-1, 1)
if self.cfg.multitask:
actions = actions * self.model._action_masks[task]
# Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num_(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
score /= score.sum(0)
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \
.clamp_(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask:
mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task]
# Select action
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean
a, std = actions[0], std[0]
if not eval_mode:
a += std * torch.randn(self.cfg.action_dim, device=std.device)
return a.clamp_(-1, 1)
def update_pi(self, zs, task):
"""
Update policy using a sequence of latent states.
Args:
zs (torch.Tensor): Sequence of latent states.
task (torch.Tensor): Task index (only used for multi-task experiments).
Returns:
float: Loss of the policy update.
"""
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
_, pis, log_pis, _ = self.model.pi(zs, task)
qs = self.model.Q(zs, pis, task, return_type='avg')
self.scale.update(qs[0])
qs = self.scale(qs)
# Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step()
self.model.track_q_grad(True)
return pi_loss.item()
@torch.no_grad()
def _td_target(self, next_z, reward, task):
"""
Compute the TD-target from a reward and the observation at the following time step.
Args:
next_z (torch.Tensor): Latent state at the following time step.
reward (torch.Tensor): Reward at the current time step.
task (torch.Tensor): Task index (only used for multi-task experiments).
Returns:
torch.Tensor: TD-target.
"""
pi = self.model.pi(next_z, task)[1]
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
def update(self, buffer):
"""
Main update function. Corresponds to one iteration of model learning.
Args:
buffer (common.buffer.Buffer): Replay buffer.
Returns:
dict: Dictionary of training statistics.
"""
obs, action, reward, task = buffer.sample()
# Compute targets
with torch.no_grad():
next_z = self.model.encode(obs[1:], task)
td_targets = self._td_target(next_z, reward, task)
# Prepare for update
self.optim.zero_grad(set_to_none=True)
self.model.train()
# Latent rollout
zs = torch.empty(self.cfg.horizon+1, self.cfg.batch_size, self.cfg.latent_dim, device=self.device)
z = self.model.encode(obs[0], task)
zs[0] = z
consistency_loss = 0
for t in range(self.cfg.horizon):
z = self.model.next(z, action[t], task)
consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t
zs[t+1] = z
# Predictions
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
# Compute losses
reward_loss, value_loss = 0, 0
for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon)
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
total_loss = (
self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss +
self.cfg.value_coef * value_loss
)
# Update model
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
self.optim.step()
# Update policy
pi_loss = self.update_pi(zs.detach(), task)
# Update target Q-functions
self.model.soft_update_target_Q()
# Return training statistics
self.model.eval()
return {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"value_loss": float(value_loss.mean().item()),
"pi_loss": pi_loss,
"total_loss": float(total_loss.mean().item()),
"grad_norm": float(grad_norm),
"pi_scale": float(self.scale.value),
}

61
tdmpc2/train.py Executable file
View File

@@ -0,0 +1,61 @@
import os
os.environ['MUJOCO_GL'] = 'egl'
import warnings
warnings.filterwarnings('ignore')
import torch
import hydra
from termcolor import colored
from common.parser import parse_cfg
from common.seed import set_seed
from common.buffer import Buffer
from envs import make_env
from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer
from trainer.online_trainer import OnlineTrainer
from common.logger import Logger
torch.backends.cudnn.benchmark = True
@hydra.main(config_name='config', config_path='.')
def train(cfg: dict):
"""
Script for training single-task / multi-task TD-MPC2 agents.
Most relevant args:
`task`: task name (or mt30/mt80 for multi-task training)
`model_size`: model size, must be one of `[1, 5, 19, 48, 317]` (default: 5)
`steps`: number of training/environment steps (default: 10M)
`seed`: random seed (default: 1)
See config.yaml for a full list of args.
Example usage:
```
$ python train.py task=mt80 model_size=48
$ python train.py task=mt30 model_size=317
$ python train.py task=dog-run steps=7000000
```
"""
assert torch.cuda.is_available()
assert cfg.steps > 0, 'Must train for at least 1 step.'
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
trainer = trainer_cls(
cfg=cfg,
env=make_env(cfg),
agent=TDMPC2(cfg),
buffer=Buffer(cfg),
logger=Logger(cfg),
)
trainer.train()
print('\nTraining completed successfully')
if __name__ == '__main__':
train()

View File

19
tdmpc2/trainer/base.py Executable file
View File

@@ -0,0 +1,19 @@
class Trainer:
"""Base trainer class for TD-MPC2."""
def __init__(self, cfg, env, agent, buffer, logger):
self.cfg = cfg
self.env = env
self.agent = agent
self.buffer = buffer
self.logger = logger
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
print('Architecture:', self.agent.model)
def eval(self):
"""Evaluate a TD-MPC2 agent."""
raise NotImplementedError
def train(self):
"""Train a TD-MPC2 agent."""
raise NotImplementedError

View File

@@ -0,0 +1,92 @@
import os
from copy import deepcopy
from time import time
from pathlib import Path
from glob import glob
import numpy as np
import torch
from tqdm import tqdm
from common.buffer import Buffer
from trainer.base import Trainer
class OfflineTrainer(Trainer):
"""Trainer class for multi-task offline TD-MPC2 training."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._start_time = time()
def eval(self):
"""Evaluate a TD-MPC2 agent."""
results = dict()
for task_idx in tqdm(range(len(self.cfg.tasks)), desc='Evaluating'):
ep_rewards, ep_successes = [], []
for _ in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0
while not done:
action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx)
obs, reward, done, info = self.env.step(action)
ep_reward += reward
t += 1
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
results.update({
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
return results
def train(self):
"""Train a TD-MPC2 agent."""
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
'Offline training only supports multitask training with mt30 or mt80 task sets.'
# Load data
assert self.cfg.task in self.cfg.data_dir, \
f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \
f'please double-check your config.'
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}'
print(f'Found {len(fps)} files in {fp}')
# Create buffer for sampling
_cfg = deepcopy(self.cfg)
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
_cfg.steps = _cfg.buffer_size
self.buffer = Buffer(_cfg)
for fp in tqdm(fps, desc='Loading data'):
td = torch.load(fp)
assert td.shape[1] == _cfg.episode_length, \
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
f'please double-check your config.'
for i in range(len(td)):
self.buffer.add(td[i])
assert self.buffer.num_eps == self.buffer.capacity, \
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {}
for i in range(self.cfg.steps):
# Update agent
train_metrics = self.agent.update(self.buffer)
# Evaluate agent periodically
if i % self.cfg.eval_freq == 0 or i == 10_000:
metrics = {
'iteration': i,
'total_time': time() - self._start_time,
}
metrics.update(train_metrics)
if i % self.cfg.eval_freq == 0:
metrics.update(self.eval())
self.logger.pprint_multitask(metrics, self.cfg)
if i > 0:
self.logger.save_agent(self.agent, identifier=f'{i}')
self.logger.log(metrics, 'pretrain')
self.logger.finish(self.agent)

117
tdmpc2/trainer/online_trainer.py Executable file
View File

@@ -0,0 +1,117 @@
from time import time
import numpy as np
import torch
from tensordict.tensordict import TensorDict
from trainer.base import Trainer
class OnlineTrainer(Trainer):
"""Trainer class for single-task online TD-MPC2 training."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._step = 0
self._ep_idx = 0
self._start_time = time()
def common_metrics(self):
"""Return a dictionary of current metrics."""
return dict(
step=self._step,
episode=self._ep_idx,
total_time=time() - self._start_time,
)
def eval(self):
"""Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], []
for i in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0))
while not done:
action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action)
ep_reward += reward
t += 1
if self.cfg.save_video:
self.logger.video.record(self.env)
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
if self.cfg.save_video:
self.logger.video.save(self._step)
return dict(
episode_reward=np.nanmean(ep_rewards),
episode_success=np.nanmean(ep_successes),
)
def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode."""
if isinstance(obs, dict):
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
else:
obs = obs.unsqueeze(0).cpu()
if action is None:
action = torch.empty_like(self.env.rand_act())
if reward is None:
reward = torch.tensor(float('nan'))
td = TensorDict(dict(
obs=obs,
action=action.unsqueeze(0),
reward=reward.unsqueeze(0),
), batch_size=(1,))
return td
def train(self):
"""Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, True
while self._step <= self.cfg.steps:
# Evaluate agent periodically
if self._step % self.cfg.eval_freq == 0:
eval_next = True
# Reset environment
if done:
if eval_next:
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())
self.logger.log(eval_metrics, 'eval')
eval_next = False
if self._step > 0:
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset()
self._tds = [self.to_td(obs)]
# Collect experience
if self._step > self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1)
else:
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward))
# Update agent
if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps
print('Pretraining agent on seed data...')
else:
num_updates = 1
for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
self._step += 1
self.logger.finish(self.agent)