This commit is contained in:
Nicklas Hansen
2024-02-11 14:41:20 -08:00
parent 57158282b4
commit 51d6b8d7a9
8 changed files with 144 additions and 64 deletions

View File

@@ -41,7 +41,7 @@ class Buffer():
storage=storage, storage=storage,
sampler=self._sampler, sampler=self._sampler,
pin_memory=True, pin_memory=True,
prefetch=1, prefetch=int(self.cfg.num_envs / self.cfg.steps_per_update),
batch_size=self._batch_size, batch_size=self._batch_size,
) )
@@ -82,11 +82,13 @@ class Buffer():
def add(self, td): def add(self, td):
"""Add an episode to the buffer.""" """Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs)
td = td.permute(1, 0)
if self._num_eps == 0: if self._num_eps == 0:
self._buffer = self._init(td) self._buffer = self._init(td[0])
self._buffer.extend(td) for i in range(self.cfg.num_envs):
self._num_eps += 1 self._buffer.extend(td[i])
self._num_eps += self.cfg.num_envs
return self._num_eps return self._num_eps
def sample(self): def sample(self):

View File

@@ -4,6 +4,7 @@ defaults:
# environment # environment
task: dog-run task: dog-run
obs: state obs: state
num_envs: 1
# evaluation # evaluation
checkpoint: ??? checkpoint: ???
@@ -13,6 +14,7 @@ eval_freq: 50000
# training # training
steps: 10_000_000 steps: 10_000_000
batch_size: 256 batch_size: 256
steps_per_update: 1
reward_coef: 0.1 reward_coef: 0.1
value_coef: 0.1 value_coef: 0.1
consistency_coef: 20 consistency_coef: 20

View File

@@ -6,6 +6,8 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
from envs.wrappers.vectorized import Vectorized
def missing_dependencies(task): def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
@@ -59,16 +61,19 @@ def make_env(cfg):
gym.logger.set_level(40) gym.logger.set_level(40)
if cfg.multitask: if cfg.multitask:
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
break
except ValueError: except ValueError:
pass pass
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \
'Vectorized environments only support state observations.'
env = Vectorized(cfg, fn)
env = TensorWrapper(env) env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb': if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env) env = PixelWrapper(cfg, env)
@@ -78,5 +83,5 @@ def make_env(cfg):
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0] cfg.action_dim = env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length) cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs
return env return env

View File

@@ -177,6 +177,9 @@ class TimeStepToGymWrapper:
camera_id = dict(quadruped=2).get(self.domain, camera_id) camera_id = dict(quadruped=2).get(self.domain, camera_id)
return self.env.physics.render(height, width, camera_id) return self.env.physics.render(height, width, camera_id)
def close(self):
self.env.close()
def make_env(cfg): def make_env(cfg):
""" """

View File

@@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized'
def rand_act(self): def rand_act(self):
if self._wrapped_vectorized:
return self.env.rand_act()
return torch.from_numpy(self.action_space.sample().astype(np.float32)) return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x): def _try_f32_tensor(self, x):
@@ -30,11 +33,23 @@ class TensorWrapper(gym.Wrapper):
obs = self._try_f32_tensor(obs) obs = self._try_f32_tensor(obs)
return obs return obs
def reset(self, task_idx=None): def reset(self, task_idx=None, **kwargs):
return self._obs_to_tensor(self.env.reset()) if self._wrapped_vectorized:
obs = self.env.reset(**kwargs)
else:
obs = self.env.reset()
return self._obs_to_tensor(obs)
def step(self, action): def step(self, action, **kwargs):
obs, reward, done, info = self.env.step(action.numpy()) if self._wrapped_vectorized:
info = defaultdict(float, info) obs, reward, done, info = self.env.step(action.numpy(), **kwargs)
info['success'] = float(info['success']) else:
obs, reward, done, info = self.env.step(action.numpy())
if isinstance(info, tuple):
info = {key: torch.stack([torch.tensor(d[key]) for d in info]) for key in info[0].keys()}
if 'success' not in info.keys():
info['success'] = torch.zeros(len(done))
else:
info = defaultdict(float, info)
info['success'] = float(info['success'])
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info

View File

@@ -0,0 +1,40 @@
from copy import deepcopy
from gym.vector import AsyncVectorEnv
import numpy as np
import torch
class Vectorized():
"""
Vectorized environment for TD-MPC2 online training.
"""
def __init__(self, cfg, env_fn):
super().__init__()
self.cfg = cfg
def make():
_cfg = deepcopy(cfg)
_cfg.num_envs = 1
_cfg.seed = cfg.seed + np.random.randint(1000)
return env_fn(_cfg)
print(f'Creating {cfg.num_envs} environments...')
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
env = make()
self.observation_space = env.observation_space
self.action_space = env.action_space
self.max_episode_steps = env.max_episode_steps
def rand_act(self):
return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
def render(self, *args, **kwargs):
return self.env.render(*args, **kwargs)

View File

@@ -81,23 +81,23 @@ class TDMPC2:
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
obs = obs.to(self.device, non_blocking=True).unsqueeze(0) obs = obs.to(self.device, non_blocking=True)
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) action = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
else: else:
a = self.model.pi(z, task)[int(not eval_mode)][0] action = self.model.pi(z, task)[int(not eval_mode)]
return a.cpu() return action.cpu()
@torch.no_grad() @torch.no_grad()
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions.""" """Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1 G, discount = 0, 1
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[:, t], task)
G += discount * reward G += discount * reward
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
@@ -118,57 +118,69 @@ class TDMPC2:
""" """
# Sample policy trajectories # Sample policy trajectories
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1) _z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1): for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1] pi_actions[:,t] = self.model.pi(_z, task)[1]
_z = self.model.next(_z, pi_actions[t], task) _z = self.model.next(_z, pi_actions[:,t], task)
pi_actions[-1] = self.model.pi(_z, task)[1] pi_actions[:,-1] = self.model.pi(_z, task)[1]
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device) std = self.cfg.max_std*torch.ones(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
if not t0: if not t0:
mean[:-1] = self._prev_mean[1:] mean[:, :-1] = self._prev_mean[:, 1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI # Iterate MPPI
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
# Sample actions # Sample actions
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \ actions[:, :, self.cfg.num_pi_trajs:] = (mean.unsqueeze(2) + std.unsqueeze(2) * \
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \ torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \
.clamp(-1, 1) .clamp(-1, 1)
if self.cfg.multitask: if self.cfg.multitask:
actions = actions * self.model._action_masks[task] actions = actions * self.model._action_masks[task]
# Compute elite actions # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num_(0) value = self._estimate_value(z, actions, task).nan_to_num_(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
# vectorized version
# elite_value, elite_actions = [], []
# for i in range(self.cfg.num_envs):
# elite_value.append(value[i, elite_idxs[i]])
# elite_actions.append(actions[i, elite_idxs[i]])
# elite_value = torch.stack(elite_value, dim=0)
# Update parameters # Update parameters
max_value = elite_value.max(0)[0] max_value = elite_value.max(1)[0]
score = torch.exp(self.cfg.temperature*(elite_value - max_value)) score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
score /= score.sum(0) score /= score.sum(1, keepdim=True)
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) mean = torch.sum(score.unsqueeze(1) * elite_actions, dim=2) / (score.sum(1, keepdim=True) + 1e-9)
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \ std = torch.sqrt(torch.sum(score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2, dim=2) / (score.sum(1, keepdim=True) + 1e-9)) \
.clamp_(self.cfg.min_std, self.cfg.max_std) .clamp_(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask: if self.cfg.multitask:
mean = mean * self.model._action_masks[task] mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task] std = std * self.model._action_masks[task]
# Select action # Select action sequence with probability `score`
score = score.squeeze(1).cpu().numpy() score = score.squeeze(1).squeeze(-1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] actions = torch.stack([
elite_actions[i, :, np.random.choice(np.arange(score.shape[1]), p=score[i])] \
for i in range(score.shape[0])], dim=0)
self._prev_mean = mean self._prev_mean = mean
a, std = actions[0], std[0] action, std = actions[:, 0], std[:, 0]
if not eval_mode: if not eval_mode:
a += std * torch.randn(self.cfg.action_dim, device=std.device) action += std * torch.randn(self.cfg.action_dim, device=std.device)
return a.clamp_(-1, 1) return action.clamp_(-1, 1)
def update_pi(self, zs, task): def update_pi(self, zs, task):
""" """

View File

@@ -1,6 +1,5 @@
from time import time from time import time
import numpy as np
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
@@ -26,25 +25,25 @@ class OnlineTrainer(Trainer):
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], [] ep_rewards = []
for i in range(self.cfg.eval_episodes): for i in range(self.cfg.eval_episodes // self.cfg.num_envs):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0 obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0)) self.logger.video.init(self.env, enabled=(i==0))
while not done: while not done.any():
action = self.agent.act(obs, t0=t==0, eval_mode=True) action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
ep_reward += reward ep_reward += reward
t += 1 t += 1
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.record(self.env) self.logger.video.record(self.env)
assert done.all(), 'Vectorized environments must reset all environments at once.'
ep_rewards.append(ep_reward) ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.save(self._step) self.logger.video.save(self._step)
return dict( return dict(
episode_reward=np.nanmean(ep_rewards), episode_reward=torch.cat(ep_rewards).mean(),
episode_success=np.nanmean(ep_successes), episode_success=info['success'].mean(),
) )
def to_td(self, obs, action=None, reward=None): def to_td(self, obs, action=None, reward=None):
@@ -56,17 +55,17 @@ class OnlineTrainer(Trainer):
if action is None: if action is None:
action = torch.full_like(self.env.rand_act(), float('nan')) action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
td = TensorDict(dict( td = TensorDict(dict(
obs=obs, obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
), batch_size=(1,)) ), batch_size=(1, self.cfg.num_envs,))
return td return td
def train(self): def train(self):
"""Train a TD-MPC2 agent.""" """Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, True train_metrics, done, eval_next = {}, torch.tensor(True), True
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
# Evaluate agent periodically # Evaluate agent periodically
@@ -74,7 +73,8 @@ class OnlineTrainer(Trainer):
eval_next = True eval_next = True
# Reset environment # Reset environment
if done: if done.any():
assert done.all(), 'Vectorized environments must reset all environments at once.'
if eval_next: if eval_next:
eval_metrics = self.eval() eval_metrics = self.eval()
eval_metrics.update(self.common_metrics()) eval_metrics.update(self.common_metrics())
@@ -82,13 +82,14 @@ class OnlineTrainer(Trainer):
eval_next = False eval_next = False
if self._step > 0: if self._step > 0:
tds = torch.cat(self._tds)
train_metrics.update( train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_reward=tds['reward'].nansum(0).mean(),
episode_success=info['success'], episode_success=info['success'].nanmean(),
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(tds)
obs = self.env.reset() obs = self.env.reset()
self._tds = [self.to_td(obs)] self._tds = [self.to_td(obs)]
@@ -104,14 +105,14 @@ class OnlineTrainer(Trainer):
# Update agent # Update agent
if self._step >= self.cfg.seed_steps: if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps: if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update)
print('Pretraining agent on seed data...') print('Pretraining agent on seed data...')
else: else:
num_updates = 1 num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update))
for _ in range(num_updates): for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer) _train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics) train_metrics.update(_train_metrics)
self._step += 1 self._step += self.cfg.num_envs
self.logger.finish(self.agent) self.logger.finish(self.agent)