diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 4b16f6f..c0eab21 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -83,11 +83,13 @@ class Buffer(): def add(self, td): """Add an episode to the buffer.""" - td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) + td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs) + td = td.permute(1, 0) if self._num_eps == 0: - self._buffer = self._init(td) - self._buffer.extend(td) - self._num_eps += 1 + self._buffer = self._init(td[0]) + for i in range(self.cfg.num_envs): + self._buffer.extend(td[i]) + self._num_eps += self.cfg.num_envs return self._num_eps def _prepare_batch(self, td): diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index 57a8da8..a752499 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -84,15 +84,12 @@ def two_hot_inv(x, cfg): return symexp(x) -def gumbel_softmax_sample(p, temperature=1.0, dim=0): - """Sample from the Gumbel-Softmax distribution.""" - logits = p.log() - gumbels = ( - -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() - ) # ~Gumbel(0,1) - gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau) - y_soft = gumbels.softmax(dim) - return y_soft.argmax(-1) +def gumbel_softmax_sample(p, temperature=1.0, dim=1): + """Sample indices from a Gumbel-Softmax distribution.""" + logits = torch.log(p + 1e-9) + gumbels = -torch.empty_like(logits).exponential_().log() + y = (logits + gumbels) / temperature + return y.argmax(dim=dim) def termination_statistics(pred, target, eps=1e-9): diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index ff15dbb..4b7d9b1 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -5,6 +5,7 @@ defaults: task: dog-run obs: state episodic: false +num_envs: 1 # evaluation checkpoint: ??? @@ -14,6 +15,7 @@ eval_freq: 50000 # training steps: 10_000_000 batch_size: 256 +steps_per_update: 1 reward_coef: 0.1 value_coef: 0.1 termination_coef: 1 @@ -64,8 +66,8 @@ dropout: 0.01 simnorm_dim: 8 # logging -wandb_project: ??? -wandb_entity: ??? +wandb_project: tdmpc3 +wandb_entity: nicklashansen wandb_silent: false enable_wandb: true save_csv: true diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 46f99a8..6a78be6 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -5,6 +5,8 @@ import gymnasium as gym from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.tensor import TensorWrapper +from envs.wrappers.vectorized import Vectorized + def missing_dependencies(task): raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') @@ -62,16 +64,19 @@ def make_env(cfg): gym.logger.set_level(40) if cfg.multitask: env = make_multitask_env(cfg) - else: env = None for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]: try: env = fn(cfg) + break except ValueError: pass if env is None: raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') + assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \ + 'Vectorized environments only support state observations.' + env = Vectorized(cfg, fn) env = TensorWrapper(env) try: # Dict cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} @@ -79,5 +84,5 @@ def make_env(cfg): cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.action_dim = env.action_space.shape[0] cfg.episode_length = env.max_episode_steps - cfg.seed_steps = max(1000, 5*cfg.episode_length) + cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs return env diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index a6e21b3..f549bd5 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -44,12 +44,16 @@ class DMControlWrapper: def unwrapped(self): return self.env + @property + def metadata(self): + return None + def _obs_to_array(self, obs): return torch.from_numpy( np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32)) def reset(self): - return self._obs_to_array(self.env.reset().observation) + return self._obs_to_array(self.env.reset().observation), defaultdict(float) def step(self, action): reward = 0 @@ -61,6 +65,9 @@ class DMControlWrapper: def render(self, width=384, height=384, camera_id=None): return self.env.physics.render(height, width, camera_id or self.camera_id) + + def close(self): + self.env.close() class Pixels(gym.Wrapper): @@ -88,6 +95,9 @@ class Pixels(gym.Wrapper): _, reward, done, info = self.env.step(action) return self._get_obs(), reward, done, info + def close(self): + self.env.close() + def make_env(cfg): """ diff --git a/tdmpc2/envs/wrappers/tensor.py b/tdmpc2/envs/wrappers/tensor.py index 4a6819a..be636a3 100644 --- a/tdmpc2/envs/wrappers/tensor.py +++ b/tdmpc2/envs/wrappers/tensor.py @@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) + self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized' def rand_act(self): + if self._wrapped_vectorized: + return self.env.rand_act() return torch.from_numpy(self.action_space.sample().astype(np.float32)) def _try_f32_tensor(self, x): @@ -31,12 +34,24 @@ class TensorWrapper(gym.Wrapper): obs = self._try_f32_tensor(obs) return obs - def reset(self, task_idx=None): - return self._obs_to_tensor(self.env.reset()) + def reset(self, task_idx=None, **kwargs): + if self._wrapped_vectorized: + obs = self.env.reset(**kwargs) + else: + obs = self.env.reset() + return self._obs_to_tensor(obs) - def step(self, action): - obs, reward, done, info = self.env.step(action.numpy()) - info = defaultdict(float, info) - info['success'] = float(info['success']) - info['terminated'] = torch.tensor(float(info['terminated'])) - return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info + def step(self, action, **kwargs): + if self._wrapped_vectorized: + obs, reward, terminated, truncated, info = self.env.step(action.numpy(), **kwargs) + else: + obs, reward, terminated, truncated, info = self.env.step(action.numpy()) + reward = torch.tensor(reward, dtype=torch.float32) + terminated = torch.tensor(terminated) + truncated = torch.tensor(truncated) + done = terminated | truncated + if 'success' not in info: + info['success'] = torch.zeros_like(reward) + info['terminated'] = terminated.float() + info['truncated'] = truncated.float() + return self._obs_to_tensor(obs), reward, done, info diff --git a/tdmpc2/envs/wrappers/timeout.py b/tdmpc2/envs/wrappers/timeout.py index cc2081c..1499e82 100644 --- a/tdmpc2/envs/wrappers/timeout.py +++ b/tdmpc2/envs/wrappers/timeout.py @@ -19,7 +19,9 @@ class Timeout(gym.Wrapper): return self.env.reset(**kwargs) def step(self, action): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, info = self.env.step(action) self._t += 1 - done = done or self._t >= self.max_episode_steps - return obs, reward, done, info + truncated = self._t >= self.max_episode_steps + info['terminated'] = terminated + info['truncated'] = truncated + return obs, reward, terminated, truncated, info diff --git a/tdmpc2/envs/wrappers/vectorized.py b/tdmpc2/envs/wrappers/vectorized.py new file mode 100644 index 0000000..4dadc9b --- /dev/null +++ b/tdmpc2/envs/wrappers/vectorized.py @@ -0,0 +1,42 @@ +from copy import deepcopy + +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv +import numpy as np +import torch + + +class Vectorized(): + """ + Vectorized environment for TD-MPC2 online training. + """ + + def __init__(self, cfg, env_fn): + super().__init__() + self.cfg = cfg + + def make(): + _cfg = deepcopy(cfg) + _cfg.num_envs = 1 + _cfg.seed = cfg.seed + np.random.randint(1000) + return env_fn(_cfg) + + print(f'Creating {cfg.num_envs} environments...') + # self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)]) + self.env = SyncVectorEnv([make for _ in range(cfg.num_envs)]) + env = make() + self.observation_space = env.observation_space + self.action_space = env.action_space + self.max_episode_steps = env.max_episode_steps + + def rand_act(self): + return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1 + + def reset(self): + obs, _ = self.env.reset() + return obs + + def step(self, action): + return self.env.step(action) + + def render(self, *args, **kwargs): + return self.env.render(*args, **kwargs) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 7111329..6029b4f 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -38,7 +38,7 @@ class TDMPC2(torch.nn.Module): ) if self.cfg.multitask else self._get_discount(cfg.episode_length) print('Episode length:', cfg.episode_length) print('Discount factor:', self.discount) - self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) + self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: print('Compiling update function with torch.compile...') self._update = torch.compile(self._update, mode="reduce-overhead") @@ -109,7 +109,7 @@ class TDMPC2(torch.nn.Module): Returns: torch.Tensor: Action to take in the environment. """ - obs = obs.to(self.device, non_blocking=True).unsqueeze(0) + obs = obs.to(self.device, non_blocking=True) if task is not None: task = torch.tensor([task], device=self.device) if self.cfg.mpc: @@ -118,7 +118,7 @@ class TDMPC2(torch.nn.Module): action, info = self.model.pi(z, task) if eval_mode: action = info["mean"] - return action[0].cpu() + return action.cpu() @torch.no_grad() def _estimate_value(self, z, actions, task): @@ -126,8 +126,8 @@ class TDMPC2(torch.nn.Module): G, discount = 0, 1 termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device) for t in range(self.cfg.horizon): - reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) - z = self.model.next(z, actions[t], task) + reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg) + z = self.model.next(z, actions[:, t], task) G = G + discount * (1-termination) * reward discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update @@ -137,7 +137,7 @@ class TDMPC2(torch.nn.Module): return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg') @torch.no_grad() - def _plan(self, obs, t0=False, eval_mode=False, task=None): + def _plan(self, z, t0=False, eval_mode=False, task=None): """ Plan a sequence of actions using the learned world model. @@ -151,61 +151,69 @@ class TDMPC2(torch.nn.Module): torch.Tensor: Action to take in the environment. """ # Sample policy trajectories - z = self.model.encode(obs, task) if self.cfg.num_pi_trajs > 0: - pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) - _z = z.repeat(self.cfg.num_pi_trajs, 1) - for t in range(self.cfg.horizon-1): - pi_actions[t], _ = self.model.pi(_z, task) - _z = self.model.next(_z, pi_actions[t], task) - pi_actions[-1], _ = self.model.pi(_z, task) + pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) + _z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1).view(self.cfg.num_envs * self.cfg.num_pi_trajs, -1) + for t in range(self.cfg.horizon - 1): + a, _ = self.model.pi(_z, task) + pi_actions[:, t] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim) + _z = self.model.next(_z, a, task) + a, _ = self.model.pi(_z, task) + pi_actions[:, -1] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim) # Initialize state and parameters - z = z.repeat(self.cfg.num_samples, 1) - mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) - std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) + z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1) + mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device) + std = torch.full((self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, device=self.device) if not t0: - mean[:-1] = self._prev_mean[1:] - actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) + mean[:, :-1] = self._prev_mean[:, 1:] + actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) if self.cfg.num_pi_trajs > 0: - actions[:, :self.cfg.num_pi_trajs] = pi_actions + actions[:, :, :self.cfg.num_pi_trajs] = pi_actions # Iterate MPPI for _ in range(self.cfg.iterations): - # Sample actions - r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) - actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r - actions_sample = actions_sample.clamp(-1, 1) - actions[:, self.cfg.num_pi_trajs:] = actions_sample + # Sample new actions + r = torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples - self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) + actions_sample = mean.unsqueeze(2) + std.unsqueeze(2) * r + actions[:, :, self.cfg.num_pi_trajs:] = actions_sample.clamp(-1, 1) if self.cfg.multitask: actions = actions * self.model._action_masks[task] # Compute elite actions value = self._estimate_value(z, actions, task).nan_to_num(0) - elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices - elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices + elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2)) + elite_actions = actions.gather( + dim=2, + index=elite_idxs[:, None, :, None].expand(-1, self.cfg.horizon, self.cfg.num_elites, self.cfg.action_dim) + ) # Update parameters - max_value = elite_value.max(0).values - score = torch.exp(self.cfg.temperature*(elite_value - max_value)) - score = score / score.sum(0) - mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) - std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() - std = std.clamp(self.cfg.min_std, self.cfg.max_std) + score = torch.exp(self.cfg.temperature * (elite_value - elite_value.max(1, keepdim=True).values)) + score = score / (score.sum(dim=1, keepdim=True) + 1e-9) + score_exp = score.unsqueeze(1) + mean = (score_exp * elite_actions).sum(dim=2) / (score_exp.sum(dim=2) + 1e-9) + std = ((score_exp * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / + (score_exp.sum(dim=2) + 1e-9)).sqrt().clamp(self.cfg.min_std, self.cfg.max_std) if self.cfg.multitask: mean = mean * self.model._action_masks[task] std = std * self.model._action_masks[task] # Select action - rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) - actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) - a, std = actions[0], std[0] + logits = torch.log(score.squeeze(2) + 1e-9) + rand_idx = math.gumbel_softmax_sample(logits, temperature=self.cfg.temperature, dim=1) + selected_actions = elite_actions.gather( + dim=2, + index=rand_idx[:, None, None, None].expand(-1, self.cfg.horizon, 1, self.cfg.action_dim) + ).squeeze(2) + action, std_out = selected_actions[:, 0], std[:, 0] if not eval_mode: - a = a + std * torch.randn(self.cfg.action_dim, device=std.device) + action = action + std_out * torch.randn_like(action) self._prev_mean.copy_(mean) - return a.clamp(-1, 1) - + return action.clamp(-1, 1) + def update_pi(self, zs, task): """ Update policy using a sequence of latent states. diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 83128f7..3cee3f9 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -1,6 +1,5 @@ from time import time -import numpy as np import torch from tensordict.tensordict import TensorDict from trainer.base import Trainer @@ -28,11 +27,11 @@ class OnlineTrainer(Trainer): def eval(self): """Evaluate a TD-MPC2 agent.""" ep_rewards, ep_successes, ep_lengths = [], [], [] - for i in range(self.cfg.eval_episodes): - obs, done, ep_reward, t = self.env.reset(), False, 0, 0 + for i in range(self.cfg.eval_episodes // self.cfg.num_envs): + obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0 if self.cfg.save_video: self.logger.video.init(self.env, enabled=(i==0)) - while not done: + while not done.any(): torch.compiler.cudagraph_mark_step_begin() action = self.agent.act(obs, t0=t==0, eval_mode=True) obs, reward, done, info = self.env.step(action) @@ -40,15 +39,16 @@ class OnlineTrainer(Trainer): t += 1 if self.cfg.save_video: self.logger.video.record(self.env) + assert done.all(), 'Vectorized environments must reset all environments at once.' ep_rewards.append(ep_reward) ep_successes.append(info['success']) ep_lengths.append(t) if self.cfg.save_video: self.logger.video.save(self._step) return dict( - episode_reward=np.nanmean(ep_rewards), - episode_success=np.nanmean(ep_successes), - episode_length= np.nanmean(ep_lengths), + episode_reward=torch.cat(ep_rewards).mean(), + episode_success=info['success'].mean(), + episode_length= torch.tensor(ep_lengths, dtype=torch.float32).mean(), ) def to_td(self, obs, action=None, reward=None, terminated=None): @@ -60,27 +60,28 @@ class OnlineTrainer(Trainer): if action is None: action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: - reward = torch.tensor(float('nan')) + reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs) if terminated is None: - terminated = torch.tensor(float('nan')) + terminated = torch.tensor(float('nan')).repeat(self.cfg.num_envs) td = TensorDict( obs=obs, action=action.unsqueeze(0), reward=reward.unsqueeze(0), terminated=terminated.unsqueeze(0), - batch_size=(1,)) + batch_size=(1, self.cfg.num_envs,)) return td def train(self): """Train a TD-MPC2 agent.""" - train_metrics, done, eval_next = {}, True, False + train_metrics, done, eval_next = {}, torch.tensor(True), True while self._step <= self.cfg.steps: # Evaluate agent periodically if self._step % self.cfg.eval_freq == 0: eval_next = True # Reset environment - if done: + if done.any(): + assert done.all(), 'Vectorized environments must reset all environments at once.' if eval_next: eval_metrics = self.eval() eval_metrics.update(self.common_metrics()) @@ -88,17 +89,19 @@ class OnlineTrainer(Trainer): eval_next = False if self._step > 0: - if info['terminated'] and not self.cfg.episodic: + if info['terminated'].any() and not self.cfg.episodic: raise ValueError('Termination detected but you are not in episodic mode. ' \ 'Set `episodic=true` to enable support for terminations.') + tds = torch.cat(self._tds) train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), - episode_success=info['success'], + episode_reward=tds['reward'].nansum(0).mean(), + episode_success=info['success'].nanmean(), episode_length=len(self._tds), - episode_terminated=info['terminated']) + episode_terminated=info['terminated'].nanmean(), + ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + self._ep_idx = self.buffer.add(tds) obs = self.env.reset() self._tds = [self.to_td(obs)] @@ -114,14 +117,14 @@ class OnlineTrainer(Trainer): # Update agent if self._step >= self.cfg.seed_steps: if self._step == self.cfg.seed_steps: - num_updates = self.cfg.seed_steps + num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update) print('Pretraining agent on seed data...') else: - num_updates = 1 + num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update)) for _ in range(num_updates): _train_metrics = self.agent.update(self.buffer) train_metrics.update(_train_metrics) - self._step += 1 - + self._step += self.cfg.num_envs + self.logger.finish(self.agent)