This commit is contained in:
Nicklas Hansen
2024-02-11 14:41:20 -08:00
parent 1bfbcb7794
commit fa41a3e450
8 changed files with 128 additions and 59 deletions

View File

@@ -81,11 +81,13 @@ class Buffer():
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs)
td = td.permute(1, 0)
if self._num_eps == 0:
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
self._buffer = self._init(td[0])
for i in range(self.cfg.num_envs):
self._buffer.extend(td[i])
self._num_eps += self.cfg.num_envs
return self._num_eps
def sample(self):

View File

@@ -4,6 +4,7 @@ defaults:
# environment
task: dog-run
obs: state
num_envs: 1
# evaluation
checkpoint: ???
@@ -13,6 +14,7 @@ eval_freq: 50000
# training
steps: 10_000_000
batch_size: 256
steps_per_update: 1
reward_coef: 0.1
value_coef: 0.1
consistency_coef: 20

View File

@@ -6,6 +6,8 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper
from envs.wrappers.vectorized import Vectorized
def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
@@ -59,16 +61,19 @@ def make_env(cfg):
gym.logger.set_level(40)
if cfg.multitask:
env = make_multitask_env(cfg)
else:
env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try:
env = fn(cfg)
break
except ValueError:
pass
if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \
'Vectorized environments only support state observations.'
env = Vectorized(cfg, fn)
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
@@ -78,5 +83,5 @@ def make_env(cfg):
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length)
cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs
return env

View File

@@ -177,6 +177,9 @@ class TimeStepToGymWrapper:
camera_id = dict(quadruped=2).get(self.domain, camera_id)
return self.env.physics.render(height, width, camera_id)
def close(self):
self.env.close()
def make_env(cfg):
"""

View File

@@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized'
def rand_act(self):
if self._wrapped_vectorized:
return self.env.rand_act()
return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x):
@@ -30,11 +33,23 @@ class TensorWrapper(gym.Wrapper):
obs = self._try_f32_tensor(obs)
return obs
def reset(self, task_idx=None):
return self._obs_to_tensor(self.env.reset())
def reset(self, task_idx=None, **kwargs):
if self._wrapped_vectorized:
obs = self.env.reset(**kwargs)
else:
obs = self.env.reset()
return self._obs_to_tensor(obs)
def step(self, action):
obs, reward, done, info = self.env.step(action.numpy())
info = defaultdict(float, info)
info['success'] = float(info['success'])
def step(self, action, **kwargs):
if self._wrapped_vectorized:
obs, reward, done, info = self.env.step(action.numpy(), **kwargs)
else:
obs, reward, done, info = self.env.step(action.numpy())
if isinstance(info, tuple):
info = {key: torch.stack([torch.tensor(d[key]) for d in info]) for key in info[0].keys()}
if 'success' not in info.keys():
info['success'] = torch.zeros(len(done))
else:
info = defaultdict(float, info)
info['success'] = float(info['success'])
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info

View File

@@ -0,0 +1,40 @@
from copy import deepcopy
from gym.vector import AsyncVectorEnv
import numpy as np
import torch
class Vectorized():
"""
Vectorized environment for TD-MPC2 online training.
"""
def __init__(self, cfg, env_fn):
super().__init__()
self.cfg = cfg
def make():
_cfg = deepcopy(cfg)
_cfg.num_envs = 1
_cfg.seed = cfg.seed + np.random.randint(1000)
return env_fn(_cfg)
print(f'Creating {cfg.num_envs} environments...')
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
env = make()
self.observation_space = env.observation_space
self.action_space = env.action_space
self.max_episode_steps = env.max_episode_steps
def rand_act(self):
return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
def render(self, *args, **kwargs):
return self.env.render(*args, **kwargs)

View File

@@ -114,8 +114,8 @@ class TDMPC2(torch.nn.Module):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
z = self.model.next(z, actions[:, t], task)
G = G + discount * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update
@@ -138,45 +138,45 @@ class TDMPC2(torch.nn.Module):
# Sample policy trajectories
z = self.model.encode(obs, task)
if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1)
pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1]
_z = self.model.next(_z, pi_actions[t], task)
pi_actions[-1] = self.model.pi(_z, task)[1]
pi_actions[:,t] = self.model.pi(_z, task)[1]
_z = self.model.next(_z, pi_actions[:,t], task)
pi_actions[:,-1] = self.model.pi(_z, task)[1]
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = self.cfg.max_std*torch.ones(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
if not t0:
mean[:-1] = self._prev_mean[1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
mean[:, :-1] = self._prev_mean[:, 1:]
actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions
actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI
for _ in range(self.cfg.iterations):
# Sample actions
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
r = torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
actions_sample = mean.unsqueeze(2) + std.unsqueeze(2) * r
actions_sample = actions_sample.clamp(-1, 1)
actions[:, self.cfg.num_pi_trajs:] = actions_sample
actions[:, :, self.cfg.num_pi_trajs:] = actions_sample
if self.cfg.multitask:
actions = actions * self.model._action_masks[task]
# Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
elite_value, elite_actions = value[elite_idxs], actions[:, :, elite_idxs]
# Update parameters
max_value = elite_value.max(0).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
score = score / score.sum(0)
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
max_value = elite_value.max(1).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
score = score / score.sum(1)
mean = (score.unsqueeze(1) * elite_actions).sum(dim=2) / (score.sum(1) + 1e-9)
std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / (score.sum(1) + 1e-9)).sqrt()
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask:
mean = mean * self.model._action_masks[task]
@@ -185,11 +185,12 @@ class TDMPC2(torch.nn.Module):
# Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
a, std = actions[0], std[0]
action, std = actions[:, 0], std[:, 0]
if not eval_mode:
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
action = action + std * torch.randn(self.cfg.action_dim, device=std.device)
self._prev_mean.copy_(mean)
return a.clamp(-1, 1)
return action.clamp(-1, 1)
def update_pi(self, zs, task):
"""

View File

@@ -1,6 +1,5 @@
from time import time
import numpy as np
import torch
from tensordict.tensordict import TensorDict
from trainer.base import Trainer
@@ -25,12 +24,12 @@ class OnlineTrainer(Trainer):
def eval(self):
"""Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], []
for i in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
ep_rewards = []
for i in range(self.cfg.eval_episodes // self.cfg.num_envs):
obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0
if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0))
while not done:
while not done.any():
torch.compiler.cudagraph_mark_step_begin()
action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action)
@@ -38,13 +37,13 @@ class OnlineTrainer(Trainer):
t += 1
if self.cfg.save_video:
self.logger.video.record(self.env)
assert done.all(), 'Vectorized environments must reset all environments at once.'
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
if self.cfg.save_video:
self.logger.video.save(self._step)
return dict(
episode_reward=np.nanmean(ep_rewards),
episode_success=np.nanmean(ep_successes),
episode_reward=torch.cat(ep_rewards).mean(),
episode_success=info['success'].mean(),
)
def to_td(self, obs, action=None, reward=None):
@@ -56,24 +55,25 @@ class OnlineTrainer(Trainer):
if action is None:
action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None:
reward = torch.tensor(float('nan'))
td = TensorDict(
reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
td = TensorDict(dict(
obs=obs,
action=action.unsqueeze(0),
reward=reward.unsqueeze(0),
batch_size=(1,))
), batch_size=(1, self.cfg.num_envs,))
return td
def train(self):
"""Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, False
train_metrics, done, eval_next = {}, torch.tensor(True), True
while self._step <= self.cfg.steps:
# Evaluate agent periodically
if self._step % self.cfg.eval_freq == 0:
eval_next = True
# Reset environment
if done:
if done.any():
assert done.all(), 'Vectorized environments must reset all environments at once.'
if eval_next:
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())
@@ -81,13 +81,14 @@ class OnlineTrainer(Trainer):
eval_next = False
if self._step > 0:
tds = torch.cat(self._tds)
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
episode_reward=tds['reward'].nansum(0).mean(),
episode_success=info['success'].nanmean(),
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))
self._ep_idx = self.buffer.add(tds)
obs = self.env.reset()
self._tds = [self.to_td(obs)]
@@ -103,14 +104,14 @@ class OnlineTrainer(Trainer):
# Update agent
if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps
num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update)
print('Pretraining agent on seed data...')
else:
num_updates = 1
num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update))
for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
self._step += 1
self._step += self.cfg.num_envs
self.logger.finish(self.agent)