From c6d1bd85bf6d86f6501a943bf69956efb381e6ee Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 7 Jan 2024 11:52:53 -0800 Subject: [PATCH] support distributed training --- tdmpc2/common/buffer.py | 26 ++++++++++++--- tdmpc2/common/logger.py | 8 +++-- tdmpc2/common/scale.py | 4 +-- tdmpc2/common/world_model.py | 36 ++++++++++++++------- tdmpc2/config.yaml | 2 ++ tdmpc2/envs/__init__.py | 3 +- tdmpc2/tdmpc2.py | 6 ++-- tdmpc2/train.py | 53 ++++++++++++++++++++++++------- tdmpc2/trainer/base.py | 5 +-- tdmpc2/trainer/offline_trainer.py | 26 +++++++++++---- 10 files changed, 123 insertions(+), 46 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 29cc293..512d613 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -13,7 +13,7 @@ class Buffer(): def __init__(self, cfg): self.cfg = cfg - self._device = torch.device('cuda') + self._device = torch.device(self.cfg.rank) self._capacity = min(cfg.buffer_size, cfg.steps) self._sampler = SliceSampler( num_slices=self.cfg.batch_size, @@ -23,6 +23,7 @@ class Buffer(): ) self._batch_size = cfg.batch_size * (cfg.horizon+1) self._num_eps = 0 + self._num_transitions = 0 @property def capacity(self): @@ -33,6 +34,11 @@ class Buffer(): def num_eps(self): """Return the number of episodes in the buffer.""" return self._num_eps + + @property + def num_transitions(self): + """Return the number of transitions in the buffer.""" + return self._num_transitions def _reserve_buffer(self, storage): """ @@ -48,7 +54,11 @@ class Buffer(): def _init(self, tds): """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" - print(f'Buffer capacity: {self._capacity:,}') + if self.cfg.rank == 0: + if self.cfg.world_size > 1: + print(f'Buffer capacity per process: {self._capacity:,}') + else: + print(f'Buffer capacity: {self._capacity:,}') mem_free, _ = torch.cuda.mem_get_info() bytes_per_step = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ @@ -56,10 +66,15 @@ class Buffer(): for v in tds.values() ]) / len(tds) total_bytes = bytes_per_step*self._capacity - print(f'Storage required: {total_bytes/1e9:.2f} GB') + if self.cfg.rank == 0: + if self.cfg.world_size > 1: + print(f'Storage required per process: {total_bytes/1e9:.2f} GB') + else: + print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory - storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' - print(f'Using {storage_device.upper()} memory for storage.') + storage_device = self.cfg.rank if 2.5*total_bytes < mem_free else 'cpu' + if self.cfg.rank == 0: + print(f'Using {storage_device.upper()} memory for storage.') return self._reserve_buffer( LazyTensorStorage(self._capacity, device=torch.device(storage_device)) ) @@ -88,6 +103,7 @@ class Buffer(): self._buffer = self._init(td) self._buffer.extend(td) self._num_eps += 1 + self._num_transitions += len(td) return self._num_eps def sample(self): diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index ea26996..96ad50b 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -113,11 +113,13 @@ class Logger: self._group = cfg_to_group(cfg) self._seed = cfg.seed self._eval = [] - print_run(cfg) + if cfg.rank == 0: + print_run(cfg) self.project = cfg.get("wandb_project", "none") self.entity = cfg.get("wandb_entity", "none") - if cfg.disable_wandb or self.project == "none" or self.entity == "none": - print(colored("Wandb disabled.", "blue", attrs=["bold"])) + if cfg.rank == 0 or cfg.disable_wandb or self.project == "none" or self.entity == "none": + if cfg.rank == 0: + print(colored("Wandb disabled.", "blue", attrs=["bold"])) cfg.save_agent = False cfg.save_video = False self._wandb = None diff --git a/tdmpc2/common/scale.py b/tdmpc2/common/scale.py index 63f0bb2..2430dd7 100644 --- a/tdmpc2/common/scale.py +++ b/tdmpc2/common/scale.py @@ -6,8 +6,8 @@ class RunningScale: def __init__(self, cfg): self.cfg = cfg - self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda')) - self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')) + self._value = torch.ones(1, dtype=torch.float32, device=torch.device(cfg.rank)) + self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device(cfg.rank)) def state_dict(self): return dict(value=self._value, percentiles=self._percentiles) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index a780ad0..996da29 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -3,13 +3,15 @@ from copy import deepcopy import numpy as np import torch import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from tensordict.tensordict import TensorDict from common import layers, math, init class WorldModel(nn.Module): """ - TD-MPC2 implicit world model architecture. + Distributed version of the TD-MPC2 world model architecture. Can be used for both single-task and multi-task experiments. """ @@ -17,24 +19,36 @@ class WorldModel(nn.Module): super().__init__() self.cfg = cfg if cfg.multitask: - self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) + self.__task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim) for i in range(len(cfg.tasks)): self._action_masks[i, :cfg.action_dims[i]] = 1. - self._encoder = layers.enc(cfg) - self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) - self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) - self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) - self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) + self.__encoder = layers.enc(cfg) + self.__dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) + self.__reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) + self.__pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) + self.__Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) - init.zero_([self._reward[-1].weight, self._Qs.params[-2]]) - self._target_Qs = deepcopy(self._Qs).requires_grad_(False) - self.log_std_min = torch.tensor(cfg.log_std_min) - self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min + init.zero_([self.__reward[-1].weight, self.__Qs.params[-2]]) + self._target_Qs = deepcopy(self.__Qs).requires_grad_(False) + self.log_std_min = torch.tensor(cfg.log_std_min, requires_grad=False) + self.log_std_dif = torch.tensor(cfg.log_std_max, requires_grad=False) - self.log_std_min + self.to(cfg.rank) + if cfg.multitask: + self._task_emb = DDP(self.__task_emb, device_ids=[cfg.rank]) + self._encoder = nn.ModuleDict({k: DDP(v, device_ids=[cfg.rank]) for k, v in self.__encoder.items()}) + self._dynamics = DDP(self.__dynamics, device_ids=[cfg.rank]) + self._reward = DDP(self.__reward, device_ids=[cfg.rank]) + self._pi = DDP(self.__pi, device_ids=[cfg.rank]) + self._Qs = DDP(self.__Qs, device_ids=[cfg.rank]) @property def total_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def __repr__(self): + modules = '\n'.join([str(m) for m in [self._encoder, self._dynamics, self._reward, self._pi, self._Qs]]) + return f"{self.__class__.__name__}({modules})\nLearnable parameters: {self.total_params:,}" def to(self, *args, **kwargs): """ diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b720923..45f3711 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -11,6 +11,7 @@ eval_episodes: 10 eval_freq: 50000 # training +world_size: 1 steps: 10_000_000 batch_size: 256 reward_coef: 0.1 @@ -74,6 +75,7 @@ save_agent: true seed: 1 # convenience +rank: ??? work_dir: ??? task_title: ??? multitask: ??? diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 6326a9e..75f5cc3 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -35,7 +35,8 @@ def make_multitask_env(cfg): """ Make a multi-task environment for TD-MPC2 experiments. """ - print('Creating multi-task environment with tasks:', cfg.tasks) + if cfg.rank == 0: + print('Creating multi-task environment with tasks:', cfg.tasks) envs = [] for task in cfg.tasks: _cfg = deepcopy(cfg) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 9d49cf8..92bd58e 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -16,8 +16,8 @@ class TDMPC2: def __init__(self, cfg): self.cfg = cfg - self.device = torch.device('cuda') - self.model = WorldModel(cfg).to(self.device) + self.device = torch.device(cfg.rank) + self.model = WorldModel(cfg) self.optim = torch.optim.Adam([ {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, @@ -30,7 +30,7 @@ class TDMPC2: self.scale = RunningScale(cfg) self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.discount = torch.tensor( - [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda' + [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank ) if self.cfg.multitask else self._get_discount(cfg.episode_length) def _get_discount(self, episode_length): diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 5953bb2..f5ac2b1 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -14,14 +14,28 @@ from common.buffer import Buffer from envs import make_env from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer -from trainer.online_trainer import OnlineTrainer from common.logger import Logger torch.backends.cudnn.benchmark = True -@hydra.main(config_name='config', config_path='.') -def train(cfg: dict): +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + torch.distributed.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size + ) + + +def cleanup(): + torch.distributed.destroy_process_group() + + +def train(rank: int, cfg: dict): """ Script for training single-task / multi-task TD-MPC2 agents. @@ -40,14 +54,11 @@ def train(cfg: dict): $ python train.py task=dog-run steps=7000000 ``` """ - assert torch.cuda.is_available() - assert cfg.steps > 0, 'Must train for at least 1 step.' - cfg = parse_cfg(cfg) - set_seed(cfg.seed) - print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) + setup(rank, cfg.world_size) + set_seed(cfg.seed + rank) + cfg.rank = rank - trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer - trainer = trainer_cls( + trainer = OfflineTrainer( cfg=cfg, env=make_env(cfg), agent=TDMPC2(cfg), @@ -55,8 +66,26 @@ def train(cfg: dict): logger=Logger(cfg), ) trainer.train() - print('\nTraining completed successfully') + if cfg.rank == 0: + print('\nTraining completed successfully') + cleanup() + + +@hydra.main(config_name='config', config_path='.') +def launch(cfg: dict): + assert torch.cuda.is_available() + assert cfg.world_size > 0, 'Must train with at least 1 GPU.' + assert cfg.task in {'mt30', 'mt80'}, 'Distributed training is only supported for multi-task experiments.' + assert cfg.steps > 0, 'Must train for at least 1 step.' + cfg = parse_cfg(cfg) + print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) + torch.multiprocessing.spawn( + train, + args=(cfg,), + nprocs=cfg.world_size, + join=True, + ) if __name__ == '__main__': - train() + launch() diff --git a/tdmpc2/trainer/base.py b/tdmpc2/trainer/base.py index aaf1a39..3c4d0b1 100755 --- a/tdmpc2/trainer/base.py +++ b/tdmpc2/trainer/base.py @@ -7,8 +7,9 @@ class Trainer: self.agent = agent self.buffer = buffer self.logger = logger - print("Learnable parameters: {:,}".format(self.agent.model.total_params)) - print('Architecture:', self.agent.model) + if cfg.rank == 0: + print("Learnable parameters: {:,}".format(self.agent.model.total_params)) + print('Architecture:', self.agent.model) def eval(self): """Evaluate a TD-MPC2 agent.""" diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index 1bace8e..a9b74b1 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -50,12 +50,21 @@ class OfflineTrainer(Trainer): fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fps = sorted(glob(str(fp))) assert len(fps) > 0, f'No data found at {fp}' - print(f'Found {len(fps)} files in {fp}') - + if self.cfg.rank == 0: + print(f'Found {len(fps)} files in {fp}') + + # Distribute data across processes + assert len(fps) >= self.cfg.world_size, \ + f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}' + fps = fps[self.cfg.rank::self.cfg.world_size] + print(f'Process {self.cfg.rank} has {len(fps)} files') + assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}' + # Create buffer for sampling _cfg = deepcopy(self.cfg) _cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501 _cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000 + _cfg.buffer_size //= self.cfg.world_size _cfg.steps = _cfg.buffer_size self.buffer = Buffer(_cfg) for fp in tqdm(fps, desc='Loading data'): @@ -65,10 +74,12 @@ class OfflineTrainer(Trainer): f'please double-check your config.' for i in range(len(td)): self.buffer.add(td[i]) - assert self.buffer.num_eps == self.buffer.capacity, \ - f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.' + if self.buffer.num_transitions > self.buffer.capacity: + print(f'Buffer has {self.buffer.num_transitions} transitions,' \ + f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.') - print(f'Training agent for {self.cfg.steps} iterations...') + if self.cfg.rank == 0: + print(f'Training agent for {self.cfg.steps} iterations...') metrics = {} for i in range(self.cfg.steps): @@ -76,7 +87,7 @@ class OfflineTrainer(Trainer): train_metrics = self.agent.update(self.buffer) # Evaluate agent periodically - if i % self.cfg.eval_freq == 0 or i % 10_000 == 0: + if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0): metrics = { 'iteration': i, 'total_time': time() - self._start_time, @@ -89,4 +100,5 @@ class OfflineTrainer(Trainer): self.logger.save_agent(self.agent, identifier=f'{i}') self.logger.log(metrics, 'pretrain') - self.logger.finish(self.agent) + if self.cfg.rank == 0: + self.logger.finish(self.agent)