first commit
This commit is contained in:
0
tdmpc2/__init__.py
Executable file
0
tdmpc2/__init__.py
Executable file
60
tdmpc2/common/__init__.py
Normal file
60
tdmpc2/common/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
MODEL_SIZE = { # parameters (M)
|
||||
1: {'enc_dim': 256,
|
||||
'mlp_dim': 384,
|
||||
'latent_dim': 128,
|
||||
'num_enc_layers': 2,
|
||||
'num_q': 2},
|
||||
5: {'enc_dim': 256,
|
||||
'mlp_dim': 512,
|
||||
'latent_dim': 512,
|
||||
'num_enc_layers': 2},
|
||||
19: {'enc_dim': 1024,
|
||||
'mlp_dim': 1024,
|
||||
'latent_dim': 768,
|
||||
'num_enc_layers': 3},
|
||||
48: {'enc_dim': 1792,
|
||||
'mlp_dim': 1792,
|
||||
'latent_dim': 768,
|
||||
'num_enc_layers': 4},
|
||||
317: {'enc_dim': 4096,
|
||||
'mlp_dim': 4096,
|
||||
'latent_dim': 1376,
|
||||
'num_enc_layers': 5,
|
||||
'num_q': 8},
|
||||
}
|
||||
|
||||
TASK_SET = {
|
||||
'mt30': [
|
||||
# 19 original dmcontrol tasks
|
||||
'walker-stand', 'walker-walk', 'walker-run', 'cheetah-run', 'reacher-easy',
|
||||
'reacher-hard', 'acrobot-swingup', 'pendulum-swingup', 'cartpole-balance', 'cartpole-balance-sparse',
|
||||
'cartpole-swingup', 'cartpole-swingup-sparse', 'cup-catch', 'finger-spin', 'finger-turn-easy',
|
||||
'finger-turn-hard', 'fish-swim', 'hopper-stand', 'hopper-hop',
|
||||
# 11 custom dmcontrol tasks
|
||||
'walker-walk-backwards', 'walker-run-backwards', 'cheetah-run-backwards', 'cheetah-run-front', 'cheetah-run-back',
|
||||
'cheetah-jump', 'hopper-hop-backwards', 'reacher-three-easy', 'reacher-three-hard', 'cup-spin',
|
||||
'pendulum-spin',
|
||||
],
|
||||
'mt80': [
|
||||
# 19 original dmcontrol tasks
|
||||
'walker-stand', 'walker-walk', 'walker-run', 'cheetah-run', 'reacher-easy',
|
||||
'reacher-hard', 'acrobot-swingup', 'pendulum-swingup', 'cartpole-balance', 'cartpole-balance-sparse',
|
||||
'cartpole-swingup', 'cartpole-swingup-sparse', 'cup-catch', 'finger-spin', 'finger-turn-easy',
|
||||
'finger-turn-hard', 'fish-swim', 'hopper-stand', 'hopper-hop',
|
||||
# 11 custom dmcontrol tasks
|
||||
'walker-walk-backwards', 'walker-run-backwards', 'cheetah-run-backwards', 'cheetah-run-front', 'cheetah-run-back',
|
||||
'cheetah-jump', 'hopper-hop-backwards', 'reacher-three-easy', 'reacher-three-hard', 'cup-spin',
|
||||
'pendulum-spin',
|
||||
# meta-world mt50
|
||||
'mw-assembly', 'mw-basketball', 'mw-button-press-topdown', 'mw-button-press-topdown-wall', 'mw-button-press',
|
||||
'mw-button-press-wall', 'mw-coffee-button', 'mw-coffee-pull', 'mw-coffee-push', 'mw-dial-turn',
|
||||
'mw-disassemble', 'mw-door-open', 'mw-door-close', 'mw-drawer-close', 'mw-drawer-open',
|
||||
'mw-faucet-open', 'mw-faucet-close', 'mw-hammer', 'mw-handle-press-side', 'mw-handle-press',
|
||||
'mw-handle-pull-side', 'mw-handle-pull', 'mw-lever-pull', 'mw-peg-insert-side', 'mw-peg-unplug-side',
|
||||
'mw-pick-out-of-hole', 'mw-pick-place', 'mw-pick-place-wall', 'mw-plate-slide', 'mw-plate-slide-side',
|
||||
'mw-plate-slide-back', 'mw-plate-slide-back-side', 'mw-push-back', 'mw-push', 'mw-push-wall',
|
||||
'mw-reach', 'mw-reach-wall', 'mw-shelf-place', 'mw-soccer', 'mw-stick-push',
|
||||
'mw-stick-pull', 'mw-sweep-into', 'mw-sweep', 'mw-window-open', 'mw-window-close',
|
||||
'mw-bin-picking', 'mw-box-close', 'mw-door-lock', 'mw-door-unlock', 'mw-hand-insert',
|
||||
],
|
||||
}
|
||||
115
tdmpc2/common/buffer.py
Normal file
115
tdmpc2/common/buffer.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from tensordict.tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||
from torchrl.data.replay_buffers.samplers import RandomSampler
|
||||
from torchrl.envs import RandomCropTensorDict, Transform, Compose
|
||||
|
||||
from common.logger import make_dir
|
||||
|
||||
|
||||
class DataPrepTransform(Transform):
|
||||
"""
|
||||
Preprocesses data for TD-MPC2 training.
|
||||
Replay data is expected to be a TensorDict with the following keys:
|
||||
obs: observations
|
||||
action: actions
|
||||
reward: rewards
|
||||
task: task IDs (optional)
|
||||
A TensorDict with T time steps has T+1 observations and T actions and rewards.
|
||||
The first actions and rewards in each TensorDict are dummies and should be ignored.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__([])
|
||||
|
||||
def forward(self, td):
|
||||
td = td.permute(1,0)
|
||||
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
|
||||
|
||||
|
||||
class Buffer():
|
||||
"""
|
||||
Create a replay buffer for TD-MPC2 training.
|
||||
Uses CUDA memory if available, and CPU memory otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._device = torch.device('cuda')
|
||||
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
|
||||
self._num_eps = 0
|
||||
|
||||
@property
|
||||
def capacity(self):
|
||||
"""Return the capacity of the buffer."""
|
||||
return self._capacity
|
||||
|
||||
@property
|
||||
def num_eps(self):
|
||||
"""Return the number of episodes in the buffer."""
|
||||
return self._num_eps
|
||||
|
||||
def _reserve_buffer(self, storage):
|
||||
"""
|
||||
Reserve a buffer with the given storage.
|
||||
Uses the RandomSampler to sample trajectories,
|
||||
and the RandomCropTensorDict transform to crop trajectories to the desired length.
|
||||
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
|
||||
"""
|
||||
return ReplayBuffer(
|
||||
storage=storage,
|
||||
sampler=RandomSampler(),
|
||||
pin_memory=True,
|
||||
prefetch=1,
|
||||
transform=Compose(
|
||||
RandomCropTensorDict(self.cfg.horizon+1, -1),
|
||||
DataPrepTransform(),
|
||||
),
|
||||
batch_size=self.cfg.batch_size,
|
||||
)
|
||||
|
||||
def _init(self, tds):
|
||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||
mem_free, _ = torch.cuda.mem_get_info()
|
||||
bytes_per_ep = sum([
|
||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
||||
for k,v in tds.items()
|
||||
])
|
||||
print(f'Bytes per episode: {bytes_per_ep:,}')
|
||||
total_bytes = bytes_per_ep*self._capacity
|
||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||
# Heuristic: decide whether to use CUDA or CPU memory
|
||||
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
||||
print('Using CPU memory for storage.')
|
||||
return self._reserve_buffer(
|
||||
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
|
||||
)
|
||||
else: # Sufficient CUDA memory
|
||||
print('Using CUDA memory for storage.')
|
||||
return self._reserve_buffer(
|
||||
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
||||
)
|
||||
|
||||
def add(self, tds):
|
||||
"""Add an episode to the buffer. All episodes are expected to have the same length."""
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(tds)
|
||||
self._buffer.add(tds)
|
||||
self._num_eps += 1
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of sub-trajectories from the buffer."""
|
||||
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
|
||||
return obs.to(self._device, non_blocking=True), \
|
||||
action.to(self._device, non_blocking=True), \
|
||||
reward.to(self._device, non_blocking=True), \
|
||||
task.to(self._device, non_blocking=True) if task is not None else None
|
||||
|
||||
def save(self):
|
||||
"""Save the buffer to disk. Useful for storing offline datasets."""
|
||||
td = self._buffer._storage._storage.cpu()
|
||||
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
|
||||
torch.save(td, fp)
|
||||
22
tdmpc2/common/init.py
Normal file
22
tdmpc2/common/init.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def weight_init(m):
|
||||
"""Custom weight initialization for TD-MPC2."""
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Embedding):
|
||||
nn.init.uniform_(m.weight, -0.02, 0.02)
|
||||
elif isinstance(m, nn.ParameterList):
|
||||
for i,p in enumerate(m):
|
||||
if p.dim() == 3: # Linear
|
||||
nn.init.trunc_normal_(p, std=0.02) # Weight
|
||||
nn.init.constant_(m[i+1], 0) # Bias
|
||||
|
||||
|
||||
def zero_(params):
|
||||
"""Initialize parameters to zero."""
|
||||
for p in params:
|
||||
p.data.fill_(0)
|
||||
97
tdmpc2/common/layers.py
Normal file
97
tdmpc2/common/layers.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functorch import combine_state_for_ensemble
|
||||
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
"""
|
||||
|
||||
def __init__(self, modules, **kwargs):
|
||||
super().__init__()
|
||||
modules = nn.ModuleList(modules)
|
||||
fn, params, _ = combine_state_for_ensemble(modules)
|
||||
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs)
|
||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||
self._repr = str(modules)
|
||||
|
||||
def modules(self):
|
||||
return self.vmap.__wrapped__.stateless_model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Vectorized ' + self._repr
|
||||
|
||||
|
||||
class SimNorm(nn.Module):
|
||||
"""
|
||||
Simplicial normalization.
|
||||
Adapted from https://arxiv.org/abs/2204.00616.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.dim = cfg.simnorm_dim
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.shape
|
||||
x = x.view(*shp[:-1], -1, self.dim)
|
||||
x = F.softmax(x, dim=-1)
|
||||
return x.view(*shp)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SimNorm(dim={self.dim})"
|
||||
|
||||
|
||||
class NormedLinear(nn.Linear):
|
||||
"""
|
||||
Linear layer with LayerNorm, activation, and optionally dropout.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ln = nn.LayerNorm(self.out_features)
|
||||
self.act = act
|
||||
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.dropout:
|
||||
x = self.dropout(x)
|
||||
return self.act(self.ln(x))
|
||||
|
||||
def __repr__(self):
|
||||
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
||||
return f"NormedLinear(in_features={self.in_features}, "\
|
||||
f"out_features={self.out_features}, "\
|
||||
f"bias={self.bias is not None}{repr_dropout}, "\
|
||||
f"act={self.act.__class__.__name__})"
|
||||
|
||||
|
||||
def enc(cfg, out={}):
|
||||
"""
|
||||
Returns a dictionary of encoders for each observation in the dict.
|
||||
"""
|
||||
for k in cfg.obs_shape.keys():
|
||||
assert k == 'state'
|
||||
out[k] = mlp(cfg.obs_shape[k][0] + cfg.task_dim, max(cfg.num_enc_layers-1, 1)*[cfg.enc_dim], cfg.latent_dim, act=SimNorm(cfg))
|
||||
return nn.ModuleDict(out)
|
||||
|
||||
|
||||
def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.):
|
||||
"""
|
||||
Basic building block of TD-MPC2.
|
||||
MLP with LayerNorm, Mish activations, and optionally dropout.
|
||||
"""
|
||||
if isinstance(mlp_dims, int):
|
||||
mlp_dims = [mlp_dims]
|
||||
dims = [in_dim] + mlp_dims + [out_dim]
|
||||
mlp = nn.ModuleList()
|
||||
for i in range(len(dims) - 2):
|
||||
mlp.append(NormedLinear(dims[i], dims[i+1], dropout=dropout*(i==0)))
|
||||
mlp.append(NormedLinear(dims[-2], dims[-1], act=act) if act else nn.Linear(dims[-2], dims[-1]))
|
||||
return nn.Sequential(*mlp)
|
||||
238
tdmpc2/common/logger.py
Executable file
238
tdmpc2/common/logger.py
Executable file
@@ -0,0 +1,238 @@
|
||||
import os
|
||||
import datetime
|
||||
import re
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from termcolor import colored
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from common import TASK_SET
|
||||
|
||||
|
||||
CONSOLE_FORMAT = [
|
||||
("iteration", "I", "int"),
|
||||
("episode", "E", "int"),
|
||||
("step", "I", "int"),
|
||||
("episode_reward", "R", "float"),
|
||||
("episode_success", "S", "float"),
|
||||
("total_time", "T", "time"),
|
||||
]
|
||||
|
||||
CAT_TO_COLOR = {
|
||||
"pretrain": "yellow",
|
||||
"train": "blue",
|
||||
"eval": "green",
|
||||
}
|
||||
|
||||
|
||||
def make_dir(dir_path):
|
||||
"""Create directory if it does not already exist."""
|
||||
try:
|
||||
os.makedirs(dir_path)
|
||||
except OSError:
|
||||
pass
|
||||
return dir_path
|
||||
|
||||
|
||||
def print_run(cfg):
|
||||
"""
|
||||
Pretty-printing of current run information.
|
||||
Logger calls this method at initialization.
|
||||
"""
|
||||
prefix, color, attrs = " ", "green", ["bold"]
|
||||
|
||||
def _limstr(s, maxlen=36):
|
||||
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
|
||||
|
||||
def _pprint(k, v):
|
||||
print(
|
||||
prefix + colored(f'{k.capitalize()+":":<15}', color, attrs=attrs), _limstr(v)
|
||||
)
|
||||
|
||||
obs_dim = cfg.obs_shape['state'][0] if 'state' in cfg.obs_shape else cfg.obs_shape[0]
|
||||
kvs = [
|
||||
("task", cfg.task_title),
|
||||
("steps", f"{int(cfg.steps):,}"),
|
||||
("observations", obs_dim),
|
||||
("actions", cfg.action_dim),
|
||||
("experiment", cfg.exp_name),
|
||||
]
|
||||
w = np.max([len(_limstr(str(kv[1]))) for kv in kvs]) + 25
|
||||
div = "-" * w
|
||||
print(div)
|
||||
for k, v in kvs:
|
||||
_pprint(k, v)
|
||||
print(div)
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
"""
|
||||
Return a wandb-safe group name for logging.
|
||||
Optionally returns group name as list.
|
||||
"""
|
||||
lst = [cfg.task, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
class VideoRecorder:
|
||||
"""Utility class for logging evaluation videos."""
|
||||
|
||||
def __init__(self, cfg, wandb, fps=15):
|
||||
self.cfg = cfg
|
||||
self._save_dir = make_dir(cfg.work_dir / 'eval_video')
|
||||
self._wandb = wandb
|
||||
self.fps = fps
|
||||
self.frames = []
|
||||
self.enabled = False
|
||||
|
||||
def init(self, env, enabled=True):
|
||||
self.frames = []
|
||||
self.enabled = self._save_dir and self._wandb and enabled
|
||||
self.record(env)
|
||||
|
||||
def record(self, env):
|
||||
if self.enabled:
|
||||
self.frames.append(env.render())
|
||||
|
||||
def save(self, step, key='videos/eval_video'):
|
||||
if self.enabled and len(self.frames) > 0:
|
||||
frames = np.stack(self.frames)
|
||||
return self._wandb.log(
|
||||
{key: self._wandb.Video(frames.transpose(0, 3, 1, 2), fps=self.fps, format='mp4')}, step=step
|
||||
)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logging object. Logs either locally or using wandb."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self._log_dir = make_dir(cfg.work_dir)
|
||||
self._model_dir = make_dir(self._log_dir / "models")
|
||||
self._save_csv = cfg.save_csv
|
||||
self._save_agent = cfg.save_agent
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
self._eval = []
|
||||
print_run(cfg)
|
||||
self.project = cfg.get("wandb_project", "none")
|
||||
self.entity = cfg.get("wandb_entity", "none")
|
||||
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
||||
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
||||
cfg.save_agent = False
|
||||
cfg.save_video = False
|
||||
self._wandb = None
|
||||
self._video = None
|
||||
return
|
||||
os.environ["WANDB_SILENT"] = "true" if cfg.wandb_silent else "false"
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=self.project,
|
||||
entity=self.entity,
|
||||
name=str(cfg.seed),
|
||||
group=self._group,
|
||||
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
|
||||
dir=self._log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
self._wandb = wandb
|
||||
self._video = (
|
||||
VideoRecorder(cfg, self._wandb)
|
||||
if self._wandb and cfg.save_video
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def video(self):
|
||||
return self._video
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return self._model_dir
|
||||
|
||||
def save_agent(self, agent=None, identifier='final'):
|
||||
if self._save_agent and agent:
|
||||
fp = self._model_dir / f'{str(identifier)}.pt'
|
||||
agent.save(fp)
|
||||
if self._wandb:
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group + '-' + str(self._seed) + '-' + str(identifier),
|
||||
type='model',
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def finish(self, agent=None):
|
||||
try:
|
||||
self.save_agent(agent)
|
||||
except Exception as e:
|
||||
print(colored(f"Failed to save model: {e}", "red"))
|
||||
if self._wandb:
|
||||
self._wandb.finish()
|
||||
|
||||
def _format(self, key, value, ty):
|
||||
if ty == "int":
|
||||
return f'{colored(key+":", "blue")} {int(value):,}'
|
||||
elif ty == "float":
|
||||
return f'{colored(key+":", "blue")} {value:.01f}'
|
||||
elif ty == "time":
|
||||
value = str(datetime.timedelta(seconds=int(value)))
|
||||
return f'{colored(key+":", "blue")} {value}'
|
||||
else:
|
||||
raise f"invalid log format type: {ty}"
|
||||
|
||||
def _print(self, d, category):
|
||||
category = colored(category, CAT_TO_COLOR[category])
|
||||
pieces = [f" {category:<14}"]
|
||||
for k, disp_k, ty in CONSOLE_FORMAT:
|
||||
if k in d:
|
||||
pieces.append(f"{self._format(disp_k, d[k], ty):<22}")
|
||||
print(" ".join(pieces))
|
||||
|
||||
def pprint_multitask(self, d, cfg):
|
||||
"""Pretty-print evaluation metrics for multi-task training."""
|
||||
print(colored(f'Evaluated agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold']))
|
||||
dmcontrol_reward = []
|
||||
metaworld_reward = []
|
||||
metaworld_success = []
|
||||
for k, v in d.items():
|
||||
if '+' not in k:
|
||||
continue
|
||||
task = k.split('+')[1]
|
||||
if task in TASK_SET['mt30'] and k.startswith('episode_reward'): # DMControl
|
||||
dmcontrol_reward.append(v)
|
||||
print(colored(f' {task:<22}\tR: {v:.01f}', 'yellow'))
|
||||
elif task in TASK_SET['mt80'] and task not in TASK_SET['mt30']: # Meta-World
|
||||
if k.startswith('episode_reward'):
|
||||
metaworld_reward.append(v)
|
||||
elif k.startswith('episode_success'):
|
||||
metaworld_success.append(v)
|
||||
print(colored(f' {task:<22}\tS: {v:.02f}', 'yellow'))
|
||||
dmcontrol_reward = np.nanmean(dmcontrol_reward)
|
||||
d['episode_reward+avg_dmcontrol'] = dmcontrol_reward
|
||||
print(colored(f' {"dmcontrol":<22}\tR: {dmcontrol_reward:.01f}', 'yellow', attrs=['bold']))
|
||||
if cfg.task == 'mt80':
|
||||
metaworld_reward = np.nanmean(metaworld_reward)
|
||||
metaworld_success = np.nanmean(metaworld_success)
|
||||
d['episode_reward+avg_metaworld'] = metaworld_reward
|
||||
d['episode_success+avg_metaworld'] = metaworld_success
|
||||
print(colored(f' {"metaworld":<22}\tR: {metaworld_reward:.01f}', 'yellow', attrs=['bold']))
|
||||
print(colored(f' {"metaworld":<22}\tS: {metaworld_success:.02f}', 'yellow', attrs=['bold']))
|
||||
|
||||
def log(self, d, category="train"):
|
||||
assert category in CAT_TO_COLOR.keys(), f"invalid category: {category}"
|
||||
if self._wandb:
|
||||
if category in {"train", "eval"}:
|
||||
xkey = "step"
|
||||
elif category == "pretrain":
|
||||
xkey = "iteration"
|
||||
for k, v in d.items():
|
||||
self._wandb.log({category + "/" + k: v}, step=d[xkey])
|
||||
if category == "eval" and self._save_csv:
|
||||
keys = ["step", "episode_reward"]
|
||||
self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
|
||||
pd.DataFrame(np.array(self._eval)).to_csv(
|
||||
self._log_dir / "eval.csv", header=keys, index=None
|
||||
)
|
||||
self._print(d, category)
|
||||
95
tdmpc2/common/math.py
Normal file
95
tdmpc2/common/math.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def soft_ce(pred, target, cfg):
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def log_std(x, low, dif):
|
||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_residual(eps, log_std):
|
||||
return -0.5 * eps.pow(2) - log_std
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_logprob(residual):
|
||||
return residual - 0.5 * torch.log(2 * torch.pi)
|
||||
|
||||
|
||||
def gaussian_logprob(eps, log_std, size=None):
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
||||
if size is None:
|
||||
size = eps.size(-1)
|
||||
return _gaussian_logprob(residual) * size
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _squash(pi):
|
||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||
|
||||
|
||||
def squash(mu, pi, log_pi):
|
||||
"""Apply squashing function."""
|
||||
mu = torch.tanh(mu)
|
||||
pi = torch.tanh(pi)
|
||||
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symlog(x):
|
||||
"""
|
||||
Symmetric logarithmic function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symexp(x):
|
||||
"""
|
||||
Symmetric exponential function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
||||
|
||||
|
||||
def two_hot(x, cfg):
|
||||
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symlog(x)
|
||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
|
||||
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
|
||||
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
|
||||
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
|
||||
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
|
||||
return soft_two_hot
|
||||
|
||||
|
||||
DREG_BINS = None
|
||||
|
||||
|
||||
def two_hot_inv(x, cfg):
|
||||
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
||||
global DREG_BINS
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symexp(x)
|
||||
if DREG_BINS is None:
|
||||
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
|
||||
x = F.softmax(x, dim=-1)
|
||||
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
|
||||
return symexp(x)
|
||||
60
tdmpc2/common/parser.py
Executable file
60
tdmpc2/common/parser.py
Executable file
@@ -0,0 +1,60 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from common import MODEL_SIZE, TASK_SET
|
||||
|
||||
|
||||
def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
||||
"""
|
||||
Parses a Hydra config. Mostly for convenience.
|
||||
"""
|
||||
|
||||
# Logic
|
||||
for k in cfg.keys():
|
||||
try:
|
||||
v = cfg[k]
|
||||
if v == None:
|
||||
v = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Algebraic expressions
|
||||
for k in cfg.keys():
|
||||
try:
|
||||
v = cfg[k]
|
||||
if isinstance(v, str):
|
||||
match = re.match(r"(\d+)([+\-*/])(\d+)", v)
|
||||
if match:
|
||||
cfg[k] = eval(match.group(1) + match.group(2) + match.group(3))
|
||||
if isinstance(cfg[k], float) and cfg[k].is_integer():
|
||||
cfg[k] = int(cfg[k])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Convenience
|
||||
cfg.work_dir = Path(hydra.utils.get_original_cwd()) / 'logs' / cfg.task / str(cfg.seed) / cfg.exp_name
|
||||
cfg.task_title = cfg.task.replace("-", " ").title()
|
||||
cfg.bin_size = (cfg.vmax - cfg.vmin) / (cfg.num_bins-1) # Bin size for discrete regression
|
||||
|
||||
# Model size
|
||||
assert cfg.model_size in MODEL_SIZE.keys(), \
|
||||
f'Invalid model size {cfg.model_size}. Must be one of {list(MODEL_SIZE.keys())}'
|
||||
for k, v in MODEL_SIZE[cfg.model_size].items():
|
||||
cfg[k] = v
|
||||
if cfg.task == 'mt30' and cfg.model_size == 19:
|
||||
cfg.latent_dim = 512 # This checkpoint is slightly smaller
|
||||
|
||||
# Multi-task
|
||||
cfg.multitask = cfg.task in TASK_SET.keys()
|
||||
if cfg.multitask:
|
||||
cfg.task_title = cfg.task.upper()
|
||||
# Account for slight inconsistency in task_dim for the mt30 experiments
|
||||
cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.model_size in {1, 317} else 64
|
||||
else:
|
||||
cfg.task_dim = 0
|
||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||
|
||||
return cfg
|
||||
48
tdmpc2/common/scale.py
Normal file
48
tdmpc2/common/scale.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
|
||||
|
||||
class RunningScale:
|
||||
"""Running trimmed scale estimator."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))
|
||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))
|
||||
|
||||
def state_dict(self):
|
||||
return dict(value=self._value, percentiles=self._percentiles)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._value.data.copy_(state_dict['value'])
|
||||
self._percentiles.data.copy_(state_dict['percentiles'])
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value.cpu().item()
|
||||
|
||||
def _percentile(self, x):
|
||||
x_dtype, x_shape = x.dtype, x.shape
|
||||
x = x.view(x.shape[0], -1)
|
||||
in_sorted, _ = torch.sort(x, dim=0)
|
||||
positions = self._percentiles * (x.shape[0]-1) / 100
|
||||
floored = torch.floor(positions)
|
||||
ceiled = floored + 1
|
||||
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
|
||||
weight_ceiled = positions-floored
|
||||
weight_floored = 1.0 - weight_ceiled
|
||||
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
|
||||
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
|
||||
return (d0+d1).view(-1, *x_shape[1:]).type(x_dtype)
|
||||
|
||||
def update(self, x):
|
||||
percentiles = self._percentile(x.detach())
|
||||
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
|
||||
self._value.data.lerp_(value, self.cfg.tau)
|
||||
|
||||
def __call__(self, x, update=False):
|
||||
if update:
|
||||
self.update(x)
|
||||
return x * (1/self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f'RunningScale(S: {self.value})'
|
||||
12
tdmpc2/common/seed.py
Normal file
12
tdmpc2/common/seed.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
174
tdmpc2/common/world_model.py
Normal file
174
tdmpc2/common/world_model.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from common import layers, math, init
|
||||
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
"""
|
||||
TD-MPC2 implicit world model architecture.
|
||||
Can be used for both single-task and multi-task experiments.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
if cfg.multitask:
|
||||
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
||||
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
|
||||
for i in range(len(cfg.tasks)):
|
||||
self._action_masks[i, :cfg.action_dims[i]] = 1.
|
||||
self._encoder = layers.enc(cfg)
|
||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.apply(init.weight_init)
|
||||
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
|
||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
||||
self.log_std_min = torch.tensor(cfg.log_std_min)
|
||||
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
|
||||
|
||||
@property
|
||||
def total_params(self):
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
Overriding `to` method to also move additional tensors to device.
|
||||
"""
|
||||
super().to(*args, **kwargs)
|
||||
if self.cfg.multitask:
|
||||
self._action_masks = self._action_masks.to(*args, **kwargs)
|
||||
self.log_std_min = self.log_std_min.to(*args, **kwargs)
|
||||
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def train(self, mode=True):
|
||||
"""
|
||||
Overriding `train` method to keep target Q-networks in eval mode.
|
||||
"""
|
||||
super().train(mode)
|
||||
self._target_Qs.train(False)
|
||||
return self
|
||||
|
||||
def track_q_grad(self, mode=True):
|
||||
"""
|
||||
Enables/disables gradient tracking of Q-networks.
|
||||
Avoids unnecessary computation during policy optimization.
|
||||
This method also enables/disables gradients for task embeddings,
|
||||
and sets the dropout probability to 0 if `mode` is False.
|
||||
"""
|
||||
for p in self._Qs.parameters():
|
||||
p.requires_grad_(mode)
|
||||
if self.cfg.multitask:
|
||||
for p in self._task_emb.parameters():
|
||||
p.requires_grad_(mode)
|
||||
for m in self._Qs.modules():
|
||||
if isinstance(m, nn.Dropout):
|
||||
m.p = self.cfg.dropout if mode else 0
|
||||
|
||||
def soft_update_target_Q(self):
|
||||
"""
|
||||
Soft-update target Q-networks using Polyak averaging.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
|
||||
p_target.data.lerp_(p.data, self.cfg.tau)
|
||||
|
||||
def task_emb(self, x, task):
|
||||
"""
|
||||
Continuous task embedding for multi-task experiments.
|
||||
Retrieves the task embedding for a given task ID `task`
|
||||
and concatenates it to the input `x`.
|
||||
"""
|
||||
if isinstance(task, int):
|
||||
task = torch.tensor([task], device=x.device)
|
||||
emb = self._task_emb(task.long())
|
||||
if x.ndim == 3:
|
||||
emb = emb.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
||||
elif emb.shape[0] == 1:
|
||||
emb = emb.repeat(x.shape[0], 1)
|
||||
return torch.cat([x, emb], dim=-1)
|
||||
|
||||
def encode(self, obs, task):
|
||||
"""
|
||||
Encodes an observation into its latent representation.
|
||||
This implementation assumes a single state-based observation.
|
||||
"""
|
||||
if self.cfg.multitask:
|
||||
obs = self.task_emb(obs, task)
|
||||
return self._encoder['state'](obs)
|
||||
|
||||
def next(self, z, a, task):
|
||||
"""
|
||||
Predicts the next latent state given the current latent state and action.
|
||||
"""
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(z)
|
||||
|
||||
def reward(self, z, a, task):
|
||||
"""
|
||||
Predicts instantaneous (single-step) reward.
|
||||
"""
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
return self._reward(z)
|
||||
|
||||
def pi(self, z, task):
|
||||
"""
|
||||
Samples an action from the policy prior.
|
||||
The policy prior is a Gaussian distribution with
|
||||
mean and (log) std predicted by a neural network.
|
||||
"""
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
|
||||
# Gaussian policy prior
|
||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
||||
eps = torch.randn_like(mu)
|
||||
|
||||
if self.cfg.multitask: # Mask out unused action dimensions
|
||||
mu = mu * self._action_masks[task]
|
||||
log_std = log_std * self._action_masks[task]
|
||||
eps = eps * self._action_masks[task]
|
||||
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
|
||||
else: # No masking
|
||||
action_dims = None
|
||||
|
||||
log_pi = math.gaussian_logprob(eps, log_std, size=action_dims)
|
||||
pi = mu + eps * log_std.exp()
|
||||
mu, pi, log_pi = math.squash(mu, pi, log_pi)
|
||||
|
||||
return mu, pi, log_pi, log_std
|
||||
|
||||
def Q(self, z, a, task, return_type='min', target=False):
|
||||
"""
|
||||
Predict state-action value.
|
||||
`return_type` can be one of [`min`, `avg`, `all`]:
|
||||
- `min`: return the minimum of two randomly subsampled Q-values.
|
||||
- `avg`: return the average of two randomly subsampled Q-values.
|
||||
- `all`: return all Q-values.
|
||||
`target` specifies whether to use the target Q-networks or not.
|
||||
"""
|
||||
assert return_type in {'min', 'avg', 'all'}
|
||||
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
out = (self._target_Qs if target else self._Qs)(z)
|
||||
|
||||
if return_type == 'all':
|
||||
return out
|
||||
|
||||
Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)]
|
||||
Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg)
|
||||
return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2
|
||||
86
tdmpc2/config.yaml
Executable file
86
tdmpc2/config.yaml
Executable file
@@ -0,0 +1,86 @@
|
||||
defaults:
|
||||
- override hydra/launcher: submitit_local
|
||||
|
||||
# environment
|
||||
task: dog-run
|
||||
|
||||
# evaluation
|
||||
checkpoint: ???
|
||||
eval_episodes: 10
|
||||
eval_freq: 50000
|
||||
|
||||
# training
|
||||
steps: 10_000_000
|
||||
batch_size: 256
|
||||
reward_coef: 0.1
|
||||
value_coef: 0.1
|
||||
consistency_coef: 20
|
||||
rho: 0.5
|
||||
lr: 3e-4
|
||||
enc_lr_scale: 0.3
|
||||
grad_clip_norm: 20
|
||||
tau: 0.01
|
||||
discount_denom: 5
|
||||
discount_min: 0.95
|
||||
discount_max: 0.995
|
||||
buffer_size: 1_000_000
|
||||
exp_name: default
|
||||
data_dir: ???
|
||||
|
||||
# planning
|
||||
mpc: true
|
||||
iterations: 6
|
||||
num_samples: 512
|
||||
num_elites: 64
|
||||
num_pi_trajs: 24
|
||||
horizon: 3
|
||||
min_std: 0.05
|
||||
max_std: 2
|
||||
temperature: 0.5
|
||||
|
||||
# actor
|
||||
log_std_min: -10
|
||||
log_std_max: 2
|
||||
entropy_coef: 1e-4
|
||||
|
||||
# critic
|
||||
num_bins: 101
|
||||
vmin: -10
|
||||
vmax: +10
|
||||
|
||||
# architecture
|
||||
model_size: 5
|
||||
num_enc_layers: 2
|
||||
enc_dim: 256
|
||||
mlp_dim: 512
|
||||
latent_dim: 512
|
||||
task_dim: 96
|
||||
num_q: 5
|
||||
dropout: 0.01
|
||||
simnorm_dim: 8
|
||||
|
||||
# logging
|
||||
wandb_project: ???
|
||||
wandb_entity: ???
|
||||
wandb_silent: false
|
||||
disable_wandb: true
|
||||
save_csv: true
|
||||
|
||||
# misc
|
||||
save_video: true
|
||||
save_agent: true
|
||||
seed: 1
|
||||
|
||||
# convenience
|
||||
work_dir: ???
|
||||
task_title: ???
|
||||
multitask: ???
|
||||
tasks: ???
|
||||
obs_shape: ???
|
||||
action_dim: ???
|
||||
episode_length: ???
|
||||
obs_shapes: ???
|
||||
action_dims: ???
|
||||
episode_lengths: ???
|
||||
seed_steps: ???
|
||||
bin_size: ???
|
||||
62
tdmpc2/envs/__init__.py
Normal file
62
tdmpc2/envs/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from copy import deepcopy
|
||||
import warnings
|
||||
|
||||
import gym
|
||||
|
||||
from envs.wrappers.multitask import MultitaskWrapper
|
||||
from envs.wrappers.tensor import TensorWrapper
|
||||
from envs.dmcontrol import make_env as make_dm_control_env
|
||||
from envs.maniskill import make_env as make_maniskill_env
|
||||
from envs.metaworld import make_env as make_metaworld_env
|
||||
from envs.myosuite import make_env as make_myosuite_env
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
|
||||
|
||||
def make_multitask_env(cfg):
|
||||
"""
|
||||
Make a multi-task environment for TD-MPC2 experiments.
|
||||
"""
|
||||
print('Creating multi-task environment with tasks:', cfg.tasks)
|
||||
envs = []
|
||||
for task in cfg.tasks:
|
||||
_cfg = deepcopy(cfg)
|
||||
_cfg.task = task
|
||||
_cfg.multitask = False
|
||||
env = make_env(_cfg)
|
||||
if env is None:
|
||||
raise UnknownTaskError(task)
|
||||
envs.append(env)
|
||||
env = MultitaskWrapper(cfg, envs)
|
||||
cfg.obs_shapes = env._obs_dims
|
||||
cfg.action_dims = env._action_dims
|
||||
cfg.episode_lengths = env._episode_lengths
|
||||
return env
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
"""
|
||||
Make an environment for TD-MPC2 experiments.
|
||||
"""
|
||||
gym.logger.set_level(40)
|
||||
if cfg.multitask:
|
||||
env = make_multitask_env(cfg)
|
||||
else:
|
||||
env = None
|
||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||
try:
|
||||
env = fn(cfg)
|
||||
except UnknownTaskError:
|
||||
pass
|
||||
if env is None:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
env = TensorWrapper(env)
|
||||
try: # Dict
|
||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||
except: # Box
|
||||
cfg.obs_shape = {'state': env.observation_space.shape}
|
||||
cfg.action_dim = env.action_space.shape[0]
|
||||
cfg.episode_length = env.max_episode_steps
|
||||
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
||||
return env
|
||||
200
tdmpc2/envs/dmcontrol.py
Normal file
200
tdmpc2/envs/dmcontrol.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from collections import deque, defaultdict
|
||||
from typing import Any, NamedTuple
|
||||
import dm_env
|
||||
import numpy as np
|
||||
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
|
||||
from dm_control import suite
|
||||
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
||||
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
||||
from dm_control.suite.wrappers import action_scale
|
||||
from dm_env import StepType, specs
|
||||
from envs.exceptions import UnknownTaskError
|
||||
import gym
|
||||
|
||||
|
||||
class ExtendedTimeStep(NamedTuple):
|
||||
step_type: Any
|
||||
reward: Any
|
||||
discount: Any
|
||||
observation: Any
|
||||
action: Any
|
||||
|
||||
def first(self):
|
||||
return self.step_type == StepType.FIRST
|
||||
|
||||
def mid(self):
|
||||
return self.step_type == StepType.MID
|
||||
|
||||
def last(self):
|
||||
return self.step_type == StepType.LAST
|
||||
|
||||
|
||||
class ActionRepeatWrapper(dm_env.Environment):
|
||||
def __init__(self, env, num_repeats):
|
||||
self._env = env
|
||||
self._num_repeats = num_repeats
|
||||
|
||||
def step(self, action):
|
||||
reward = 0.0
|
||||
discount = 1.0
|
||||
for i in range(self._num_repeats):
|
||||
time_step = self._env.step(action)
|
||||
reward += (time_step.reward or 0.0) * discount
|
||||
discount *= time_step.discount
|
||||
if time_step.last():
|
||||
break
|
||||
|
||||
return time_step._replace(reward=reward, discount=discount)
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def action_spec(self):
|
||||
return self._env.action_spec()
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class ActionDTypeWrapper(dm_env.Environment):
|
||||
def __init__(self, env, dtype):
|
||||
self._env = env
|
||||
wrapped_action_spec = env.action_spec()
|
||||
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
|
||||
dtype,
|
||||
wrapped_action_spec.minimum,
|
||||
wrapped_action_spec.maximum,
|
||||
'action')
|
||||
|
||||
def step(self, action):
|
||||
action = action.astype(self._env.action_spec().dtype)
|
||||
return self._env.step(action)
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def action_spec(self):
|
||||
return self._action_spec
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class ExtendedTimeStepWrapper(dm_env.Environment):
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
return self._augment_time_step(time_step)
|
||||
|
||||
def step(self, action):
|
||||
time_step = self._env.step(action)
|
||||
return self._augment_time_step(time_step, action)
|
||||
|
||||
def _augment_time_step(self, time_step, action=None):
|
||||
if action is None:
|
||||
action_spec = self.action_spec()
|
||||
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
|
||||
return ExtendedTimeStep(observation=time_step.observation,
|
||||
step_type=time_step.step_type,
|
||||
action=action,
|
||||
reward=time_step.reward or 0.0,
|
||||
discount=time_step.discount or 1.0)
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def action_spec(self):
|
||||
return self._env.action_spec()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class TimeStepToGymWrapper:
|
||||
def __init__(self, env, domain, task):
|
||||
obs_shp = []
|
||||
for v in env.observation_spec().values():
|
||||
try:
|
||||
shp = np.prod(v.shape)
|
||||
except:
|
||||
shp = 1
|
||||
obs_shp.append(shp)
|
||||
obs_shp = (int(np.sum(obs_shp)),)
|
||||
act_shp = env.action_spec().shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.full(
|
||||
obs_shp,
|
||||
-np.inf,
|
||||
dtype=np.float32),
|
||||
high=np.full(
|
||||
obs_shp,
|
||||
np.inf,
|
||||
dtype=np.float32),
|
||||
dtype=np.float32,
|
||||
)
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=np.full(act_shp, env.action_spec().minimum),
|
||||
high=np.full(act_shp, env.action_spec().maximum),
|
||||
dtype=env.action_spec().dtype)
|
||||
self.env = env
|
||||
self.domain = domain
|
||||
self.task = task
|
||||
self.max_episode_steps = 500
|
||||
self.t = 0
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self.env
|
||||
|
||||
@property
|
||||
def reward_range(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return None
|
||||
|
||||
def _obs_to_array(self, obs):
|
||||
return np.concatenate([v.flatten() for v in obs.values()])
|
||||
|
||||
def reset(self):
|
||||
self.t = 0
|
||||
return self._obs_to_array(self.env.reset().observation)
|
||||
|
||||
def step(self, action):
|
||||
self.t += 1
|
||||
time_step = self.env.step(action)
|
||||
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float)
|
||||
|
||||
def render(self, mode='rgb_array', width=384, height=384, camera_id=0):
|
||||
camera_id = dict(quadruped=2).get(self.domain, camera_id)
|
||||
return self.env.physics.render(height, width, camera_id)
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
"""
|
||||
Make DMControl environment.
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
domain, task = cfg.task.replace('-', '_').split('_', 1)
|
||||
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
|
||||
if (domain, task) not in suite.ALL_TASKS:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
env = suite.load(domain,
|
||||
task,
|
||||
task_kwargs={'random': cfg.seed},
|
||||
visualize_reward=False)
|
||||
env = ActionDTypeWrapper(env, np.float32)
|
||||
env = ActionRepeatWrapper(env, 2)
|
||||
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
||||
env = ExtendedTimeStepWrapper(env)
|
||||
env = TimeStepToGymWrapper(env, domain, task)
|
||||
return env
|
||||
4
tdmpc2/envs/exceptions.py
Normal file
4
tdmpc2/envs/exceptions.py
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
class UnknownTaskError(Exception):
|
||||
def __init__(self, task):
|
||||
super().__init__(f'Unknown task: {task}')
|
||||
79
tdmpc2/envs/maniskill.py
Normal file
79
tdmpc2/envs/maniskill.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
import mani_skill2.envs
|
||||
|
||||
|
||||
MANISKILL_TASKS = {
|
||||
'lift-cube': dict(
|
||||
env='LiftCube-v0',
|
||||
control_mode='pd_ee_delta_pos',
|
||||
),
|
||||
'pick-cube': dict(
|
||||
env='PickCube-v0',
|
||||
control_mode='pd_ee_delta_pos',
|
||||
),
|
||||
'stack-cube': dict(
|
||||
env='StackCube-v0',
|
||||
control_mode='pd_ee_delta_pos',
|
||||
),
|
||||
'pick-ycb': dict(
|
||||
env='PickSingleYCB-v0',
|
||||
control_mode='pd_ee_delta_pose',
|
||||
),
|
||||
'turn-faucet': dict(
|
||||
env='TurnFaucet-v0',
|
||||
control_mode='pd_ee_delta_pose',
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ManiSkillWrapper(gym.Wrapper):
|
||||
def __init__(self, env, cfg):
|
||||
super().__init__(env)
|
||||
self.env = env
|
||||
self.cfg = cfg
|
||||
self.observation_space = self.env.observation_space
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=np.full(self.env.action_space.shape, self.env.action_space.low.min()),
|
||||
high=np.full(self.env.action_space.shape, self.env.action_space.high.max()),
|
||||
dtype=self.env.action_space.dtype,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, action):
|
||||
reward = 0
|
||||
for _ in range(2):
|
||||
obs, r, _, info = self.env.step(action)
|
||||
reward += r
|
||||
return obs, reward, False, info
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self.env.unwrapped
|
||||
|
||||
def render(self, args, **kwargs):
|
||||
return self.env.render(mode='cameras')
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
"""
|
||||
Make ManiSkill2 environment.
|
||||
"""
|
||||
if cfg.task not in MANISKILL_TASKS:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
task_cfg = MANISKILL_TASKS[cfg.task]
|
||||
env = gym.make(
|
||||
task_cfg['env'],
|
||||
obs_mode='state',
|
||||
control_mode=task_cfg['control_mode'],
|
||||
render_camera_cfgs=dict(width=384, height=384),
|
||||
)
|
||||
env = ManiSkillWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
env.max_episode_steps = env._max_episode_steps
|
||||
return env
|
||||
52
tdmpc2/envs/metaworld.py
Normal file
52
tdmpc2/envs/metaworld.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
||||
|
||||
|
||||
class MetaWorldWrapper(gym.Wrapper):
|
||||
def __init__(self, env, cfg):
|
||||
super().__init__(env)
|
||||
self.env = env
|
||||
self.cfg = cfg
|
||||
self.camera_name = "corner2"
|
||||
self.env.model.cam_pos[2] = [0.75, 0.075, 0.7]
|
||||
self.env._freeze_rand_vec = False
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs = super().reset(**kwargs).astype(np.float32)
|
||||
self.env.step(np.zeros(self.env.action_space.shape))
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
reward = 0
|
||||
for _ in range(2):
|
||||
obs, r, _, info = self.env.step(action.copy())
|
||||
reward += r
|
||||
obs = obs.astype(np.float32)
|
||||
return obs, reward, False, info
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self.env.unwrapped
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
return self.env.render(
|
||||
offscreen=True, resolution=(384, 384), camera_name=self.camera_name
|
||||
).copy()
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
"""
|
||||
Make Meta-World environment.
|
||||
"""
|
||||
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
|
||||
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
|
||||
env = MetaWorldWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
env.max_episode_steps = env._max_episode_steps
|
||||
return env
|
||||
59
tdmpc2/envs/myosuite.py
Normal file
59
tdmpc2/envs/myosuite.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
|
||||
MYOSUITE_TASKS = {
|
||||
'myo-finger-reach': 'myoFingerReachFixed-v0',
|
||||
'myo-finger-reach-hard': 'myoFingerReachRandom-v0',
|
||||
'myo-finger-pose': 'myoFingerPoseFixed-v0',
|
||||
'myo-finger-pose-hard': 'myoFingerPoseRandom-v0',
|
||||
'myo-hand-reach': 'myoHandReachFixed-v0',
|
||||
'myo-hand-reach-hard': 'myoHandReachRandom-v0',
|
||||
'myo-hand-pose': 'myoHandPoseFixed-v0',
|
||||
'myo-hand-pose-hard': 'myoHandPoseRandom-v0',
|
||||
'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0',
|
||||
'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0',
|
||||
'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
|
||||
'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
|
||||
'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
|
||||
'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
|
||||
}
|
||||
|
||||
|
||||
class MyoSuiteWrapper(gym.Wrapper):
|
||||
def __init__(self, env, cfg):
|
||||
super().__init__(env)
|
||||
self.env = env
|
||||
self.cfg = cfg
|
||||
self.camera_id = 'hand_side_inter'
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, _, info = self.env.step(action.copy())
|
||||
obs = obs.astype(np.float32)
|
||||
info['success'] = info['solved']
|
||||
return obs, reward, False, info
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self.env.unwrapped
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
return self.env.sim.renderer.render_offscreen(
|
||||
width=384, height=384, camera_id=self.camera_id
|
||||
).copy()
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
"""
|
||||
Make Myosuite environment.
|
||||
"""
|
||||
if not cfg.task in MYOSUITE_TASKS:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
import myosuite
|
||||
env = gym.make(MYOSUITE_TASKS[cfg.task])
|
||||
env = MyoSuiteWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
env.max_episode_steps = env._max_episode_steps
|
||||
return env
|
||||
0
tdmpc2/envs/tasks/__init__.py
Normal file
0
tdmpc2/envs/tasks/__init__.py
Normal file
99
tdmpc2/envs/tasks/ball_in_cup.py
Normal file
99
tdmpc2/envs/tasks/ball_in_cup.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import collections
|
||||
import os
|
||||
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
from dm_control.suite import ball_in_cup
|
||||
from dm_control.suite import common
|
||||
from dm_control.utils import rewards
|
||||
from dm_control.utils import io as resources
|
||||
import numpy as np
|
||||
|
||||
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
|
||||
|
||||
_DIST_TARGET = 0.5
|
||||
_TARGET_SPEED = 6.
|
||||
|
||||
_DEFAULT_TIME_LIMIT = 20 # (seconds)
|
||||
_CONTROL_TIMESTEP = .02 # (seconds)
|
||||
|
||||
|
||||
def get_model_and_assets():
|
||||
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||
return resources.GetResource(os.path.join(_TASKS_DIR, 'ball_in_cup.xml')), common.ASSETS
|
||||
|
||||
|
||||
@ball_in_cup.SUITE.add('custom')
|
||||
def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Ball-in-Cup Spin task."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomBallInCup(random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
class Physics(mujoco.Physics):
|
||||
"""Physics with additional features for the Ball-in-Cup domain."""
|
||||
|
||||
def ball_to_target(self):
|
||||
"""Returns the vector from the ball to the target."""
|
||||
target = self.named.data.site_xpos['target', ['x', 'z']]
|
||||
ball = self.named.data.xpos['ball', ['x', 'z']]
|
||||
return target - ball
|
||||
|
||||
def in_target(self):
|
||||
"""Returns 1 if the ball is in the target, 0 otherwise."""
|
||||
ball_to_target = abs(self.ball_to_target())
|
||||
target_size = self.named.model.site_size['target', [0, 2]]
|
||||
ball_size = self.named.model.geom_size['ball', 0]
|
||||
return float(all(ball_to_target < target_size - ball_size))
|
||||
|
||||
|
||||
class CustomBallInCup(ball_in_cup.BallInCup):
|
||||
"""Custom Ball-in-Cup tasks."""
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
# Find a collision-free random initial position of the ball.
|
||||
penetrating = True
|
||||
valid_pos = False
|
||||
init_out_of_target = self.random.uniform() < 0.1
|
||||
while penetrating or not valid_pos:
|
||||
# Assign a random ball position.
|
||||
physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
|
||||
physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
|
||||
# Check for collisions.
|
||||
physics.after_reset()
|
||||
penetrating = physics.data.ncon > 0
|
||||
valid_pos = bool(physics.in_target()) or init_out_of_target
|
||||
base.Task.initialize_episode(self, physics)
|
||||
|
||||
def get_observation(self, physics):
|
||||
"""Returns an observation of the state."""
|
||||
obs = collections.OrderedDict()
|
||||
obs['position'] = physics.position()
|
||||
obs['velocity'] = physics.velocity()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
dist = np.linalg.norm(physics.ball_to_target())
|
||||
ball_vel_x = abs(physics.named.data.qvel['ball_x'])
|
||||
ball_vel_z = abs(physics.named.data.qvel['ball_z'])
|
||||
ball_vel = np.linalg.norm([ball_vel_x, ball_vel_z])
|
||||
|
||||
# reward: spin around target (maximize distance to target + ball velocity)
|
||||
dist_reward = rewards.tolerance(dist,
|
||||
bounds=(_DIST_TARGET, float('inf')),
|
||||
margin=_DIST_TARGET/2,
|
||||
value_at_margin=0.5,
|
||||
sigmoid='linear')
|
||||
not_in_target = 1 - physics.in_target()
|
||||
vel_reward = rewards.tolerance(ball_vel,
|
||||
bounds=(_TARGET_SPEED, float('inf')),
|
||||
margin=_TARGET_SPEED/2,
|
||||
value_at_margin=0.5,
|
||||
sigmoid='linear')
|
||||
spin_reward = not_in_target * (dist_reward + 2*vel_reward) / 3
|
||||
return spin_reward
|
||||
53
tdmpc2/envs/tasks/ball_in_cup.xml
Normal file
53
tdmpc2/envs/tasks/ball_in_cup.xml
Normal file
@@ -0,0 +1,53 @@
|
||||
<mujoco model="ball in cup">
|
||||
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/skybox.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
|
||||
<default>
|
||||
<motor ctrllimited="true" ctrlrange="-1 1" gear="5"/>
|
||||
<default class="cup">
|
||||
<joint type="slide" damping="3" stiffness="20"/>
|
||||
<geom type="capsule" size=".008" material="self"/>
|
||||
</default>
|
||||
</default>
|
||||
|
||||
<worldbody>
|
||||
<light name="light" directional="true" diffuse=".6 .6 .6" pos="0 0 2" specular=".3 .3 .3"/>
|
||||
<geom name="ground" type="plane" pos="0 0 0" size=".6 .2 10" material="grid"/>
|
||||
<camera name="cam0" pos="0 -1 .8" xyaxes="1 0 0 0 1 2"/>
|
||||
<camera name="cam1" pos="0 -1 .4" xyaxes="1 0 0 0 0 1" />
|
||||
|
||||
<body name="cup" pos="0 0 .6" childclass="cup">
|
||||
<joint name="cup_x" axis="1 0 0"/>
|
||||
<joint name="cup_z" axis="0 0 1"/>
|
||||
<geom name="cup_part_0" fromto="-.05 0 0 -.05 0 -.075" />
|
||||
<geom name="cup_part_1" fromto="-.05 0 -.075 -.025 0 -.1" />
|
||||
<geom name="cup_part_2" fromto="-.025 0 -.1 .025 0 -.1" />
|
||||
<geom name="cup_part_3" fromto=".025 0 -.1 .05 0 -.075" />
|
||||
<geom name="cup_part_4" fromto=".05 0 -.075 .05 0 0" />
|
||||
<site name="cup" pos="0 0 -.108" size=".005"/>
|
||||
<site name="target" type="box" pos="0 0 -.05" size=".05 .006 .05" group="4"/>
|
||||
</body>
|
||||
|
||||
<body name="ball" pos="0 0 .2">
|
||||
<joint name="ball_x" type="slide" axis="1 0 0"/>
|
||||
<joint name="ball_z" type="slide" axis="0 0 1"/>
|
||||
<geom name="ball" type="sphere" size=".025" material="effector"/>
|
||||
<site name="ball" size=".005"/>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<motor name="x" joint="cup_x"/>
|
||||
<motor name="z" joint="cup_z"/>
|
||||
</actuator>
|
||||
|
||||
<tendon>
|
||||
<spatial name="string" limited="true" range="0 0.3" width="0.003">
|
||||
<site site="ball"/>
|
||||
<site site="cup"/>
|
||||
</spatial>
|
||||
</tendon>
|
||||
|
||||
</mujoco>
|
||||
268
tdmpc2/envs/tasks/cheetah.py
Normal file
268
tdmpc2/envs/tasks/cheetah.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import os
|
||||
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import common
|
||||
from dm_control.suite import cheetah
|
||||
from dm_control.utils import rewards
|
||||
from dm_control.utils import io as resources
|
||||
|
||||
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
|
||||
|
||||
_CHEETAH_JUMP_HEIGHT = 1.2
|
||||
_CHEETAH_LIE_HEIGHT = 0.25
|
||||
_CHEETAH_SPIN_SPEED = 8
|
||||
|
||||
|
||||
def get_model_and_assets():
|
||||
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||
return resources.GetResource(os.path.join(_TASKS_DIR, 'cheetah.xml')), common.ASSETS
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def run_backwards(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Run Backwards task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='run-backwards', move_speed=cheetah._RUN_SPEED*0.8, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def stand_front(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Stand Front task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='stand-front', move_speed=0.5, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def stand_back(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Stand Back task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='stand-back', move_speed=0.5, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def jump(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Jump task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='jump', move_speed=0.5, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def run_front(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Run Front task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='run-front', move_speed=cheetah._RUN_SPEED*0.6, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def run_back(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Run Back task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='run-back', move_speed=cheetah._RUN_SPEED*0.6, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def lie_down(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Lie Down task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='lie-down', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def legs_up(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Legs Up task."""
|
||||
physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='legs-up', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def flip(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Flip task."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='flip', move_speed=cheetah._RUN_SPEED, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@cheetah.SUITE.add('custom')
|
||||
def flip_backwards(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Flip Backwards task."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomCheetah(goal='flip-backwards', move_speed=cheetah._RUN_SPEED*0.8, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(physics, task, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
class Physics(cheetah.Physics):
|
||||
"""Physics simulation with additional features for the Cheetah domain."""
|
||||
|
||||
def angmomentum(self):
|
||||
"""Returns the angular momentum of torso of the Cheetah about Y axis."""
|
||||
return self.named.data.subtree_angmom['torso'][1]
|
||||
|
||||
|
||||
class CustomCheetah(cheetah.Cheetah):
|
||||
"""Custom Cheetah tasks."""
|
||||
|
||||
def __init__(self, goal='run-backwards', move_speed=0, random=None):
|
||||
super().__init__(random)
|
||||
self._goal = goal
|
||||
self._move_speed = move_speed
|
||||
|
||||
def _run_backwards_reward(self, physics):
|
||||
return rewards.tolerance(physics.speed(),
|
||||
bounds=(-float('inf'), -self._move_speed),
|
||||
margin=self._move_speed,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
|
||||
def _stand_one_foot_reward(self, physics, foot):
|
||||
"""Note: `foot` is the foot that is *not* on the ground."""
|
||||
torso_height = physics.named.data.xpos['torso', 'z']
|
||||
foot_height = physics.named.data.xpos[foot, 'z']
|
||||
height_reward = rewards.tolerance((torso_height + foot_height)/2,
|
||||
bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
|
||||
margin=_CHEETAH_JUMP_HEIGHT/2)
|
||||
horizontal_speed_reward = rewards.tolerance(physics.speed(),
|
||||
bounds=(-self._move_speed, self._move_speed),
|
||||
margin=self._move_speed,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
stand_reward = (5*height_reward + horizontal_speed_reward) / 6
|
||||
return stand_reward
|
||||
|
||||
def _stand_front_reward(self, physics):
|
||||
return self._stand_one_foot_reward(physics, 'bfoot')
|
||||
|
||||
def _stand_back_reward(self, physics):
|
||||
return self._stand_one_foot_reward(physics, 'ffoot')
|
||||
|
||||
def _jump_reward(self, physics):
|
||||
front_reward = self._stand_front_reward(physics)
|
||||
back_reward = self._stand_back_reward(physics)
|
||||
jump_reward = (front_reward + back_reward) / 2
|
||||
return jump_reward
|
||||
|
||||
def _run_one_foot_reward(self, physics, foot):
|
||||
"""Note: `foot` is the foot that is *not* on the ground."""
|
||||
torso_height = physics.named.data.xpos['torso', 'z']
|
||||
foot_height = physics.named.data.xpos[foot, 'z']
|
||||
torso_up = rewards.tolerance(torso_height,
|
||||
bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
|
||||
margin=_CHEETAH_JUMP_HEIGHT/2)
|
||||
foot_up = rewards.tolerance(foot_height,
|
||||
bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
|
||||
margin=_CHEETAH_JUMP_HEIGHT/2)
|
||||
up_reward = (3*foot_up + 2*torso_up) / 5
|
||||
if self._move_speed == 0:
|
||||
return up_reward
|
||||
horizontal_speed_reward = rewards.tolerance(physics.speed(),
|
||||
bounds=(self._move_speed, float('inf')),
|
||||
margin=self._move_speed,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
return up_reward * (5*horizontal_speed_reward + 1) / 6
|
||||
|
||||
def _run_front_reward(self, physics):
|
||||
return self._run_one_foot_reward(physics, 'bfoot')
|
||||
|
||||
def _run_back_reward(self, physics):
|
||||
return self._run_one_foot_reward(physics, 'ffoot')
|
||||
|
||||
def _lie_down_reward(self, physics):
|
||||
torso_height = physics.named.data.xpos['torso', 'z']
|
||||
feet_height = (physics.named.data.xpos['ffoot', 'z'] + physics.named.data.xpos['bfoot', 'z']) / 2
|
||||
torso_down = rewards.tolerance(torso_height,
|
||||
bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
|
||||
margin=_CHEETAH_LIE_HEIGHT,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
feet_down = rewards.tolerance(feet_height,
|
||||
bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
|
||||
margin=_CHEETAH_LIE_HEIGHT,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
lie_down_reward = (3*torso_down + feet_down) / 4
|
||||
return lie_down_reward
|
||||
|
||||
def _legs_up_reward(self, physics):
|
||||
torso_height = physics.named.data.xpos['torso', 'z']
|
||||
torso_down = rewards.tolerance(torso_height,
|
||||
bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
|
||||
margin=_CHEETAH_LIE_HEIGHT/2)
|
||||
get_up = self._run_one_foot_reward(physics, 'bfoot')
|
||||
legs_up_reward = (5*torso_down + get_up) / 6
|
||||
return legs_up_reward
|
||||
|
||||
def _flip_reward(self, physics, forward=True):
|
||||
spin_reward = rewards.tolerance(
|
||||
(1. if forward else -1.) * physics.angmomentum(),
|
||||
bounds=(_CHEETAH_SPIN_SPEED, float('inf')),
|
||||
margin=_CHEETAH_SPIN_SPEED,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
horizontal_speed_reward = rewards.tolerance(
|
||||
(1. if forward else -1.) * physics.speed(),
|
||||
bounds=(self._move_speed, float('inf')),
|
||||
margin=self._move_speed,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
flip_reward = (2*spin_reward + horizontal_speed_reward) / 3
|
||||
return flip_reward
|
||||
|
||||
def get_reward(self, physics):
|
||||
if self._goal == 'run-backwards':
|
||||
return self._run_backwards_reward(physics)
|
||||
elif self._goal == 'stand-front':
|
||||
return self._stand_front_reward(physics)
|
||||
elif self._goal == 'stand-back':
|
||||
return self._stand_back_reward(physics)
|
||||
elif self._goal == 'jump':
|
||||
return self._jump_reward(physics)
|
||||
elif self._goal == 'run-front':
|
||||
return self._run_front_reward(physics)
|
||||
elif self._goal == 'run-back':
|
||||
return self._run_back_reward(physics)
|
||||
elif self._goal == 'lie-down':
|
||||
return self._lie_down_reward(physics)
|
||||
elif self._goal == 'legs-up':
|
||||
return self._legs_up_reward(physics)
|
||||
elif self._goal == 'flip':
|
||||
return self._flip_reward(physics, forward=True)
|
||||
elif self._goal == 'flip-backwards':
|
||||
return self._flip_reward(physics, forward=False)
|
||||
else:
|
||||
raise NotImplementedError(f'Goal {self._goal} is not implemented.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = jump()
|
||||
obs = env.reset()
|
||||
import numpy as np
|
||||
next_obs, reward, done, info = env.step(np.zeros(6))
|
||||
print(reward)
|
||||
73
tdmpc2/envs/tasks/cheetah.xml
Normal file
73
tdmpc2/envs/tasks/cheetah.xml
Normal file
@@ -0,0 +1,73 @@
|
||||
<mujoco model="cheetah">
|
||||
<include file="./common/skybox.xml"/>
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
|
||||
<compiler settotalmass="14"/>
|
||||
|
||||
<default>
|
||||
<default class="cheetah">
|
||||
<joint limited="true" damping=".01" armature=".1" stiffness="8" type="hinge" axis="0 1 0"/>
|
||||
<geom contype="1" conaffinity="1" condim="3" friction=".4 .1 .1" material="self"/>
|
||||
</default>
|
||||
<default class="free">
|
||||
<joint limited="false" damping="0" armature="0" stiffness="0"/>
|
||||
</default>
|
||||
<motor ctrllimited="true" ctrlrange="-1 1"/>
|
||||
</default>
|
||||
|
||||
<statistic center="0 0 .7" extent="2"/>
|
||||
|
||||
<option timestep="0.01"/>
|
||||
|
||||
<worldbody>
|
||||
<geom name="ground" type="plane" conaffinity="1" pos="98 0 0" size="200 .8 .5" material="grid"/>
|
||||
<body name="torso" pos="0 0 .7" childclass="cheetah">
|
||||
<light name="light" pos="0 0 2" mode="trackcom"/>
|
||||
<camera name="side" pos="0 -3 0" quat="0.707 0.707 0 0" mode="trackcom"/>
|
||||
<camera name="back" pos="-1.8 -1.3 0.8" xyaxes="0.45 -0.9 0 0.3 0.15 0.94" mode="trackcom"/>
|
||||
<joint name="rootx" type="slide" axis="1 0 0" class="free"/>
|
||||
<joint name="rootz" type="slide" axis="0 0 1" class="free"/>
|
||||
<joint name="rooty" type="hinge" axis="0 1 0" class="free"/>
|
||||
<geom name="torso" type="capsule" fromto="-.5 0 0 .5 0 0" size="0.046"/>
|
||||
<geom name="head" type="capsule" pos=".6 0 .1" euler="0 50 0" size="0.046 .15"/>
|
||||
<body name="bthigh" pos="-.5 0 0">
|
||||
<joint name="bthigh" range="-30 60" stiffness="240" damping="6"/>
|
||||
<geom name="bthigh" type="capsule" pos=".1 0 -.13" euler="0 -218 0" size="0.046 .145"/>
|
||||
<body name="bshin" pos=".16 0 -.25">
|
||||
<joint name="bshin" range="-50 50" stiffness="180" damping="4.5"/>
|
||||
<geom name="bshin" type="capsule" pos="-.14 0 -.07" euler="0 -116 0" size="0.046 .15"/>
|
||||
<body name="bfoot" pos="-.28 0 -.14">
|
||||
<joint name="bfoot" range="-230 50" stiffness="120" damping="3"/>
|
||||
<geom name="bfoot" type="capsule" pos=".03 0 -.097" euler="0 -15 0" size="0.046 .094"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="fthigh" pos=".5 0 0">
|
||||
<joint name="fthigh" range="-57 .40" stiffness="180" damping="4.5"/>
|
||||
<geom name="fthigh" type="capsule" pos="-.07 0 -.12" euler="0 30 0" size="0.046 .133"/>
|
||||
<body name="fshin" pos="-.14 0 -.24">
|
||||
<joint name="fshin" range="-70 50" stiffness="120" damping="3"/>
|
||||
<geom name="fshin" type="capsule" pos=".065 0 -.09" euler="0 -34 0" size="0.046 .106"/>
|
||||
<body name="ffoot" pos=".13 0 -.18">
|
||||
<joint name="ffoot" range="-28 28" stiffness="60" damping="1.5"/>
|
||||
<geom name="ffoot" type="capsule" pos=".045 0 -.07" euler="0 -34 0" size="0.046 .07"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<sensor>
|
||||
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
|
||||
</sensor>
|
||||
|
||||
<actuator>
|
||||
<motor name="bthigh" joint="bthigh" gear="120" />
|
||||
<motor name="bshin" joint="bshin" gear="90" />
|
||||
<motor name="bfoot" joint="bfoot" gear="60" />
|
||||
<motor name="fthigh" joint="fthigh" gear="90" />
|
||||
<motor name="fshin" joint="fshin" gear="60" />
|
||||
<motor name="ffoot" joint="ffoot" gear="30" />
|
||||
</actuator>
|
||||
</mujoco>
|
||||
79
tdmpc2/envs/tasks/fish.py
Normal file
79
tdmpc2/envs/tasks/fish.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import collections
|
||||
import os
|
||||
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
from dm_control.suite import common
|
||||
from dm_control.suite import fish
|
||||
from dm_control.utils import rewards
|
||||
from dm_control.utils import io as resources
|
||||
import numpy as np
|
||||
|
||||
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
|
||||
|
||||
_DEFAULT_TIME_LIMIT = 40
|
||||
_CONTROL_TIMESTEP = .04
|
||||
_JOINTS = ['tail1',
|
||||
'tail_twist',
|
||||
'tail2',
|
||||
'finright_roll',
|
||||
'finright_pitch',
|
||||
'finleft_roll',
|
||||
'finleft_pitch']
|
||||
|
||||
|
||||
def get_model_and_assets():
|
||||
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||
return resources.GetResource(os.path.join(_TASKS_DIR, 'fish.xml')), common.ASSETS
|
||||
|
||||
|
||||
@fish.SUITE.add('custom')
|
||||
def obstacles(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Fish Obstacles task."""
|
||||
physics = fish.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = Obstacles(random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
class Obstacles(fish.Swim):
|
||||
"""A custom Fish Obstacles task."""
|
||||
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def in_wall(self, physics, name, min_distance=0.08):
|
||||
"""Returns True if the given body is too close to a wall."""
|
||||
for wall in ['wall0', 'wall1', 'wall2', 'wall3']:
|
||||
l1_dist = np.min(np.abs(physics.named.data.geom_xpos[name][:2] - physics.named.data.geom_xpos[wall][:2]))
|
||||
if l1_dist < min_distance:
|
||||
return True
|
||||
return False
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
in_wall = True
|
||||
while in_wall:
|
||||
# Randomize fish position.
|
||||
quat = self.random.randn(4)
|
||||
physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
|
||||
for joint in _JOINTS:
|
||||
physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
|
||||
# Randomize target position.
|
||||
physics.named.model.geom_pos['target', 'x'] = self.random.uniform(-.4, .4)
|
||||
physics.named.model.geom_pos['target', 'y'] = self.random.uniform(-.4, .4)
|
||||
physics.named.model.geom_pos['target', 'z'] = self.random.uniform(.1, .3)
|
||||
# Make sure target is not too close to a wall.
|
||||
physics.after_reset()
|
||||
in_wall = self.in_wall(physics, 'target')
|
||||
base.Task.initialize_episode(self, physics)
|
||||
|
||||
def get_reward(self, physics):
|
||||
radii = physics.named.model.geom_size[['mouth', 'target'], 0].sum()
|
||||
in_target = rewards.tolerance(np.linalg.norm(physics.mouth_to_target()),
|
||||
bounds=(0, radii), margin=2*radii)
|
||||
is_upright = 0.5 * (physics.upright() + 1)
|
||||
is_not_in_wall = 1. - self.in_wall(physics, 'torso', min_distance=0.06)
|
||||
return is_not_in_wall * (7*in_target + is_upright) / 8
|
||||
93
tdmpc2/envs/tasks/fish.xml
Normal file
93
tdmpc2/envs/tasks/fish.xml
Normal file
@@ -0,0 +1,93 @@
|
||||
<mujoco model="fish">
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
<asset>
|
||||
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>
|
||||
</asset>
|
||||
|
||||
|
||||
<option timestep="0.004" density="5000">
|
||||
<flag gravity="disable" constraint="disable"/>
|
||||
</option>
|
||||
|
||||
<default>
|
||||
<general ctrllimited="true"/>
|
||||
<default class="fish">
|
||||
<joint type="hinge" limited="false" range="-60 60" damping="2e-5" solreflimit=".1 1" solimplimit="0 .8 .1"/>
|
||||
<geom material="self"/>
|
||||
</default>
|
||||
<default class="wall">
|
||||
<geom type="box" material="self"/>
|
||||
</default>
|
||||
</default>
|
||||
|
||||
<worldbody>
|
||||
<camera name="tracking_top" pos="0 0 1" xyaxes="1 0 0 0 1 0" mode="trackcom"/>
|
||||
<camera name="tracking_x" pos="-.3 0 .2" xyaxes="0 -1 0 0.342 0 0.940" fovy="60" mode="trackcom"/>
|
||||
<camera name="tracking_y" pos="0 -.3 .2" xyaxes="1 0 0 0 0.342 0.940" fovy="60" mode="trackcom"/>
|
||||
<camera name="fixed_top" pos="0 0 5.5" fovy="10"/>
|
||||
<geom name="ground" type="plane" size=".5 .5 .1" material="grid"/>
|
||||
|
||||
<geom name="wall0" class="wall" pos="-.15 -.15 .1" size=".05 .05 .1"/>
|
||||
<geom name="wall1" class="wall" pos=".15 -.15 .1" size=".05 .05 .1"/>
|
||||
<geom name="wall2" class="wall" pos=".15 .15 .1" size=".05 .05 .1"/>
|
||||
<geom name="wall3" class="wall" pos="-.15 .15 .1" size=".05 .05 .1"/>
|
||||
|
||||
<geom name="target" type="sphere" pos="0 .4 .1" size=".04" material="target"/>
|
||||
<body name="torso" pos="0 0 .1" childclass="fish">
|
||||
<light name="light" diffuse=".6 .6 .6" pos="0 0 0.5" dir="0 0 -1" specular=".3 .3 .3" mode="track"/>
|
||||
<joint name="root" type="free" damping="0" limited="false"/>
|
||||
<site name="torso" size=".01" rgba="0 0 0 0"/>
|
||||
<geom name="eye" type="ellipsoid" pos="0 .055 .015" size=".008 .012 .008" euler="-10 0 0" material="eye" mass="0"/>
|
||||
<camera name="eye" pos="0 .06 .02" xyaxes="1 0 0 0 0 1"/>
|
||||
<geom name="mouth" type="capsule" fromto="0 .079 0 0 .07 0" size=".005" material="effector" mass="0"/>
|
||||
<geom name="lower_mouth" type="capsule" fromto="0 .079 -.004 0 .07 -.003" size=".0045" material="effector" mass="0"/>
|
||||
<geom name="torso" type="ellipsoid" size=".01 .08 .04" mass="0"/>
|
||||
<geom name="back_fin" type="ellipsoid" size=".001 .03 .015" pos="0 -.03 .03" material="effector" mass="0"/>
|
||||
<geom name="torso_massive" type="box" size=".002 .06 .03" group="4"/>
|
||||
<body name="tail1" pos="0 -.09 0">
|
||||
<joint name="tail1" axis="0 0 1" pos="0 .01 0"/>
|
||||
<joint name="tail_twist" axis="0 1 0" pos="0 .01 0" range="-30 30"/>
|
||||
<geom name="tail1" type="ellipsoid" size=".001 .008 .016"/>
|
||||
<body name="tail2" pos="0 -.028 0">
|
||||
<joint name="tail2" axis="0 0 1" pos="0 .02 0" stiffness="8e-5"/>
|
||||
<geom name="tail2" type="ellipsoid" size=".001 .018 .035"/>
|
||||
</body>
|
||||
</body>
|
||||
<body name="finright" pos=".01 0 0">
|
||||
<joint name="finright_roll" axis="0 1 0"/>
|
||||
<joint name="finright_pitch" axis="1 0 0" pos="0 .005 0"/>
|
||||
<geom name="finright" type="ellipsoid" pos=".015 0 0" size=".02 .015 .001" />
|
||||
</body>
|
||||
<body name="finleft" pos="-.01 0 0">
|
||||
<joint name="finleft_roll" axis="0 1 0"/>
|
||||
<joint name="finleft_pitch" axis="1 0 0" pos="0 .005 0"/>
|
||||
<geom name="finleft" type="ellipsoid" pos="-.015 0 0" size=".02 .015 .001"/>
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<tendon>
|
||||
<fixed name="fins_flap">
|
||||
<joint joint="finleft_roll" coef="-.5"/>
|
||||
<joint joint="finright_roll" coef=".5"/>
|
||||
</fixed>
|
||||
<fixed name="fins_sym" stiffness="1e-4">
|
||||
<joint joint="finleft_roll" coef=".5"/>
|
||||
<joint joint="finright_roll" coef=".5"/>
|
||||
</fixed>
|
||||
</tendon>
|
||||
|
||||
<actuator>
|
||||
<position name="tail" joint="tail1" ctrlrange="-1 1" kp="5e-4"/>
|
||||
<position name="tail_twist" joint="tail_twist" ctrlrange="-1 1" kp="1e-4"/>
|
||||
<position name="fins_flap" tendon="fins_flap" ctrlrange="-1 1" kp="3e-4"/>
|
||||
<position name="finleft_pitch" joint="finleft_pitch" ctrlrange="-1 1" kp="1e-4"/>
|
||||
<position name="finright_pitch" joint="finright_pitch" ctrlrange="-1 1" kp="1e-4"/>
|
||||
</actuator>
|
||||
|
||||
<sensor>
|
||||
<velocimeter name="velocimeter" site="torso"/>
|
||||
<gyro name="gyro" site="torso"/>
|
||||
</sensor>
|
||||
</mujoco>
|
||||
114
tdmpc2/envs/tasks/hopper.py
Normal file
114
tdmpc2/envs/tasks/hopper.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import common
|
||||
from dm_control.suite import hopper
|
||||
from dm_control.utils import rewards
|
||||
from dm_control.utils import io as resources
|
||||
import numpy as np
|
||||
|
||||
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
|
||||
|
||||
_CONTROL_TIMESTEP = .02 # (Seconds)
|
||||
|
||||
# Default duration of an episode, in seconds.
|
||||
_DEFAULT_TIME_LIMIT = 20
|
||||
|
||||
# Minimal height of torso over foot above which stand reward is 1.
|
||||
_STAND_HEIGHT = 0.6
|
||||
|
||||
# Hopping speed above which hop reward is 1.
|
||||
_HOP_SPEED = 2
|
||||
|
||||
# Angular momentum above which reward is 1.
|
||||
_SPIN_SPEED = 5
|
||||
|
||||
|
||||
def get_model_and_assets():
|
||||
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||
return resources.GetResource(os.path.join(_TASKS_DIR, 'hopper.xml')), common.ASSETS
|
||||
|
||||
|
||||
@hopper.SUITE.add('custom')
|
||||
def hop_backwards(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Hop Backwards task."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomHopper(goal='hop-backwards', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@hopper.SUITE.add('custom')
|
||||
def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Flip task."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomHopper(goal='flip', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@hopper.SUITE.add('custom')
|
||||
def flip_backwards(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Flip Backwards task."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||
task = CustomHopper(goal='flip-backwards', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
class Physics(hopper.Physics):
|
||||
|
||||
def angmomentum(self):
|
||||
"""Returns the angular momentum of torso of the Cheetah about Y axis."""
|
||||
return self.named.data.subtree_angmom['torso'][1]
|
||||
|
||||
|
||||
class CustomHopper(hopper.Hopper):
|
||||
"""Custom Hopper tasks."""
|
||||
|
||||
def __init__(self, goal='hop-backwards', random=None):
|
||||
super().__init__(None, random)
|
||||
self._goal = goal
|
||||
|
||||
def _hop_backwards_reward(self, physics):
|
||||
standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
|
||||
hopping = rewards.tolerance(physics.speed(),
|
||||
bounds=(-float('inf'), -_HOP_SPEED/2),
|
||||
margin=_HOP_SPEED/4,
|
||||
value_at_margin=0.5,
|
||||
sigmoid='linear')
|
||||
return standing * hopping
|
||||
|
||||
def _flip_reward(self, physics, forward=True):
|
||||
reward = rewards.tolerance((1. if forward else -1.) * physics.angmomentum(),
|
||||
bounds=(_SPIN_SPEED, float('inf')),
|
||||
margin=_SPIN_SPEED/2,
|
||||
value_at_margin=0,
|
||||
sigmoid='linear')
|
||||
return reward
|
||||
|
||||
|
||||
def get_reward(self, physics):
|
||||
if self._goal == 'hop-backwards':
|
||||
return self._hop_backwards_reward(physics)
|
||||
elif self._goal == 'flip':
|
||||
return self._flip_reward(physics, forward=True)
|
||||
elif self._goal == 'flip-backwards':
|
||||
return self._flip_reward(physics, forward=False)
|
||||
else:
|
||||
raise NotImplementedError(f'Goal {self._goal} is not implemented.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = hop_backwards()
|
||||
obs = env.reset()
|
||||
import numpy as np
|
||||
next_obs, reward, done, info = env.step(np.zeros(2))
|
||||
print(reward)
|
||||
66
tdmpc2/envs/tasks/hopper.xml
Normal file
66
tdmpc2/envs/tasks/hopper.xml
Normal file
@@ -0,0 +1,66 @@
|
||||
<mujoco model="planar hopper">
|
||||
<include file="./common/skybox.xml"/>
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
|
||||
<statistic extent="2" center="0 0 .5"/>
|
||||
|
||||
<default>
|
||||
<default class="hopper">
|
||||
<joint type="hinge" axis="0 1 0" limited="true" damping=".05" armature=".2"/>
|
||||
<geom type="capsule" material="self"/>
|
||||
<site type="sphere" size="0.05" group="3"/>
|
||||
</default>
|
||||
<default class="free">
|
||||
<joint limited="false" damping="0" armature="0" stiffness="0"/>
|
||||
</default>
|
||||
<motor ctrlrange="-1 1" ctrllimited="true"/>
|
||||
</default>
|
||||
|
||||
<option timestep="0.005"/>
|
||||
|
||||
<worldbody>
|
||||
<camera name="cam0" pos="0 -2.8 0.8" euler="90 0 0" mode="trackcom"/>
|
||||
<camera name="back" pos="-2 -.2 1.2" xyaxes="0.2 -1 0 .5 0 2" mode="trackcom"/>
|
||||
<geom name="floor" type="plane" conaffinity="1" pos="48 0 0" size="50 1 .2" material="grid"/>
|
||||
<body name="torso" pos="0 0 1" childclass="hopper">
|
||||
<light name="top" pos="0 0 2" mode="trackcom"/>
|
||||
<joint name="rootx" type="slide" axis="1 0 0" class="free"/>
|
||||
<joint name="rootz" type="slide" axis="0 0 1" class="free"/>
|
||||
<joint name="rooty" type="hinge" axis="0 1 0" class="free"/>
|
||||
<geom name="torso" fromto="0 0 -.05 0 0 .2" size="0.0653"/>
|
||||
<geom name="nose" fromto=".08 0 .13 .15 0 .14" size="0.03"/>
|
||||
<body name="pelvis" pos="0 0 -.05">
|
||||
<joint name="waist" range="-30 30"/>
|
||||
<geom name="pelvis" fromto="0 0 0 0 0 -.15" size="0.065"/>
|
||||
<body name="thigh" pos="0 0 -.2">
|
||||
<joint name="hip" range="-170 10"/>
|
||||
<geom name="thigh" fromto="0 0 0 0 0 -.33" size="0.04"/>
|
||||
<body name="calf" pos="0 0 -.33">
|
||||
<joint name="knee" range="5 150"/>
|
||||
<geom name="calf" fromto="0 0 0 0 0 -.32" size="0.03"/>
|
||||
<body name="foot" pos="0 0 -.32">
|
||||
<joint name="ankle" range="-45 45"/>
|
||||
<geom name="foot" fromto="-.08 0 0 .17 0 0" size="0.04"/>
|
||||
<site name="touch_toe" pos=".17 0 0"/>
|
||||
<site name="touch_heel" pos="-.08 0 0"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<sensor>
|
||||
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
|
||||
<touch name="touch_toe" site="touch_toe"/>
|
||||
<touch name="touch_heel" site="touch_heel"/>
|
||||
</sensor>
|
||||
|
||||
<actuator>
|
||||
<motor name="waist" joint="waist" gear="30"/>
|
||||
<motor name="hip" joint="hip" gear="40"/>
|
||||
<motor name="knee" joint="knee" gear="30"/>
|
||||
<motor name="ankle" joint="ankle" gear="10"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
43
tdmpc2/envs/tasks/pendulum.py
Normal file
43
tdmpc2/envs/tasks/pendulum.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import pendulum
|
||||
from dm_control.suite import common
|
||||
from dm_control.utils import rewards
|
||||
from dm_control.utils import io as resources
|
||||
import numpy as np
|
||||
|
||||
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
|
||||
|
||||
_DEFAULT_TIME_LIMIT = 20
|
||||
_TARGET_SPEED = 9.
|
||||
|
||||
|
||||
def get_model_and_assets():
|
||||
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||
return resources.GetResource(os.path.join(_TASKS_DIR, 'pendulum.xml')), common.ASSETS
|
||||
|
||||
|
||||
@pendulum.SUITE.add('custom')
|
||||
def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||
environment_kwargs=None):
|
||||
"""Returns pendulum spin task."""
|
||||
physics = pendulum.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = Spin(random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||
|
||||
|
||||
class Spin(pendulum.SwingUp):
|
||||
"""A custom Pendulum Spin task."""
|
||||
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def get_reward(self, physics):
|
||||
return rewards.tolerance(np.linalg.norm(physics.angular_velocity()),
|
||||
bounds=(_TARGET_SPEED, float('inf')),
|
||||
margin=_TARGET_SPEED/2,
|
||||
value_at_margin=0.5,
|
||||
sigmoid='linear')
|
||||
26
tdmpc2/envs/tasks/pendulum.xml
Normal file
26
tdmpc2/envs/tasks/pendulum.xml
Normal file
@@ -0,0 +1,26 @@
|
||||
<mujoco model="pendulum">
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/skybox.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
|
||||
<option timestep="0.02">
|
||||
<flag contact="disable" energy="enable"/>
|
||||
</option>
|
||||
|
||||
<worldbody>
|
||||
<light name="light" pos="0 0 2"/>
|
||||
<geom name="floor" size="2 2 .2" type="plane" material="grid"/>
|
||||
<camera name="fixed" pos="0 -1.5 2" xyaxes='1 0 0 0 1 1'/>
|
||||
<camera name="lookat" mode="targetbodycom" target="pole" pos="0 -2 1"/>
|
||||
<body name="pole" pos="0 0 .6">
|
||||
<joint name="hinge" type="hinge" axis="0 1 0" damping="0.1"/>
|
||||
<geom name="base" material="decoration" type="cylinder" fromto="0 -.03 0 0 .03 0" size="0.021" mass="0"/>
|
||||
<geom name="pole" material="self" type="capsule" fromto="0 0 0 0 0 0.5" size="0.02" mass="0"/>
|
||||
<geom name="mass" material="effector" type="sphere" pos="0 0 0.5" size="0.05" mass="1"/>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<motor name="torque" joint="hinge" gear="1" ctrlrange="-1 1" ctrllimited="true"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
89
tdmpc2/envs/tasks/reacher.py
Normal file
89
tdmpc2/envs/tasks/reacher.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import collections
|
||||
import os
|
||||
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import common
|
||||
from dm_control.suite import reacher
|
||||
from dm_control.utils import io as resources
|
||||
import numpy as np
|
||||
|
||||
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
|
||||
|
||||
_DEFAULT_TIME_LIMIT = 20
|
||||
_BIG_TARGET = .05
|
||||
_SMALL_TARGET = .015
|
||||
|
||||
|
||||
def get_model_and_assets(links):
|
||||
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||
assert links in {3, 4}, 'Only 3 or 4 links are supported.'
|
||||
fn = 'reacher_three_links.xml' if links == 3 else 'reacher_four_links.xml'
|
||||
return resources.GetResource(os.path.join(_TASKS_DIR, fn)), common.ASSETS
|
||||
|
||||
|
||||
@reacher.SUITE.add('custom')
|
||||
def three_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns three-link reacher with sparse reward with 5e-2 tol and randomized target."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets(links=3))
|
||||
task = CustomThreeLinkReacher(target_size=_BIG_TARGET, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||
|
||||
|
||||
@reacher.SUITE.add('custom')
|
||||
def three_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns three-link reacher with sparse reward with 1e-2 tol and randomized target."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets(links=3))
|
||||
task = CustomThreeLinkReacher(target_size=_SMALL_TARGET, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||
|
||||
|
||||
@reacher.SUITE.add('custom')
|
||||
def four_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns three-link reacher with sparse reward with 5e-2 tol and randomized target."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets(links=4))
|
||||
task = CustomThreeLinkReacher(target_size=_BIG_TARGET, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||
|
||||
|
||||
@reacher.SUITE.add('custom')
|
||||
def four_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns three-link reacher with sparse reward with 1e-2 tol and randomized target."""
|
||||
physics = Physics.from_xml_string(*get_model_and_assets(links=4))
|
||||
task = CustomThreeLinkReacher(target_size=_SMALL_TARGET, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||
|
||||
|
||||
class Physics(mujoco.Physics):
|
||||
"""Physics simulation with additional features for the Reacher domain."""
|
||||
|
||||
def finger_to_target(self):
|
||||
"""Returns the vector from target to finger in global coordinates."""
|
||||
return (self.named.data.geom_xpos['target', :2] -
|
||||
self.named.data.geom_xpos['finger', :2])
|
||||
|
||||
def finger_to_target_dist(self):
|
||||
"""Returns the signed distance between the finger and target surface."""
|
||||
return np.linalg.norm(self.finger_to_target())
|
||||
|
||||
|
||||
class CustomThreeLinkReacher(reacher.Reacher):
|
||||
"""Custom Reacher tasks."""
|
||||
|
||||
def __init__(self, target_size, random=None):
|
||||
super().__init__(target_size, random)
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs['position'] = physics.position()
|
||||
obs['to_target'] = physics.finger_to_target()
|
||||
obs['velocity'] = physics.velocity()
|
||||
return obs
|
||||
57
tdmpc2/envs/tasks/reacher_four_links.xml
Normal file
57
tdmpc2/envs/tasks/reacher_four_links.xml
Normal file
@@ -0,0 +1,57 @@
|
||||
<mujoco model="two-link planar reacher">
|
||||
<include file="./common/skybox.xml"/>
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
|
||||
<option timestep="0.02">
|
||||
<flag contact="disable"/>
|
||||
</option>
|
||||
|
||||
<default>
|
||||
<joint type="hinge" axis="0 0 1" damping="0.01"/>
|
||||
<motor gear=".05" ctrlrange="-1 1" ctrllimited="true"/>
|
||||
</default>
|
||||
|
||||
<worldbody>
|
||||
<light name="light" pos="0 0 1"/>
|
||||
<camera name="fixed" pos="0 0 .75" quat="1 0 0 0"/>
|
||||
<!-- Arena -->
|
||||
<geom name="ground" type="plane" pos="0 0 0" size=".3 .3 10" material="grid"/>
|
||||
<geom name="wall_x" type="plane" pos="-.3 0 .02" zaxis="1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||
<geom name="wall_y" type="plane" pos="0 -.3 .02" zaxis="0 1 0" size=".3 .02 .02" material="decoration"/>
|
||||
<geom name="wall_neg_x" type="plane" pos=".3 0 .02" zaxis="-1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||
<geom name="wall_neg_y" type="plane" pos="0 .3 .02" zaxis="0 -1 0" size=".3 .02 .02" material="decoration"/>
|
||||
|
||||
<!-- Arm -->
|
||||
<geom name="root" type="cylinder" fromto="0 0 0 0 0 0.02" size=".011" material="decoration"/>
|
||||
<body name="arm0" pos="0 0 .01">
|
||||
<geom name="arm0" type="capsule" fromto="0 0 0 0.06 0 0" size=".01" material="self"/>
|
||||
<joint name="shoulder0"/>
|
||||
<body name="arm1" pos=".06 0 0">
|
||||
<geom name="arm1" type="capsule" fromto="0 0 0 0.06 0 0" size=".01" material="self"/>
|
||||
<joint name="shoulder1" limited="true" range="-80 80"/>
|
||||
<body name="arm2" pos=".06 0 0">
|
||||
<geom name="arm2" type="capsule" fromto="0 0 0 0.06 0 0" size=".01" material="self"/>
|
||||
<joint name="shoulder2" limited="true" range="-80 80"/>
|
||||
<body name="hand" pos=".06 0 0">
|
||||
<geom name="hand" type="capsule" fromto="0 0 0 0.1 0 0" size=".01" material="self"/>
|
||||
<joint name="wrist" limited="true" range="-80 80"/>
|
||||
<body name="finger" pos=".06 0 0">
|
||||
<camera name="hand" pos="0 0 .2" mode="track"/>
|
||||
<geom name="finger" type="sphere" size=".01" material="effector"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<!-- Target -->
|
||||
<geom name="target" pos="0 0 .01" material="target" type="sphere" size=".05"/>
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<motor name="shoulder0" joint="shoulder0"/>
|
||||
<motor name="shoulder1" joint="shoulder1"/>
|
||||
<motor name="shoulder2" joint="shoulder2"/>
|
||||
<motor name="wrist" joint="wrist"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
52
tdmpc2/envs/tasks/reacher_three_links.xml
Normal file
52
tdmpc2/envs/tasks/reacher_three_links.xml
Normal file
@@ -0,0 +1,52 @@
|
||||
<mujoco model="two-link planar reacher">
|
||||
<include file="./common/skybox.xml"/>
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
|
||||
<option timestep="0.02">
|
||||
<flag contact="disable"/>
|
||||
</option>
|
||||
|
||||
<default>
|
||||
<joint type="hinge" axis="0 0 1" damping="0.01"/>
|
||||
<motor gear=".05" ctrlrange="-1 1" ctrllimited="true"/>
|
||||
</default>
|
||||
|
||||
<worldbody>
|
||||
<light name="light" pos="0 0 1"/>
|
||||
<camera name="fixed" pos="0 0 .75" quat="1 0 0 0"/>
|
||||
<!-- Arena -->
|
||||
<geom name="ground" type="plane" pos="0 0 0" size=".3 .3 10" material="grid"/>
|
||||
<geom name="wall_x" type="plane" pos="-.3 0 .02" zaxis="1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||
<geom name="wall_y" type="plane" pos="0 -.3 .02" zaxis="0 1 0" size=".3 .02 .02" material="decoration"/>
|
||||
<geom name="wall_neg_x" type="plane" pos=".3 0 .02" zaxis="-1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||
<geom name="wall_neg_y" type="plane" pos="0 .3 .02" zaxis="0 -1 0" size=".3 .02 .02" material="decoration"/>
|
||||
|
||||
<!-- Arm -->
|
||||
<geom name="root" type="cylinder" fromto="0 0 0 0 0 0.02" size=".011" material="decoration"/>
|
||||
<body name="arm0" pos="0 0 .01">
|
||||
<geom name="arm0" type="capsule" fromto="0 0 0 0.09 0 0" size=".01" material="self"/>
|
||||
<joint name="shoulder0"/>
|
||||
<body name="arm1" pos=".09 0 0">
|
||||
<geom name="arm1" type="capsule" fromto="0 0 0 0.09 0 0" size=".01" material="self"/>
|
||||
<joint name="shoulder1" limited="true" range="-80 80"/>
|
||||
<body name="hand" pos=".09 0 0">
|
||||
<geom name="hand" type="capsule" fromto="0 0 0 0.1 0 0" size=".01" material="self"/>
|
||||
<joint name="wrist" limited="true" range="-80 80"/>
|
||||
<body name="finger" pos=".09 0 0">
|
||||
<camera name="hand" pos="0 0 .2" mode="track"/>
|
||||
<geom name="finger" type="sphere" size=".01" material="effector"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<!-- Target -->
|
||||
<geom name="target" pos="0 0 .01" material="target" type="sphere" size=".05"/>
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<motor name="shoulder0" joint="shoulder0"/>
|
||||
<motor name="shoulder1" joint="shoulder1"/>
|
||||
<motor name="wrist" joint="wrist"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
223
tdmpc2/envs/tasks/walker.py
Normal file
223
tdmpc2/envs/tasks/walker.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import os
|
||||
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import common
|
||||
from dm_control.suite import walker
|
||||
from dm_control.utils import rewards
|
||||
from dm_control.utils import io as resources
|
||||
|
||||
_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')
|
||||
|
||||
_YOGA_STAND_HEIGHT = 1.0
|
||||
_YOGA_LIE_DOWN_HEIGHT = 0.08
|
||||
_YOGA_LEGS_UP_HEIGHT = 1.1
|
||||
|
||||
|
||||
def get_model_and_assets():
|
||||
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||
return resources.GetResource(os.path.join(_TASKS_DIR, 'walker.xml')), common.ASSETS
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def walk_backwards(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Walk Backwards task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = BackwardsPlanarWalker(move_speed=walker._WALK_SPEED, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def run_backwards(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Run Backwards task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = BackwardsPlanarWalker(move_speed=walker._RUN_SPEED, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def arabesque(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Arabesque task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = YogaPlanarWalker(goal='arabesque', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def lie_down(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Lie Down task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = YogaPlanarWalker(goal='lie_down', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def legs_up(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Legs Up task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = YogaPlanarWalker(goal='legs_up', random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def headstand(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Headstand task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = YogaPlanarWalker(goal='flip', move_speed=0, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def flip(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Flip task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = YogaPlanarWalker(goal='flip', move_speed=walker._RUN_SPEED*0.75, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
@walker.SUITE.add('custom')
|
||||
def backflip(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||
"""Returns the Backflip task."""
|
||||
physics = walker.Physics.from_xml_string(*get_model_and_assets())
|
||||
task = YogaPlanarWalker(goal='flip', move_speed=-walker._RUN_SPEED*0.75, random=random)
|
||||
environment_kwargs = environment_kwargs or {}
|
||||
return control.Environment(
|
||||
physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
|
||||
**environment_kwargs)
|
||||
|
||||
|
||||
class BackwardsPlanarWalker(walker.PlanarWalker):
|
||||
"""Backwards PlanarWalker task."""
|
||||
def __init__(self, move_speed, random=None):
|
||||
super().__init__(move_speed, random)
|
||||
|
||||
def get_reward(self, physics):
|
||||
standing = rewards.tolerance(physics.torso_height(),
|
||||
bounds=(walker._STAND_HEIGHT, float('inf')),
|
||||
margin=walker._STAND_HEIGHT/2)
|
||||
upright = (1 + physics.torso_upright()) / 2
|
||||
stand_reward = (3*standing + upright) / 4
|
||||
if self._move_speed == 0:
|
||||
return stand_reward
|
||||
else:
|
||||
move_reward = rewards.tolerance(physics.horizontal_velocity(),
|
||||
bounds=(-float('inf'), -self._move_speed),
|
||||
margin=self._move_speed/2,
|
||||
value_at_margin=0.5,
|
||||
sigmoid='linear')
|
||||
return stand_reward * (5*move_reward + 1) / 6
|
||||
|
||||
|
||||
class YogaPlanarWalker(walker.PlanarWalker):
|
||||
"""Yoga PlanarWalker tasks."""
|
||||
|
||||
def __init__(self, goal='arabesque', move_speed=0, random=None):
|
||||
super().__init__(0, random)
|
||||
self._goal = goal
|
||||
self._move_speed = move_speed
|
||||
|
||||
def _arabesque_reward(self, physics):
|
||||
standing = rewards.tolerance(physics.torso_height(),
|
||||
bounds=(_YOGA_STAND_HEIGHT, float('inf')),
|
||||
margin=_YOGA_STAND_HEIGHT/2)
|
||||
left_foot_height = physics.named.data.xpos['left_foot', 'z']
|
||||
right_foot_height = physics.named.data.xpos['right_foot', 'z']
|
||||
left_foot_down = rewards.tolerance(left_foot_height,
|
||||
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
|
||||
margin=_YOGA_STAND_HEIGHT/2)
|
||||
right_foot_up = rewards.tolerance(right_foot_height,
|
||||
bounds=(_YOGA_STAND_HEIGHT, float('inf')),
|
||||
margin=_YOGA_STAND_HEIGHT/2)
|
||||
upright = (1 - physics.torso_upright()) / 2
|
||||
arabesque_reward = (3*standing + left_foot_down + right_foot_up + upright) / 6
|
||||
return arabesque_reward
|
||||
|
||||
def _lie_down_reward(self, physics):
|
||||
torso_down = rewards.tolerance(physics.torso_height(),
|
||||
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
|
||||
margin=_YOGA_LIE_DOWN_HEIGHT/2)
|
||||
thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
|
||||
thigh_down = rewards.tolerance(thigh_height,
|
||||
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
|
||||
margin=_YOGA_LIE_DOWN_HEIGHT/2)
|
||||
feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
|
||||
feet_down = rewards.tolerance(feet_height,
|
||||
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
|
||||
margin=_YOGA_LIE_DOWN_HEIGHT/2)
|
||||
upright = (1 - physics.torso_upright()) / 2
|
||||
lie_down_reward = (3*torso_down + thigh_down + upright) / 5
|
||||
return lie_down_reward
|
||||
|
||||
def _legs_up_reward(self, physics):
|
||||
torso_down = rewards.tolerance(physics.torso_height(),
|
||||
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
|
||||
margin=_YOGA_LIE_DOWN_HEIGHT/2)
|
||||
thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
|
||||
thigh_down = rewards.tolerance(thigh_height,
|
||||
bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
|
||||
margin=_YOGA_LIE_DOWN_HEIGHT/2)
|
||||
feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
|
||||
legs_up = rewards.tolerance(feet_height,
|
||||
bounds=(_YOGA_LEGS_UP_HEIGHT, float('inf')),
|
||||
margin=_YOGA_LEGS_UP_HEIGHT/2)
|
||||
upright = (1 - physics.torso_upright()) / 2
|
||||
legs_up_reward = (3*torso_down + 2*legs_up + thigh_down + upright) / 7
|
||||
return legs_up_reward
|
||||
|
||||
def _flip_reward(self, physics):
|
||||
thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
|
||||
thigh_up = rewards.tolerance(thigh_height,
|
||||
bounds=(_YOGA_STAND_HEIGHT, float('inf')),
|
||||
margin=_YOGA_STAND_HEIGHT/2)
|
||||
feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
|
||||
legs_up = rewards.tolerance(feet_height,
|
||||
bounds=(_YOGA_LEGS_UP_HEIGHT, float('inf')),
|
||||
margin=_YOGA_LEGS_UP_HEIGHT/2)
|
||||
upside_down_reward = (3*legs_up + 2*thigh_up) / 5
|
||||
if self._move_speed == 0:
|
||||
return upside_down_reward
|
||||
move_reward = rewards.tolerance(physics.horizontal_velocity(),
|
||||
bounds=(self._move_speed, float('inf')) if self._move_speed > 0 else (-float('inf'), self._move_speed),
|
||||
margin=abs(self._move_speed)/2,
|
||||
value_at_margin=0.5,
|
||||
sigmoid='linear')
|
||||
return upside_down_reward * (5*move_reward + 1) / 6
|
||||
|
||||
def get_reward(self, physics):
|
||||
if self._goal == 'arabesque':
|
||||
return self._arabesque_reward(physics)
|
||||
elif self._goal == 'lie_down':
|
||||
return self._lie_down_reward(physics)
|
||||
elif self._goal == 'legs_up':
|
||||
return self._legs_up_reward(physics)
|
||||
elif self._goal == 'flip':
|
||||
return self._flip_reward(physics)
|
||||
else:
|
||||
raise NotImplementedError(f'Goal {self._goal} is not implemented.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = legs_up()
|
||||
obs = env.reset()
|
||||
import numpy as np
|
||||
next_obs, reward, done, info = env.step(np.zeros(6))
|
||||
70
tdmpc2/envs/tasks/walker.xml
Normal file
70
tdmpc2/envs/tasks/walker.xml
Normal file
@@ -0,0 +1,70 @@
|
||||
<mujoco model="planar walker">
|
||||
<include file="./common/visual.xml"/>
|
||||
<include file="./common/skybox.xml"/>
|
||||
<include file="./common/materials.xml"/>
|
||||
|
||||
<option timestep="0.0025"/>
|
||||
|
||||
<statistic extent="2" center="0 0 1"/>
|
||||
|
||||
<default>
|
||||
<joint damping=".1" armature="0.01" limited="true" solimplimit="0 .99 .01"/>
|
||||
<geom contype="1" conaffinity="0" friction=".7 .1 .1"/>
|
||||
<motor ctrlrange="-1 1" ctrllimited="true"/>
|
||||
<site size="0.01"/>
|
||||
<default class="walker">
|
||||
<geom material="self" type="capsule"/>
|
||||
<joint axis="0 -1 0"/>
|
||||
</default>
|
||||
</default>
|
||||
|
||||
<worldbody>
|
||||
<geom name="floor" type="plane" conaffinity="1" pos="248 0 0" size="500 .8 .2" material="grid" zaxis="0 0 1"/>
|
||||
<body name="torso" pos="0 0 1.3" childclass="walker">
|
||||
<light name="light" pos="0 0 2" mode="trackcom"/>
|
||||
<camera name="side" pos="0 -2 .7" euler="60 0 0" mode="trackcom"/>
|
||||
<camera name="back" pos="-2 0 .5" xyaxes="0 -1 0 1 0 3" mode="trackcom"/>
|
||||
<joint name="rootz" axis="0 0 1" type="slide" limited="false" armature="0" damping="0"/>
|
||||
<joint name="rootx" axis="1 0 0" type="slide" limited="false" armature="0" damping="0"/>
|
||||
<joint name="rooty" axis="0 1 0" type="hinge" limited="false" armature="0" damping="0"/>
|
||||
<geom name="torso" size="0.07 0.3"/>
|
||||
<body name="right_thigh" pos="0 -.05 -0.3">
|
||||
<joint name="right_hip" range="-20 100"/>
|
||||
<geom name="right_thigh" pos="0 0 -0.225" size="0.05 0.225"/>
|
||||
<body name="right_leg" pos="0 0 -0.7">
|
||||
<joint name="right_knee" pos="0 0 0.25" range="-150 0"/>
|
||||
<geom name="right_leg" size="0.04 0.25"/>
|
||||
<body name="right_foot" pos="0.06 0 -0.25">
|
||||
<joint name="right_ankle" pos="-0.06 0 0" range="-45 45"/>
|
||||
<geom name="right_foot" zaxis="1 0 0" size="0.05 0.1"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="left_thigh" pos="0 .05 -0.3" >
|
||||
<joint name="left_hip" range="-20 100"/>
|
||||
<geom name="left_thigh" pos="0 0 -0.225" size="0.05 0.225"/>
|
||||
<body name="left_leg" pos="0 0 -0.7">
|
||||
<joint name="left_knee" pos="0 0 0.25" range="-150 0"/>
|
||||
<geom name="left_leg" size="0.04 0.25"/>
|
||||
<body name="left_foot" pos="0.06 0 -0.25">
|
||||
<joint name="left_ankle" pos="-0.06 0 0" range="-45 45"/>
|
||||
<geom name="left_foot" zaxis="1 0 0" size="0.05 0.1"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<sensor>
|
||||
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
|
||||
</sensor>
|
||||
|
||||
<actuator>
|
||||
<motor name="right_hip" joint="right_hip" gear="100"/>
|
||||
<motor name="right_knee" joint="right_knee" gear="50"/>
|
||||
<motor name="right_ankle" joint="right_ankle" gear="20"/>
|
||||
<motor name="left_hip" joint="left_hip" gear="100"/>
|
||||
<motor name="left_knee" joint="left_knee" gear="50"/>
|
||||
<motor name="left_ankle" joint="left_ankle" gear="20"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
0
tdmpc2/envs/wrappers/__init__.py
Normal file
0
tdmpc2/envs/wrappers/__init__.py
Normal file
57
tdmpc2/envs/wrappers/multitask.py
Normal file
57
tdmpc2/envs/wrappers/multitask.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class MultitaskWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for multi-task environments.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, envs):
|
||||
super().__init__(envs[0])
|
||||
self.cfg = cfg
|
||||
self.envs = envs
|
||||
self._task = cfg.tasks[0]
|
||||
self._task_idx = 0
|
||||
self._obs_dims = [env.observation_space.shape[0] for env in self.envs]
|
||||
self._action_dims = [env.action_space.shape[0] for env in self.envs]
|
||||
self._episode_lengths = [env.max_episode_steps for env in self.envs]
|
||||
self._obs_shape = (max(self._obs_dims),)
|
||||
self._action_dim = max(self._action_dims)
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=self._obs_shape, dtype=np.float32
|
||||
)
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=-1, high=1, shape=(self._action_dim,), dtype=np.float32
|
||||
)
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
return self._task
|
||||
|
||||
@property
|
||||
def task_idx(self):
|
||||
return self._task_idx
|
||||
|
||||
@property
|
||||
def _env(self):
|
||||
return self.envs[self.task_idx]
|
||||
|
||||
def rand_act(self):
|
||||
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
||||
|
||||
def _pad_obs(self, obs):
|
||||
if obs.shape != self._obs_shape:
|
||||
obs = torch.cat((obs, torch.zeros(self._obs_shape[0]-obs.shape[0], dtype=obs.dtype, device=obs.device)))
|
||||
return obs
|
||||
|
||||
def reset(self, task_idx=-1):
|
||||
self._task_idx = task_idx
|
||||
self._task = self.cfg.tasks[task_idx]
|
||||
self.env = self._env
|
||||
return self._pad_obs(self.env.reset())
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action[:self.env.action_space.shape[0]])
|
||||
return self._pad_obs(obs), reward, done, info
|
||||
40
tdmpc2/envs/wrappers/tensor.py
Normal file
40
tdmpc2/envs/wrappers/tensor.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class TensorWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for converting numpy arrays to torch tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def rand_act(self):
|
||||
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
||||
|
||||
def _try_f32_tensor(self, x):
|
||||
x = torch.from_numpy(x)
|
||||
if x.dtype == torch.float64:
|
||||
x = x.float()
|
||||
return x
|
||||
|
||||
def _obs_to_tensor(self, obs):
|
||||
if isinstance(obs, dict):
|
||||
for k in obs.keys():
|
||||
obs[k] = self._try_f32_tensor(obs[k])
|
||||
else:
|
||||
obs = self._try_f32_tensor(obs)
|
||||
return obs
|
||||
|
||||
def reset(self, task_idx=None):
|
||||
return self._obs_to_tensor(self.env.reset())
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action.numpy())
|
||||
info = defaultdict(float, info)
|
||||
info['success'] = float(info['success'])
|
||||
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
||||
72
tdmpc2/envs/wrappers/time_limit.py
Normal file
72
tdmpc2/envs/wrappers/time_limit.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Wrapper for limiting the time steps of an environment.
|
||||
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
|
||||
|
||||
class TimeLimit(gym.Wrapper):
|
||||
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
|
||||
|
||||
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
|
||||
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
|
||||
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
|
||||
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
|
||||
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
|
||||
|
||||
Example:
|
||||
>>> from gym.envs.classic_control import CartPoleEnv
|
||||
>>> from gym.wrappers import TimeLimit
|
||||
>>> env = CartPoleEnv()
|
||||
>>> env = TimeLimit(env, max_episode_steps=1000)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
|
||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
||||
"""
|
||||
super().__init__(env)
|
||||
if max_episode_steps is None and self.env.spec is not None:
|
||||
max_episode_steps = env.spec.max_episode_steps
|
||||
if self.env.spec is not None:
|
||||
self.env.spec.max_episode_steps = max_episode_steps
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._elapsed_steps = None
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
|
||||
|
||||
Args:
|
||||
action: The environment step action
|
||||
|
||||
Returns:
|
||||
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
|
||||
when truncated (the number of steps elapsed >= max episode steps) or
|
||||
"TimeLimit.truncated"=False if the environment terminated
|
||||
"""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
self._elapsed_steps += 1
|
||||
if self._elapsed_steps >= self._max_episode_steps:
|
||||
# TimeLimit.truncated key may have been already set by the environment
|
||||
# do not overwrite it
|
||||
episode_truncated = not done or info.get("TimeLimit.truncated", False)
|
||||
info["TimeLimit.truncated"] = episode_truncated
|
||||
done = True
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
||||
|
||||
Args:
|
||||
**kwargs: The kwargs to reset the environment with
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self._elapsed_steps = 0
|
||||
return self.env.reset(**kwargs)
|
||||
103
tdmpc2/evaluate.py
Executable file
103
tdmpc2/evaluate.py
Executable file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
os.environ['MUJOCO_GL'] = 'egl'
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from termcolor import colored
|
||||
|
||||
from common.parser import parse_cfg
|
||||
from common.seed import set_seed
|
||||
from envs import make_env
|
||||
from tdmpc2 import TDMPC2
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def evaluate(cfg: dict):
|
||||
"""
|
||||
Script for evaluating a single-task / multi-task TD-MPC2 checkpoint.
|
||||
|
||||
Most relevant args:
|
||||
`task`: task name (or mt30/mt80 for multi-task evaluation)
|
||||
`model_size`: model size, must be one of `[1, 5, 19, 48, 317]` (default: 5)
|
||||
`checkpoint`: path to model checkpoint to load
|
||||
`eval_episodes`: number of episodes to evaluate on per task (default: 10)
|
||||
`save_video`: whether to save a video of the evaluation (default: True)
|
||||
`seed`: random seed (default: 1)
|
||||
|
||||
See config.yaml for a full list of args.
|
||||
|
||||
Example usage:
|
||||
````
|
||||
$ python evaluate.py task=mt80 model_size=48 checkpoint=/path/to/mt80-48M.pt
|
||||
$ python evaluate.py task=mt30 model_size=317 checkpoint=/path/to/mt30-317M.pt
|
||||
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true
|
||||
```
|
||||
"""
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.eval_episodes > 0, 'Must evaluate at least 1 episode.'
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
|
||||
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
|
||||
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))
|
||||
print(colored('To evaluate a multi-task model, use task=mt80 or task=mt30.', 'red', attrs=['bold']))
|
||||
|
||||
# Make environment
|
||||
env = make_env(cfg)
|
||||
|
||||
# Load agent
|
||||
agent = TDMPC2(cfg)
|
||||
assert os.path.exists(cfg.checkpoint), f'Checkpoint {cfg.checkpoint} not found! Must be a valid filepath.'
|
||||
agent.load(cfg.checkpoint)
|
||||
|
||||
# Evaluate
|
||||
if cfg.multitask:
|
||||
print(colored(f'Evaluating agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold']))
|
||||
else:
|
||||
print(colored(f'Evaluating agent on {cfg.task}:', 'yellow', attrs=['bold']))
|
||||
if cfg.save_video:
|
||||
video_dir = os.path.join(cfg.work_dir, 'videos')
|
||||
os.makedirs(video_dir, exist_ok=True)
|
||||
scores = []
|
||||
tasks = cfg.tasks if cfg.multitask else [cfg.task]
|
||||
for task_idx, task in enumerate(tasks):
|
||||
if not cfg.multitask:
|
||||
task_idx = None
|
||||
ep_rewards, ep_successes = [], []
|
||||
for i in range(cfg.eval_episodes):
|
||||
obs, done, ep_reward, t = env.reset(task_idx=task_idx), False, 0, 0
|
||||
if cfg.save_video:
|
||||
frames = [env.render()]
|
||||
while not done:
|
||||
action = agent.act(obs, t0=t==0, task=task_idx)
|
||||
obs, reward, done, info = env.step(action)
|
||||
ep_reward += reward
|
||||
t += 1
|
||||
if cfg.save_video:
|
||||
frames.append(env.render())
|
||||
ep_rewards.append(ep_reward)
|
||||
ep_successes.append(info['success'])
|
||||
if cfg.save_video:
|
||||
imageio.mimsave(
|
||||
os.path.join(video_dir, f'{task}-{i}.mp4'), frames, fps=15)
|
||||
ep_rewards = np.mean(ep_rewards)
|
||||
ep_successes = np.mean(ep_successes)
|
||||
if cfg.multitask:
|
||||
scores.append(ep_successes*100 if task.startswith('mw-') else ep_rewards/10)
|
||||
print(colored(f' {task:<22}' \
|
||||
f'\tR: {ep_rewards:.01f} ' \
|
||||
f'\tS: {ep_successes:.02f}', 'yellow'))
|
||||
if cfg.multitask:
|
||||
print(colored(f'Normalized score: {np.mean(scores):.02f}', 'yellow', attrs=['bold']))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
evaluate()
|
||||
286
tdmpc2/tdmpc2.py
Executable file
286
tdmpc2/tdmpc2.py
Executable file
@@ -0,0 +1,286 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from common import math
|
||||
from common.scale import RunningScale
|
||||
from common.world_model import WorldModel
|
||||
|
||||
|
||||
class TDMPC2:
|
||||
"""
|
||||
TD-MPC2 agent. Implements training + inference.
|
||||
Can be used for both single-task and multi-task experiments.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.device = torch.device('cuda')
|
||||
self.model = WorldModel(cfg).to(self.device)
|
||||
self.optim = torch.optim.Adam([
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
{'params': self.model._dynamics.parameters()},
|
||||
{'params': self.model._reward.parameters()},
|
||||
{'params': self.model._Qs.parameters()},
|
||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
|
||||
], lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5)
|
||||
self.model.eval()
|
||||
self.scale = RunningScale(cfg)
|
||||
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda'
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
|
||||
def _get_discount(self, episode_length):
|
||||
"""
|
||||
Returns discount factor for a given episode length.
|
||||
Simple heuristic that scales discount linearly with episode length.
|
||||
Default values should work well for most tasks, but can be changed as needed.
|
||||
|
||||
Args:
|
||||
episode_length (int): Length of the episode. Assumes episodes are of fixed length.
|
||||
|
||||
Returns:
|
||||
float: Discount factor for the task.
|
||||
"""
|
||||
frac = episode_length/self.cfg.discount_denom
|
||||
return min(max((frac-1)/(frac), self.cfg.discount_min), self.cfg.discount_max)
|
||||
|
||||
def save(self, fp):
|
||||
"""
|
||||
Save state dict of the agent to filepath.
|
||||
|
||||
Args:
|
||||
fp (str): Filepath to save state dict to.
|
||||
"""
|
||||
torch.save({"model": self.model.state_dict()}, fp)
|
||||
|
||||
def load(self, fp):
|
||||
"""
|
||||
Load a saved state dict from filepath (or dictionary) into current agent.
|
||||
|
||||
Args:
|
||||
fp (str or dict): Filepath or state dict to load.
|
||||
"""
|
||||
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
|
||||
self.model.load_state_dict(state_dict["model"])
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, eval_mode=False, task=None):
|
||||
"""
|
||||
Select an action by planning in the latent space of the world model.
|
||||
|
||||
Args:
|
||||
obs (torch.Tensor): Observation from the environment.
|
||||
t0 (bool): Whether this is the first observation in the episode.
|
||||
eval_mode (bool): Whether to use the mean of the action distribution.
|
||||
task (int): Task index (only used for multi-task experiments).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
|
||||
if task is not None:
|
||||
task = torch.tensor([task], device=self.device)
|
||||
z = self.model.encode(obs, task)
|
||||
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
|
||||
return a.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def _estimate_value(self, z, actions, task):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1
|
||||
for t in range(self.cfg.horizon):
|
||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||
z = self.model.next(z, actions[t], task)
|
||||
G += discount * reward
|
||||
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z, t0=False, eval_mode=False, task=None):
|
||||
"""
|
||||
Plan a sequence of actions using the learned world model.
|
||||
|
||||
Args:
|
||||
z (torch.Tensor): Latent state from which to plan.
|
||||
t0 (bool): Whether this is the first observation in the episode.
|
||||
eval_mode (bool): Whether to use the mean of the action distribution.
|
||||
task (Torch.Tensor): Task index (only used for multi-task experiments).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
# Sample policy trajectories
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||
for t in range(self.cfg.horizon-1):
|
||||
pi_actions[t] = self.model.pi(_z, task)[1]
|
||||
_z = self.model.next(_z, pi_actions[t], task)
|
||||
pi_actions[-1] = self.model.pi(_z, task)[1]
|
||||
|
||||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples, 1)
|
||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||
if not t0:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
||||
|
||||
# Iterate MPPI
|
||||
for i in range(self.cfg.iterations):
|
||||
|
||||
# Sample actions
|
||||
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \
|
||||
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \
|
||||
.clamp(-1, 1)
|
||||
if self.cfg.multitask:
|
||||
actions = actions * self.model._action_masks[task]
|
||||
|
||||
# Compute elite actions
|
||||
value = self._estimate_value(z, actions, task).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0)[0]
|
||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
||||
score /= score.sum(0)
|
||||
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \
|
||||
.clamp_(self.cfg.min_std, self.cfg.max_std)
|
||||
if self.cfg.multitask:
|
||||
mean = mean * self.model._action_masks[task]
|
||||
std = std * self.model._action_masks[task]
|
||||
|
||||
# Select action
|
||||
score = score.squeeze(1).cpu().numpy()
|
||||
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
||||
self._prev_mean = mean
|
||||
a, std = actions[0], std[0]
|
||||
if not eval_mode:
|
||||
a += std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||
return a.clamp_(-1, 1)
|
||||
|
||||
def update_pi(self, zs, task):
|
||||
"""
|
||||
Update policy using a sequence of latent states.
|
||||
|
||||
Args:
|
||||
zs (torch.Tensor): Sequence of latent states.
|
||||
task (torch.Tensor): Task index (only used for multi-task experiments).
|
||||
|
||||
Returns:
|
||||
float: Loss of the policy update.
|
||||
"""
|
||||
self.pi_optim.zero_grad(set_to_none=True)
|
||||
self.model.track_q_grad(False)
|
||||
_, pis, log_pis, _ = self.model.pi(zs, task)
|
||||
qs = self.model.Q(zs, pis, task, return_type='avg')
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
# Loss is a weighted sum of Q-values
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
|
||||
pi_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||
self.pi_optim.step()
|
||||
self.model.track_q_grad(True)
|
||||
|
||||
return pi_loss.item()
|
||||
|
||||
@torch.no_grad()
|
||||
def _td_target(self, next_z, reward, task):
|
||||
"""
|
||||
Compute the TD-target from a reward and the observation at the following time step.
|
||||
|
||||
Args:
|
||||
next_z (torch.Tensor): Latent state at the following time step.
|
||||
reward (torch.Tensor): Reward at the current time step.
|
||||
task (torch.Tensor): Task index (only used for multi-task experiments).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: TD-target.
|
||||
"""
|
||||
pi = self.model.pi(next_z, task)[1]
|
||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
||||
|
||||
def update(self, buffer):
|
||||
"""
|
||||
Main update function. Corresponds to one iteration of model learning.
|
||||
|
||||
Args:
|
||||
buffer (common.buffer.Buffer): Replay buffer.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of training statistics.
|
||||
"""
|
||||
obs, action, reward, task = buffer.sample()
|
||||
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
next_z = self.model.encode(obs[1:], task)
|
||||
td_targets = self._td_target(next_z, reward, task)
|
||||
|
||||
# Prepare for update
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
self.model.train()
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(self.cfg.horizon+1, self.cfg.batch_size, self.cfg.latent_dim, device=self.device)
|
||||
z = self.model.encode(obs[0], task)
|
||||
zs[0] = z
|
||||
consistency_loss = 0
|
||||
for t in range(self.cfg.horizon):
|
||||
z = self.model.next(z, action[t], task)
|
||||
consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t
|
||||
zs[t+1] = z
|
||||
|
||||
# Predictions
|
||||
_zs = zs[:-1]
|
||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||
reward_preds = self.model.reward(_zs, action, task)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, value_loss = 0, 0
|
||||
for t in range(self.cfg.horizon):
|
||||
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
||||
for q in range(self.cfg.num_q):
|
||||
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
||||
consistency_loss *= (1/self.cfg.horizon)
|
||||
reward_loss *= (1/self.cfg.horizon)
|
||||
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss +
|
||||
self.cfg.reward_coef * reward_loss +
|
||||
self.cfg.value_coef * value_loss
|
||||
)
|
||||
|
||||
# Update model
|
||||
total_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
|
||||
self.optim.step()
|
||||
|
||||
# Update policy
|
||||
pi_loss = self.update_pi(zs.detach(), task)
|
||||
|
||||
# Update target Q-functions
|
||||
self.model.soft_update_target_Q()
|
||||
|
||||
# Return training statistics
|
||||
self.model.eval()
|
||||
return {
|
||||
"consistency_loss": float(consistency_loss.mean().item()),
|
||||
"reward_loss": float(reward_loss.mean().item()),
|
||||
"value_loss": float(value_loss.mean().item()),
|
||||
"pi_loss": pi_loss,
|
||||
"total_loss": float(total_loss.mean().item()),
|
||||
"grad_norm": float(grad_norm),
|
||||
"pi_scale": float(self.scale.value),
|
||||
}
|
||||
61
tdmpc2/train.py
Executable file
61
tdmpc2/train.py
Executable file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
os.environ['MUJOCO_GL'] = 'egl'
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import torch
|
||||
|
||||
import hydra
|
||||
from termcolor import colored
|
||||
|
||||
from common.parser import parse_cfg
|
||||
from common.seed import set_seed
|
||||
from common.buffer import Buffer
|
||||
from envs import make_env
|
||||
from tdmpc2 import TDMPC2
|
||||
from trainer.offline_trainer import OfflineTrainer
|
||||
from trainer.online_trainer import OnlineTrainer
|
||||
from common.logger import Logger
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def train(cfg: dict):
|
||||
"""
|
||||
Script for training single-task / multi-task TD-MPC2 agents.
|
||||
|
||||
Most relevant args:
|
||||
`task`: task name (or mt30/mt80 for multi-task training)
|
||||
`model_size`: model size, must be one of `[1, 5, 19, 48, 317]` (default: 5)
|
||||
`steps`: number of training/environment steps (default: 10M)
|
||||
`seed`: random seed (default: 1)
|
||||
|
||||
See config.yaml for a full list of args.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
$ python train.py task=mt80 model_size=48
|
||||
$ python train.py task=mt30 model_size=317
|
||||
$ python train.py task=dog-run steps=7000000
|
||||
```
|
||||
"""
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
|
||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||
trainer = trainer_cls(
|
||||
cfg=cfg,
|
||||
env=make_env(cfg),
|
||||
agent=TDMPC2(cfg),
|
||||
buffer=Buffer(cfg),
|
||||
logger=Logger(cfg),
|
||||
)
|
||||
trainer.train()
|
||||
print('\nTraining completed successfully')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
0
tdmpc2/trainer/__init__.py
Normal file
0
tdmpc2/trainer/__init__.py
Normal file
19
tdmpc2/trainer/base.py
Executable file
19
tdmpc2/trainer/base.py
Executable file
@@ -0,0 +1,19 @@
|
||||
class Trainer:
|
||||
"""Base trainer class for TD-MPC2."""
|
||||
|
||||
def __init__(self, cfg, env, agent, buffer, logger):
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.buffer = buffer
|
||||
self.logger = logger
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
print('Architecture:', self.agent.model)
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
raise NotImplementedError
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
raise NotImplementedError
|
||||
92
tdmpc2/trainer/offline_trainer.py
Executable file
92
tdmpc2/trainer/offline_trainer.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from time import time
|
||||
from pathlib import Path
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from common.buffer import Buffer
|
||||
from trainer.base import Trainer
|
||||
|
||||
|
||||
class OfflineTrainer(Trainer):
|
||||
"""Trainer class for multi-task offline TD-MPC2 training."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._start_time = time()
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
results = dict()
|
||||
for task_idx in tqdm(range(len(self.cfg.tasks)), desc='Evaluating'):
|
||||
ep_rewards, ep_successes = [], []
|
||||
for _ in range(self.cfg.eval_episodes):
|
||||
obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0
|
||||
while not done:
|
||||
action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
ep_reward += reward
|
||||
t += 1
|
||||
ep_rewards.append(ep_reward)
|
||||
ep_successes.append(info['success'])
|
||||
results.update({
|
||||
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
|
||||
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
|
||||
return results
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
||||
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
||||
|
||||
# Load data
|
||||
assert self.cfg.task in self.cfg.data_dir, \
|
||||
f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \
|
||||
f'please double-check your config.'
|
||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||
fps = sorted(glob(str(fp)))
|
||||
assert len(fps) > 0, f'No data found at {fp}'
|
||||
print(f'Found {len(fps)} files in {fp}')
|
||||
|
||||
# Create buffer for sampling
|
||||
_cfg = deepcopy(self.cfg)
|
||||
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
|
||||
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
|
||||
_cfg.steps = _cfg.buffer_size
|
||||
self.buffer = Buffer(_cfg)
|
||||
for fp in tqdm(fps, desc='Loading data'):
|
||||
td = torch.load(fp)
|
||||
assert td.shape[1] == _cfg.episode_length, \
|
||||
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
||||
f'please double-check your config.'
|
||||
for i in range(len(td)):
|
||||
self.buffer.add(td[i])
|
||||
assert self.buffer.num_eps == self.buffer.capacity, \
|
||||
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
|
||||
|
||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||
metrics = {}
|
||||
for i in range(self.cfg.steps):
|
||||
|
||||
# Update agent
|
||||
train_metrics = self.agent.update(self.buffer)
|
||||
|
||||
# Evaluate agent periodically
|
||||
if i % self.cfg.eval_freq == 0 or i == 10_000:
|
||||
metrics = {
|
||||
'iteration': i,
|
||||
'total_time': time() - self._start_time,
|
||||
}
|
||||
metrics.update(train_metrics)
|
||||
if i % self.cfg.eval_freq == 0:
|
||||
metrics.update(self.eval())
|
||||
self.logger.pprint_multitask(metrics, self.cfg)
|
||||
if i > 0:
|
||||
self.logger.save_agent(self.agent, identifier=f'{i}')
|
||||
self.logger.log(metrics, 'pretrain')
|
||||
|
||||
self.logger.finish(self.agent)
|
||||
117
tdmpc2/trainer/online_trainer.py
Executable file
117
tdmpc2/trainer/online_trainer.py
Executable file
@@ -0,0 +1,117 @@
|
||||
from time import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.tensordict import TensorDict
|
||||
|
||||
from trainer.base import Trainer
|
||||
|
||||
|
||||
class OnlineTrainer(Trainer):
|
||||
"""Trainer class for single-task online TD-MPC2 training."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._step = 0
|
||||
self._ep_idx = 0
|
||||
self._start_time = time()
|
||||
|
||||
def common_metrics(self):
|
||||
"""Return a dictionary of current metrics."""
|
||||
return dict(
|
||||
step=self._step,
|
||||
episode=self._ep_idx,
|
||||
total_time=time() - self._start_time,
|
||||
)
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
ep_rewards, ep_successes = [], []
|
||||
for i in range(self.cfg.eval_episodes):
|
||||
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.init(self.env, enabled=(i==0))
|
||||
while not done:
|
||||
action = self.agent.act(obs, t0=t==0, eval_mode=True)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
ep_reward += reward
|
||||
t += 1
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.record(self.env)
|
||||
ep_rewards.append(ep_reward)
|
||||
ep_successes.append(info['success'])
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.save(self._step)
|
||||
return dict(
|
||||
episode_reward=np.nanmean(ep_rewards),
|
||||
episode_success=np.nanmean(ep_successes),
|
||||
)
|
||||
|
||||
def to_td(self, obs, action=None, reward=None):
|
||||
"""Creates a TensorDict for a new episode."""
|
||||
if isinstance(obs, dict):
|
||||
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
|
||||
else:
|
||||
obs = obs.unsqueeze(0).cpu()
|
||||
if action is None:
|
||||
action = torch.empty_like(self.env.rand_act())
|
||||
if reward is None:
|
||||
reward = torch.tensor(float('nan'))
|
||||
td = TensorDict(dict(
|
||||
obs=obs,
|
||||
action=action.unsqueeze(0),
|
||||
reward=reward.unsqueeze(0),
|
||||
), batch_size=(1,))
|
||||
return td
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
train_metrics, done, eval_next = {}, True, True
|
||||
while self._step <= self.cfg.steps:
|
||||
|
||||
# Evaluate agent periodically
|
||||
if self._step % self.cfg.eval_freq == 0:
|
||||
eval_next = True
|
||||
|
||||
# Reset environment
|
||||
if done:
|
||||
if eval_next:
|
||||
eval_metrics = self.eval()
|
||||
eval_metrics.update(self.common_metrics())
|
||||
self.logger.log(eval_metrics, 'eval')
|
||||
eval_next = False
|
||||
|
||||
if self._step > 0:
|
||||
train_metrics.update(
|
||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||
episode_success=info['success'],
|
||||
)
|
||||
train_metrics.update(self.common_metrics())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||
|
||||
obs = self.env.reset()
|
||||
self._tds = [self.to_td(obs)]
|
||||
|
||||
# Collect experience
|
||||
if self._step > self.cfg.seed_steps:
|
||||
action = self.agent.act(obs, t0=len(self._tds)==1)
|
||||
else:
|
||||
action = self.env.rand_act()
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._tds.append(self.to_td(obs, action, reward))
|
||||
|
||||
# Update agent
|
||||
if self._step >= self.cfg.seed_steps:
|
||||
if self._step == self.cfg.seed_steps:
|
||||
num_updates = self.cfg.seed_steps
|
||||
print('Pretraining agent on seed data...')
|
||||
else:
|
||||
num_updates = 1
|
||||
for _ in range(num_updates):
|
||||
_train_metrics = self.agent.update(self.buffer)
|
||||
train_metrics.update(_train_metrics)
|
||||
|
||||
self._step += 1
|
||||
|
||||
self.logger.finish(self.agent)
|
||||
Reference in New Issue
Block a user