From fa41a3e45093f3e938d02bd5aa08f509cd61ce46 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 11 Feb 2024 14:41:20 -0800 Subject: [PATCH] init --- tdmpc2/common/buffer.py | 10 +++--- tdmpc2/config.yaml | 2 ++ tdmpc2/envs/__init__.py | 9 +++-- tdmpc2/envs/dmcontrol.py | 3 ++ tdmpc2/envs/wrappers/tensor.py | 27 +++++++++++---- tdmpc2/envs/wrappers/vectorized.py | 40 ++++++++++++++++++++++ tdmpc2/tdmpc2.py | 55 +++++++++++++++--------------- tdmpc2/trainer/online_trainer.py | 41 +++++++++++----------- 8 files changed, 128 insertions(+), 59 deletions(-) create mode 100644 tdmpc2/envs/wrappers/vectorized.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 3ff5b28..e30fe65 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -81,11 +81,13 @@ class Buffer(): def add(self, td): """Add an episode to the buffer.""" - td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) + td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs) + td = td.permute(1, 0) if self._num_eps == 0: - self._buffer = self._init(td) - self._buffer.extend(td) - self._num_eps += 1 + self._buffer = self._init(td[0]) + for i in range(self.cfg.num_envs): + self._buffer.extend(td[i]) + self._num_eps += self.cfg.num_envs return self._num_eps def sample(self): diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 597c829..214140e 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -4,6 +4,7 @@ defaults: # environment task: dog-run obs: state +num_envs: 1 # evaluation checkpoint: ??? @@ -13,6 +14,7 @@ eval_freq: 50000 # training steps: 10_000_000 batch_size: 256 +steps_per_update: 1 reward_coef: 0.1 value_coef: 0.1 consistency_coef: 20 diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 6326a9e..c780efd 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -6,6 +6,8 @@ import gym from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.pixels import PixelWrapper from envs.wrappers.tensor import TensorWrapper +from envs.wrappers.vectorized import Vectorized + def missing_dependencies(task): raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') @@ -59,16 +61,19 @@ def make_env(cfg): gym.logger.set_level(40) if cfg.multitask: env = make_multitask_env(cfg) - else: env = None for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) + break except ValueError: pass if env is None: raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') + assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \ + 'Vectorized environments only support state observations.' + env = Vectorized(cfg, fn) env = TensorWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) @@ -78,5 +83,5 @@ def make_env(cfg): cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.action_dim = env.action_space.shape[0] cfg.episode_length = env.max_episode_steps - cfg.seed_steps = max(1000, 5*cfg.episode_length) + cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs return env diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 97be75a..eca8d10 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -177,6 +177,9 @@ class TimeStepToGymWrapper: camera_id = dict(quadruped=2).get(self.domain, camera_id) return self.env.physics.render(height, width, camera_id) + def close(self): + self.env.close() + def make_env(cfg): """ diff --git a/tdmpc2/envs/wrappers/tensor.py b/tdmpc2/envs/wrappers/tensor.py index 548a5f4..d331da6 100644 --- a/tdmpc2/envs/wrappers/tensor.py +++ b/tdmpc2/envs/wrappers/tensor.py @@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) + self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized' def rand_act(self): + if self._wrapped_vectorized: + return self.env.rand_act() return torch.from_numpy(self.action_space.sample().astype(np.float32)) def _try_f32_tensor(self, x): @@ -30,11 +33,23 @@ class TensorWrapper(gym.Wrapper): obs = self._try_f32_tensor(obs) return obs - def reset(self, task_idx=None): - return self._obs_to_tensor(self.env.reset()) + def reset(self, task_idx=None, **kwargs): + if self._wrapped_vectorized: + obs = self.env.reset(**kwargs) + else: + obs = self.env.reset() + return self._obs_to_tensor(obs) - def step(self, action): - obs, reward, done, info = self.env.step(action.numpy()) - info = defaultdict(float, info) - info['success'] = float(info['success']) + def step(self, action, **kwargs): + if self._wrapped_vectorized: + obs, reward, done, info = self.env.step(action.numpy(), **kwargs) + else: + obs, reward, done, info = self.env.step(action.numpy()) + if isinstance(info, tuple): + info = {key: torch.stack([torch.tensor(d[key]) for d in info]) for key in info[0].keys()} + if 'success' not in info.keys(): + info['success'] = torch.zeros(len(done)) + else: + info = defaultdict(float, info) + info['success'] = float(info['success']) return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info diff --git a/tdmpc2/envs/wrappers/vectorized.py b/tdmpc2/envs/wrappers/vectorized.py new file mode 100644 index 0000000..ca9848f --- /dev/null +++ b/tdmpc2/envs/wrappers/vectorized.py @@ -0,0 +1,40 @@ +from copy import deepcopy + +from gym.vector import AsyncVectorEnv +import numpy as np +import torch + + +class Vectorized(): + """ + Vectorized environment for TD-MPC2 online training. + """ + + def __init__(self, cfg, env_fn): + super().__init__() + self.cfg = cfg + + def make(): + _cfg = deepcopy(cfg) + _cfg.num_envs = 1 + _cfg.seed = cfg.seed + np.random.randint(1000) + return env_fn(_cfg) + + print(f'Creating {cfg.num_envs} environments...') + self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)]) + env = make() + self.observation_space = env.observation_space + self.action_space = env.action_space + self.max_episode_steps = env.max_episode_steps + + def rand_act(self): + return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1 + + def reset(self): + return self.env.reset() + + def step(self, action): + return self.env.step(action) + + def render(self, *args, **kwargs): + return self.env.render(*args, **kwargs) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index e4d8ec2..0e1a7bd 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -114,8 +114,8 @@ class TDMPC2(torch.nn.Module): """Estimate value of a trajectory starting at latent state z and executing given actions.""" G, discount = 0, 1 for t in range(self.cfg.horizon): - reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) - z = self.model.next(z, actions[t], task) + reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg) + z = self.model.next(z, actions[:, t], task) G = G + discount * reward discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update @@ -138,45 +138,45 @@ class TDMPC2(torch.nn.Module): # Sample policy trajectories z = self.model.encode(obs, task) if self.cfg.num_pi_trajs > 0: - pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) - _z = z.repeat(self.cfg.num_pi_trajs, 1) + pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) + _z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1) for t in range(self.cfg.horizon-1): - pi_actions[t] = self.model.pi(_z, task)[1] - _z = self.model.next(_z, pi_actions[t], task) - pi_actions[-1] = self.model.pi(_z, task)[1] + pi_actions[:,t] = self.model.pi(_z, task)[1] + _z = self.model.next(_z, pi_actions[:,t], task) + pi_actions[:,-1] = self.model.pi(_z, task)[1] # Initialize state and parameters - z = z.repeat(self.cfg.num_samples, 1) - mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) - std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) + z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1) + mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device) + std = self.cfg.max_std*torch.ones(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device) if not t0: - mean[:-1] = self._prev_mean[1:] - actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) + mean[:, :-1] = self._prev_mean[:, 1:] + actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) if self.cfg.num_pi_trajs > 0: - actions[:, :self.cfg.num_pi_trajs] = pi_actions + actions[:, :, :self.cfg.num_pi_trajs] = pi_actions # Iterate MPPI for _ in range(self.cfg.iterations): # Sample actions - r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) - actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r + r = torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) + actions_sample = mean.unsqueeze(2) + std.unsqueeze(2) * r actions_sample = actions_sample.clamp(-1, 1) - actions[:, self.cfg.num_pi_trajs:] = actions_sample + actions[:, :, self.cfg.num_pi_trajs:] = actions_sample if self.cfg.multitask: actions = actions * self.model._action_masks[task] # Compute elite actions value = self._estimate_value(z, actions, task).nan_to_num(0) - elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices - elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices + elite_value, elite_actions = value[elite_idxs], actions[:, :, elite_idxs] # Update parameters - max_value = elite_value.max(0).values - score = torch.exp(self.cfg.temperature*(elite_value - max_value)) - score = score / score.sum(0) - mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) - std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() + max_value = elite_value.max(1).values + score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1))) + score = score / score.sum(1) + mean = (score.unsqueeze(1) * elite_actions).sum(dim=2) / (score.sum(1) + 1e-9) + std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / (score.sum(1) + 1e-9)).sqrt() std = std.clamp(self.cfg.min_std, self.cfg.max_std) if self.cfg.multitask: mean = mean * self.model._action_masks[task] @@ -185,12 +185,13 @@ class TDMPC2(torch.nn.Module): # Select action rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) - a, std = actions[0], std[0] + action, std = actions[:, 0], std[:, 0] if not eval_mode: - a = a + std * torch.randn(self.cfg.action_dim, device=std.device) + action = action + std * torch.randn(self.cfg.action_dim, device=std.device) self._prev_mean.copy_(mean) - return a.clamp(-1, 1) - + return action.clamp(-1, 1) + + def update_pi(self, zs, task): """ Update policy using a sequence of latent states. diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 0d2f062..1ad26ab 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -1,6 +1,5 @@ from time import time -import numpy as np import torch from tensordict.tensordict import TensorDict from trainer.base import Trainer @@ -25,12 +24,12 @@ class OnlineTrainer(Trainer): def eval(self): """Evaluate a TD-MPC2 agent.""" - ep_rewards, ep_successes = [], [] - for i in range(self.cfg.eval_episodes): - obs, done, ep_reward, t = self.env.reset(), False, 0, 0 + ep_rewards = [] + for i in range(self.cfg.eval_episodes // self.cfg.num_envs): + obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0 if self.cfg.save_video: self.logger.video.init(self.env, enabled=(i==0)) - while not done: + while not done.any(): torch.compiler.cudagraph_mark_step_begin() action = self.agent.act(obs, t0=t==0, eval_mode=True) obs, reward, done, info = self.env.step(action) @@ -38,13 +37,13 @@ class OnlineTrainer(Trainer): t += 1 if self.cfg.save_video: self.logger.video.record(self.env) + assert done.all(), 'Vectorized environments must reset all environments at once.' ep_rewards.append(ep_reward) - ep_successes.append(info['success']) if self.cfg.save_video: self.logger.video.save(self._step) return dict( - episode_reward=np.nanmean(ep_rewards), - episode_success=np.nanmean(ep_successes), + episode_reward=torch.cat(ep_rewards).mean(), + episode_success=info['success'].mean(), ) def to_td(self, obs, action=None, reward=None): @@ -56,24 +55,25 @@ class OnlineTrainer(Trainer): if action is None: action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: - reward = torch.tensor(float('nan')) - td = TensorDict( + reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs) + td = TensorDict(dict( obs=obs, action=action.unsqueeze(0), reward=reward.unsqueeze(0), - batch_size=(1,)) + ), batch_size=(1, self.cfg.num_envs,)) return td def train(self): """Train a TD-MPC2 agent.""" - train_metrics, done, eval_next = {}, True, False + train_metrics, done, eval_next = {}, torch.tensor(True), True while self._step <= self.cfg.steps: # Evaluate agent periodically if self._step % self.cfg.eval_freq == 0: eval_next = True # Reset environment - if done: + if done.any(): + assert done.all(), 'Vectorized environments must reset all environments at once.' if eval_next: eval_metrics = self.eval() eval_metrics.update(self.common_metrics()) @@ -81,13 +81,14 @@ class OnlineTrainer(Trainer): eval_next = False if self._step > 0: + tds = torch.cat(self._tds) train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), - episode_success=info['success'], + episode_reward=tds['reward'].nansum(0).mean(), + episode_success=info['success'].nanmean(), ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + self._ep_idx = self.buffer.add(tds) obs = self.env.reset() self._tds = [self.to_td(obs)] @@ -103,14 +104,14 @@ class OnlineTrainer(Trainer): # Update agent if self._step >= self.cfg.seed_steps: if self._step == self.cfg.seed_steps: - num_updates = self.cfg.seed_steps + num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update) print('Pretraining agent on seed data...') else: - num_updates = 1 + num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update)) for _ in range(num_updates): _train_metrics = self.agent.update(self.buffer) train_metrics.update(_train_metrics) - self._step += 1 - + self._step += self.cfg.num_envs + self.logger.finish(self.agent)