From 51d6b8d7a9755705839cd910fcaa5260401d1378 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 11 Feb 2024 14:41:20 -0800 Subject: [PATCH] init --- tdmpc2/common/buffer.py | 12 +++-- tdmpc2/config.yaml | 2 + tdmpc2/envs/__init__.py | 9 +++- tdmpc2/envs/dmcontrol.py | 3 ++ tdmpc2/envs/wrappers/tensor.py | 27 ++++++++--- tdmpc2/envs/wrappers/vectorized.py | 40 +++++++++++++++ tdmpc2/tdmpc2.py | 78 +++++++++++++++++------------- tdmpc2/trainer/online_trainer.py | 37 +++++++------- 8 files changed, 144 insertions(+), 64 deletions(-) create mode 100644 tdmpc2/envs/wrappers/vectorized.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index c14aa1f..2a1592e 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -41,7 +41,7 @@ class Buffer(): storage=storage, sampler=self._sampler, pin_memory=True, - prefetch=1, + prefetch=int(self.cfg.num_envs / self.cfg.steps_per_update), batch_size=self._batch_size, ) @@ -82,11 +82,13 @@ class Buffer(): def add(self, td): """Add an episode to the buffer.""" - td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps + td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs) + td = td.permute(1, 0) if self._num_eps == 0: - self._buffer = self._init(td) - self._buffer.extend(td) - self._num_eps += 1 + self._buffer = self._init(td[0]) + for i in range(self.cfg.num_envs): + self._buffer.extend(td[i]) + self._num_eps += self.cfg.num_envs return self._num_eps def sample(self): diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b720923..30e4279 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -4,6 +4,7 @@ defaults: # environment task: dog-run obs: state +num_envs: 1 # evaluation checkpoint: ??? @@ -13,6 +14,7 @@ eval_freq: 50000 # training steps: 10_000_000 batch_size: 256 +steps_per_update: 1 reward_coef: 0.1 value_coef: 0.1 consistency_coef: 20 diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 6326a9e..c780efd 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -6,6 +6,8 @@ import gym from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.pixels import PixelWrapper from envs.wrappers.tensor import TensorWrapper +from envs.wrappers.vectorized import Vectorized + def missing_dependencies(task): raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') @@ -59,16 +61,19 @@ def make_env(cfg): gym.logger.set_level(40) if cfg.multitask: env = make_multitask_env(cfg) - else: env = None for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) + break except ValueError: pass if env is None: raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') + assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \ + 'Vectorized environments only support state observations.' + env = Vectorized(cfg, fn) env = TensorWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) @@ -78,5 +83,5 @@ def make_env(cfg): cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.action_dim = env.action_space.shape[0] cfg.episode_length = env.max_episode_steps - cfg.seed_steps = max(1000, 5*cfg.episode_length) + cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs return env diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 97be75a..eca8d10 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -177,6 +177,9 @@ class TimeStepToGymWrapper: camera_id = dict(quadruped=2).get(self.domain, camera_id) return self.env.physics.render(height, width, camera_id) + def close(self): + self.env.close() + def make_env(cfg): """ diff --git a/tdmpc2/envs/wrappers/tensor.py b/tdmpc2/envs/wrappers/tensor.py index 548a5f4..d331da6 100644 --- a/tdmpc2/envs/wrappers/tensor.py +++ b/tdmpc2/envs/wrappers/tensor.py @@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) + self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized' def rand_act(self): + if self._wrapped_vectorized: + return self.env.rand_act() return torch.from_numpy(self.action_space.sample().astype(np.float32)) def _try_f32_tensor(self, x): @@ -30,11 +33,23 @@ class TensorWrapper(gym.Wrapper): obs = self._try_f32_tensor(obs) return obs - def reset(self, task_idx=None): - return self._obs_to_tensor(self.env.reset()) + def reset(self, task_idx=None, **kwargs): + if self._wrapped_vectorized: + obs = self.env.reset(**kwargs) + else: + obs = self.env.reset() + return self._obs_to_tensor(obs) - def step(self, action): - obs, reward, done, info = self.env.step(action.numpy()) - info = defaultdict(float, info) - info['success'] = float(info['success']) + def step(self, action, **kwargs): + if self._wrapped_vectorized: + obs, reward, done, info = self.env.step(action.numpy(), **kwargs) + else: + obs, reward, done, info = self.env.step(action.numpy()) + if isinstance(info, tuple): + info = {key: torch.stack([torch.tensor(d[key]) for d in info]) for key in info[0].keys()} + if 'success' not in info.keys(): + info['success'] = torch.zeros(len(done)) + else: + info = defaultdict(float, info) + info['success'] = float(info['success']) return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info diff --git a/tdmpc2/envs/wrappers/vectorized.py b/tdmpc2/envs/wrappers/vectorized.py new file mode 100644 index 0000000..ca9848f --- /dev/null +++ b/tdmpc2/envs/wrappers/vectorized.py @@ -0,0 +1,40 @@ +from copy import deepcopy + +from gym.vector import AsyncVectorEnv +import numpy as np +import torch + + +class Vectorized(): + """ + Vectorized environment for TD-MPC2 online training. + """ + + def __init__(self, cfg, env_fn): + super().__init__() + self.cfg = cfg + + def make(): + _cfg = deepcopy(cfg) + _cfg.num_envs = 1 + _cfg.seed = cfg.seed + np.random.randint(1000) + return env_fn(_cfg) + + print(f'Creating {cfg.num_envs} environments...') + self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)]) + env = make() + self.observation_space = env.observation_space + self.action_space = env.action_space + self.max_episode_steps = env.max_episode_steps + + def rand_act(self): + return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1 + + def reset(self): + return self.env.reset() + + def step(self, action): + return self.env.step(action) + + def render(self, *args, **kwargs): + return self.env.render(*args, **kwargs) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 9d49cf8..d0a54c4 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -81,23 +81,23 @@ class TDMPC2: Returns: torch.Tensor: Action to take in the environment. """ - obs = obs.to(self.device, non_blocking=True).unsqueeze(0) + obs = obs.to(self.device, non_blocking=True) if task is not None: task = torch.tensor([task], device=self.device) z = self.model.encode(obs, task) if self.cfg.mpc: - a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) + action = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) else: - a = self.model.pi(z, task)[int(not eval_mode)][0] - return a.cpu() + action = self.model.pi(z, task)[int(not eval_mode)] + return action.cpu() @torch.no_grad() def _estimate_value(self, z, actions, task): """Estimate value of a trajectory starting at latent state z and executing given actions.""" G, discount = 0, 1 for t in range(self.cfg.horizon): - reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) - z = self.model.next(z, actions[t], task) + reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg) + z = self.model.next(z, actions[:, t], task) G += discount * reward discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') @@ -115,60 +115,72 @@ class TDMPC2: Returns: torch.Tensor: Action to take in the environment. - """ + """ # Sample policy trajectories if self.cfg.num_pi_trajs > 0: - pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) - _z = z.repeat(self.cfg.num_pi_trajs, 1) + pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) + _z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1) for t in range(self.cfg.horizon-1): - pi_actions[t] = self.model.pi(_z, task)[1] - _z = self.model.next(_z, pi_actions[t], task) - pi_actions[-1] = self.model.pi(_z, task)[1] + pi_actions[:,t] = self.model.pi(_z, task)[1] + _z = self.model.next(_z, pi_actions[:,t], task) + pi_actions[:,-1] = self.model.pi(_z, task)[1] # Initialize state and parameters - z = z.repeat(self.cfg.num_samples, 1) - mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) - std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device) + z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1) + mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device) + std = self.cfg.max_std*torch.ones(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device) if not t0: - mean[:-1] = self._prev_mean[1:] - actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) + mean[:, :-1] = self._prev_mean[:, 1:] + actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) if self.cfg.num_pi_trajs > 0: - actions[:, :self.cfg.num_pi_trajs] = pi_actions + actions[:, :, :self.cfg.num_pi_trajs] = pi_actions # Iterate MPPI for _ in range(self.cfg.iterations): # Sample actions - actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \ - torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \ + actions[:, :, self.cfg.num_pi_trajs:] = (mean.unsqueeze(2) + std.unsqueeze(2) * \ + torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \ .clamp(-1, 1) if self.cfg.multitask: actions = actions * self.model._action_masks[task] # Compute elite actions value = self._estimate_value(z, actions, task).nan_to_num_(0) - elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices - elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + + elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices + elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2)) + elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim)) + + # vectorized version + # elite_value, elite_actions = [], [] + # for i in range(self.cfg.num_envs): + # elite_value.append(value[i, elite_idxs[i]]) + # elite_actions.append(actions[i, elite_idxs[i]]) + # elite_value = torch.stack(elite_value, dim=0) # Update parameters - max_value = elite_value.max(0)[0] - score = torch.exp(self.cfg.temperature*(elite_value - max_value)) - score /= score.sum(0) - mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) - std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \ + max_value = elite_value.max(1)[0] + score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1))) + score /= score.sum(1, keepdim=True) + mean = torch.sum(score.unsqueeze(1) * elite_actions, dim=2) / (score.sum(1, keepdim=True) + 1e-9) + std = torch.sqrt(torch.sum(score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2, dim=2) / (score.sum(1, keepdim=True) + 1e-9)) \ .clamp_(self.cfg.min_std, self.cfg.max_std) if self.cfg.multitask: mean = mean * self.model._action_masks[task] std = std * self.model._action_masks[task] - # Select action - score = score.squeeze(1).cpu().numpy() - actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] + # Select action sequence with probability `score` + score = score.squeeze(1).squeeze(-1).cpu().numpy() + actions = torch.stack([ + elite_actions[i, :, np.random.choice(np.arange(score.shape[1]), p=score[i])] \ + for i in range(score.shape[0])], dim=0) + self._prev_mean = mean - a, std = actions[0], std[0] + action, std = actions[:, 0], std[:, 0] if not eval_mode: - a += std * torch.randn(self.cfg.action_dim, device=std.device) - return a.clamp_(-1, 1) + action += std * torch.randn(self.cfg.action_dim, device=std.device) + return action.clamp_(-1, 1) def update_pi(self, zs, task): """ diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index a3326bc..9a60d83 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -1,6 +1,5 @@ from time import time -import numpy as np import torch from tensordict.tensordict import TensorDict @@ -26,25 +25,25 @@ class OnlineTrainer(Trainer): def eval(self): """Evaluate a TD-MPC2 agent.""" - ep_rewards, ep_successes = [], [] - for i in range(self.cfg.eval_episodes): - obs, done, ep_reward, t = self.env.reset(), False, 0, 0 + ep_rewards = [] + for i in range(self.cfg.eval_episodes // self.cfg.num_envs): + obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0 if self.cfg.save_video: self.logger.video.init(self.env, enabled=(i==0)) - while not done: + while not done.any(): action = self.agent.act(obs, t0=t==0, eval_mode=True) obs, reward, done, info = self.env.step(action) ep_reward += reward t += 1 if self.cfg.save_video: self.logger.video.record(self.env) + assert done.all(), 'Vectorized environments must reset all environments at once.' ep_rewards.append(ep_reward) - ep_successes.append(info['success']) if self.cfg.save_video: self.logger.video.save(self._step) return dict( - episode_reward=np.nanmean(ep_rewards), - episode_success=np.nanmean(ep_successes), + episode_reward=torch.cat(ep_rewards).mean(), + episode_success=info['success'].mean(), ) def to_td(self, obs, action=None, reward=None): @@ -56,17 +55,17 @@ class OnlineTrainer(Trainer): if action is None: action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: - reward = torch.tensor(float('nan')) + reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs) td = TensorDict(dict( obs=obs, action=action.unsqueeze(0), reward=reward.unsqueeze(0), - ), batch_size=(1,)) + ), batch_size=(1, self.cfg.num_envs,)) return td def train(self): """Train a TD-MPC2 agent.""" - train_metrics, done, eval_next = {}, True, True + train_metrics, done, eval_next = {}, torch.tensor(True), True while self._step <= self.cfg.steps: # Evaluate agent periodically @@ -74,7 +73,8 @@ class OnlineTrainer(Trainer): eval_next = True # Reset environment - if done: + if done.any(): + assert done.all(), 'Vectorized environments must reset all environments at once.' if eval_next: eval_metrics = self.eval() eval_metrics.update(self.common_metrics()) @@ -82,13 +82,14 @@ class OnlineTrainer(Trainer): eval_next = False if self._step > 0: + tds = torch.cat(self._tds) train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), - episode_success=info['success'], + episode_reward=tds['reward'].nansum(0).mean(), + episode_success=info['success'].nanmean(), ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + self._ep_idx = self.buffer.add(tds) obs = self.env.reset() self._tds = [self.to_td(obs)] @@ -104,14 +105,14 @@ class OnlineTrainer(Trainer): # Update agent if self._step >= self.cfg.seed_steps: if self._step == self.cfg.seed_steps: - num_updates = self.cfg.seed_steps + num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update) print('Pretraining agent on seed data...') else: - num_updates = 1 + num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update)) for _ in range(num_updates): _train_metrics = self.agent.update(self.buffer) train_metrics.update(_train_metrics) - self._step += 1 + self._step += self.cfg.num_envs self.logger.finish(self.agent)