This commit is contained in:
Nicklas Hansen
2024-02-11 14:41:20 -08:00
parent 1bfbcb7794
commit fa41a3e450
8 changed files with 128 additions and 59 deletions

View File

@@ -81,11 +81,13 @@ class Buffer():
def add(self, td): def add(self, td):
"""Add an episode to the buffer.""" """Add an episode to the buffer."""
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs)
td = td.permute(1, 0)
if self._num_eps == 0: if self._num_eps == 0:
self._buffer = self._init(td) self._buffer = self._init(td[0])
self._buffer.extend(td) for i in range(self.cfg.num_envs):
self._num_eps += 1 self._buffer.extend(td[i])
self._num_eps += self.cfg.num_envs
return self._num_eps return self._num_eps
def sample(self): def sample(self):

View File

@@ -4,6 +4,7 @@ defaults:
# environment # environment
task: dog-run task: dog-run
obs: state obs: state
num_envs: 1
# evaluation # evaluation
checkpoint: ??? checkpoint: ???
@@ -13,6 +14,7 @@ eval_freq: 50000
# training # training
steps: 10_000_000 steps: 10_000_000
batch_size: 256 batch_size: 256
steps_per_update: 1
reward_coef: 0.1 reward_coef: 0.1
value_coef: 0.1 value_coef: 0.1
consistency_coef: 20 consistency_coef: 20

View File

@@ -6,6 +6,8 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
from envs.wrappers.vectorized import Vectorized
def missing_dependencies(task): def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
@@ -59,16 +61,19 @@ def make_env(cfg):
gym.logger.set_level(40) gym.logger.set_level(40)
if cfg.multitask: if cfg.multitask:
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
break
except ValueError: except ValueError:
pass pass
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \
'Vectorized environments only support state observations.'
env = Vectorized(cfg, fn)
env = TensorWrapper(env) env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb': if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env) env = PixelWrapper(cfg, env)
@@ -78,5 +83,5 @@ def make_env(cfg):
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0] cfg.action_dim = env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length) cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs
return env return env

View File

@@ -177,6 +177,9 @@ class TimeStepToGymWrapper:
camera_id = dict(quadruped=2).get(self.domain, camera_id) camera_id = dict(quadruped=2).get(self.domain, camera_id)
return self.env.physics.render(height, width, camera_id) return self.env.physics.render(height, width, camera_id)
def close(self):
self.env.close()
def make_env(cfg): def make_env(cfg):
""" """

View File

@@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized'
def rand_act(self): def rand_act(self):
if self._wrapped_vectorized:
return self.env.rand_act()
return torch.from_numpy(self.action_space.sample().astype(np.float32)) return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x): def _try_f32_tensor(self, x):
@@ -30,11 +33,23 @@ class TensorWrapper(gym.Wrapper):
obs = self._try_f32_tensor(obs) obs = self._try_f32_tensor(obs)
return obs return obs
def reset(self, task_idx=None): def reset(self, task_idx=None, **kwargs):
return self._obs_to_tensor(self.env.reset()) if self._wrapped_vectorized:
obs = self.env.reset(**kwargs)
else:
obs = self.env.reset()
return self._obs_to_tensor(obs)
def step(self, action): def step(self, action, **kwargs):
obs, reward, done, info = self.env.step(action.numpy()) if self._wrapped_vectorized:
info = defaultdict(float, info) obs, reward, done, info = self.env.step(action.numpy(), **kwargs)
info['success'] = float(info['success']) else:
obs, reward, done, info = self.env.step(action.numpy())
if isinstance(info, tuple):
info = {key: torch.stack([torch.tensor(d[key]) for d in info]) for key in info[0].keys()}
if 'success' not in info.keys():
info['success'] = torch.zeros(len(done))
else:
info = defaultdict(float, info)
info['success'] = float(info['success'])
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info

View File

@@ -0,0 +1,40 @@
from copy import deepcopy
from gym.vector import AsyncVectorEnv
import numpy as np
import torch
class Vectorized():
"""
Vectorized environment for TD-MPC2 online training.
"""
def __init__(self, cfg, env_fn):
super().__init__()
self.cfg = cfg
def make():
_cfg = deepcopy(cfg)
_cfg.num_envs = 1
_cfg.seed = cfg.seed + np.random.randint(1000)
return env_fn(_cfg)
print(f'Creating {cfg.num_envs} environments...')
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
env = make()
self.observation_space = env.observation_space
self.action_space = env.action_space
self.max_episode_steps = env.max_episode_steps
def rand_act(self):
return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
def render(self, *args, **kwargs):
return self.env.render(*args, **kwargs)

View File

@@ -114,8 +114,8 @@ class TDMPC2(torch.nn.Module):
"""Estimate value of a trajectory starting at latent state z and executing given actions.""" """Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1 G, discount = 0, 1
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[:, t], task)
G = G + discount * reward G = G + discount * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update discount = discount * discount_update
@@ -138,45 +138,45 @@ class TDMPC2(torch.nn.Module):
# Sample policy trajectories # Sample policy trajectories
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1) _z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1): for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1] pi_actions[:,t] = self.model.pi(_z, task)[1]
_z = self.model.next(_z, pi_actions[t], task) _z = self.model.next(_z, pi_actions[:,t], task)
pi_actions[-1] = self.model.pi(_z, task)[1] pi_actions[:,-1] = self.model.pi(_z, task)[1]
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) std = self.cfg.max_std*torch.ones(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
if not t0: if not t0:
mean[:-1] = self._prev_mean[1:] mean[:, :-1] = self._prev_mean[:, 1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI # Iterate MPPI
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
# Sample actions # Sample actions
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) r = torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r actions_sample = mean.unsqueeze(2) + std.unsqueeze(2) * r
actions_sample = actions_sample.clamp(-1, 1) actions_sample = actions_sample.clamp(-1, 1)
actions[:, self.cfg.num_pi_trajs:] = actions_sample actions[:, :, self.cfg.num_pi_trajs:] = actions_sample
if self.cfg.multitask: if self.cfg.multitask:
actions = actions * self.model._action_masks[task] actions = actions * self.model._action_masks[task]
# Compute elite actions # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num(0) value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_value, elite_actions = value[elite_idxs], actions[:, :, elite_idxs]
# Update parameters # Update parameters
max_value = elite_value.max(0).values max_value = elite_value.max(1).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value)) score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
score = score / score.sum(0) score = score / score.sum(1)
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) mean = (score.unsqueeze(1) * elite_actions).sum(dim=2) / (score.sum(1) + 1e-9)
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / (score.sum(1) + 1e-9)).sqrt()
std = std.clamp(self.cfg.min_std, self.cfg.max_std) std = std.clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask: if self.cfg.multitask:
mean = mean * self.model._action_masks[task] mean = mean * self.model._action_masks[task]
@@ -185,12 +185,13 @@ class TDMPC2(torch.nn.Module):
# Select action # Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
a, std = actions[0], std[0] action, std = actions[:, 0], std[:, 0]
if not eval_mode: if not eval_mode:
a = a + std * torch.randn(self.cfg.action_dim, device=std.device) action = action + std * torch.randn(self.cfg.action_dim, device=std.device)
self._prev_mean.copy_(mean) self._prev_mean.copy_(mean)
return a.clamp(-1, 1) return action.clamp(-1, 1)
def update_pi(self, zs, task): def update_pi(self, zs, task):
""" """
Update policy using a sequence of latent states. Update policy using a sequence of latent states.

View File

@@ -1,6 +1,5 @@
from time import time from time import time
import numpy as np
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from trainer.base import Trainer from trainer.base import Trainer
@@ -25,12 +24,12 @@ class OnlineTrainer(Trainer):
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], [] ep_rewards = []
for i in range(self.cfg.eval_episodes): for i in range(self.cfg.eval_episodes // self.cfg.num_envs):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0 obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0)) self.logger.video.init(self.env, enabled=(i==0))
while not done: while not done.any():
torch.compiler.cudagraph_mark_step_begin() torch.compiler.cudagraph_mark_step_begin()
action = self.agent.act(obs, t0=t==0, eval_mode=True) action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
@@ -38,13 +37,13 @@ class OnlineTrainer(Trainer):
t += 1 t += 1
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.record(self.env) self.logger.video.record(self.env)
assert done.all(), 'Vectorized environments must reset all environments at once.'
ep_rewards.append(ep_reward) ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.save(self._step) self.logger.video.save(self._step)
return dict( return dict(
episode_reward=np.nanmean(ep_rewards), episode_reward=torch.cat(ep_rewards).mean(),
episode_success=np.nanmean(ep_successes), episode_success=info['success'].mean(),
) )
def to_td(self, obs, action=None, reward=None): def to_td(self, obs, action=None, reward=None):
@@ -56,24 +55,25 @@ class OnlineTrainer(Trainer):
if action is None: if action is None:
action = torch.full_like(self.env.rand_act(), float('nan')) action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
td = TensorDict( td = TensorDict(dict(
obs=obs, obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
batch_size=(1,)) ), batch_size=(1, self.cfg.num_envs,))
return td return td
def train(self): def train(self):
"""Train a TD-MPC2 agent.""" """Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, False train_metrics, done, eval_next = {}, torch.tensor(True), True
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
# Evaluate agent periodically # Evaluate agent periodically
if self._step % self.cfg.eval_freq == 0: if self._step % self.cfg.eval_freq == 0:
eval_next = True eval_next = True
# Reset environment # Reset environment
if done: if done.any():
assert done.all(), 'Vectorized environments must reset all environments at once.'
if eval_next: if eval_next:
eval_metrics = self.eval() eval_metrics = self.eval()
eval_metrics.update(self.common_metrics()) eval_metrics.update(self.common_metrics())
@@ -81,13 +81,14 @@ class OnlineTrainer(Trainer):
eval_next = False eval_next = False
if self._step > 0: if self._step > 0:
tds = torch.cat(self._tds)
train_metrics.update( train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_reward=tds['reward'].nansum(0).mean(),
episode_success=info['success'], episode_success=info['success'].nanmean(),
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(tds)
obs = self.env.reset() obs = self.env.reset()
self._tds = [self.to_td(obs)] self._tds = [self.to_td(obs)]
@@ -103,14 +104,14 @@ class OnlineTrainer(Trainer):
# Update agent # Update agent
if self._step >= self.cfg.seed_steps: if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps: if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update)
print('Pretraining agent on seed data...') print('Pretraining agent on seed data...')
else: else:
num_updates = 1 num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update))
for _ in range(num_updates): for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer) _train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics) train_metrics.update(_train_metrics)
self._step += 1 self._step += self.cfg.num_envs
self.logger.finish(self.agent) self.logger.finish(self.agent)