support distributed training

This commit is contained in:
Nicklas Hansen
2024-01-07 11:52:53 -08:00
parent a7ff00b0cd
commit 33555b5982
11 changed files with 125 additions and 46 deletions

View File

@@ -112,6 +112,8 @@ $ python train.py task=walker-walk obs=rgb
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments. We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
**As of Jan 7, 2024 the TD-MPC2 codebase also supports multi-GPU training for multi-task offline RL experiments**; use branch `distributed` and argument `world_size=N` to train on `N` GPUs. We cannot guarantee that distributed training will yield the same results, but they appear to be similar based on our limited testing.
---- ----
## Citation ## Citation

View File

@@ -13,7 +13,7 @@ class Buffer():
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda') self._device = torch.device(self.cfg.rank)
self._capacity = min(cfg.buffer_size, cfg.steps) self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler( self._sampler = SliceSampler(
num_slices=self.cfg.batch_size, num_slices=self.cfg.batch_size,
@@ -23,6 +23,7 @@ class Buffer():
) )
self._batch_size = cfg.batch_size * (cfg.horizon+1) self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0 self._num_eps = 0
self._num_transitions = 0
@property @property
def capacity(self): def capacity(self):
@@ -34,6 +35,11 @@ class Buffer():
"""Return the number of episodes in the buffer.""" """Return the number of episodes in the buffer."""
return self._num_eps return self._num_eps
@property
def num_transitions(self):
"""Return the number of transitions in the buffer."""
return self._num_transitions
def _reserve_buffer(self, storage): def _reserve_buffer(self, storage):
""" """
Reserve a buffer with the given storage. Reserve a buffer with the given storage.
@@ -48,6 +54,10 @@ class Buffer():
def _init(self, tds): def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements.""" """Initialize the replay buffer. Use the first episode to estimate storage requirements."""
if self.cfg.rank == 0:
if self.cfg.world_size > 1:
print(f'Buffer capacity per process: {self._capacity:,}')
else:
print(f'Buffer capacity: {self._capacity:,}') print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info() mem_free, _ = torch.cuda.mem_get_info()
bytes_per_step = sum([ bytes_per_step = sum([
@@ -56,9 +66,14 @@ class Buffer():
for v in tds.values() for v in tds.values()
]) / len(tds) ]) / len(tds)
total_bytes = bytes_per_step*self._capacity total_bytes = bytes_per_step*self._capacity
if self.cfg.rank == 0:
if self.cfg.world_size > 1:
print(f'Storage required per process: {total_bytes/1e9:.2f} GB')
else:
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' storage_device = self.cfg.rank if 2.5*total_bytes < mem_free else 'cpu'
if self.cfg.rank == 0:
print(f'Using {storage_device.upper()} memory for storage.') print(f'Using {storage_device.upper()} memory for storage.')
return self._reserve_buffer( return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device(storage_device)) LazyTensorStorage(self._capacity, device=torch.device(storage_device))
@@ -88,6 +103,7 @@ class Buffer():
self._buffer = self._init(td) self._buffer = self._init(td)
self._buffer.extend(td) self._buffer.extend(td)
self._num_eps += 1 self._num_eps += 1
self._num_transitions += len(td)
return self._num_eps return self._num_eps
def sample(self): def sample(self):

View File

@@ -113,10 +113,12 @@ class Logger:
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
self._seed = cfg.seed self._seed = cfg.seed
self._eval = [] self._eval = []
if cfg.rank == 0:
print_run(cfg) print_run(cfg)
self.project = cfg.get("wandb_project", "none") self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none") self.entity = cfg.get("wandb_entity", "none")
if cfg.disable_wandb or self.project == "none" or self.entity == "none": if cfg.rank == 0 or cfg.disable_wandb or self.project == "none" or self.entity == "none":
if cfg.rank == 0:
print(colored("Wandb disabled.", "blue", attrs=["bold"])) print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False cfg.save_agent = False
cfg.save_video = False cfg.save_video = False

View File

@@ -6,8 +6,8 @@ class RunningScale:
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda')) self._value = torch.ones(1, dtype=torch.float32, device=torch.device(cfg.rank))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')) self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device(cfg.rank))
def state_dict(self): def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles) return dict(value=self._value, percentiles=self._percentiles)

View File

@@ -3,13 +3,15 @@ from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from tensordict.tensordict import TensorDict
from common import layers, math, init from common import layers, math, init
class WorldModel(nn.Module): class WorldModel(nn.Module):
""" """
TD-MPC2 implicit world model architecture. Distributed version of the TD-MPC2 world model architecture.
Can be used for both single-task and multi-task experiments. Can be used for both single-task and multi-task experiments.
""" """
@@ -17,25 +19,37 @@ class WorldModel(nn.Module):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
if cfg.multitask: if cfg.multitask:
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) self.__task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim) self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
for i in range(len(cfg.tasks)): for i in range(len(cfg.tasks)):
self._action_masks[i, :cfg.action_dims[i]] = 1. self._action_masks[i, :cfg.action_dims[i]] = 1.
self._encoder = layers.enc(cfg) self.__encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self.__dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) self.__reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self.__pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.__Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params[-2]]) init.zero_([self.__reward[-1].weight, self.__Qs.params[-2]])
self._target_Qs = deepcopy(self._Qs).requires_grad_(False) self._target_Qs = deepcopy(self.__Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min) self.log_std_min = torch.tensor(cfg.log_std_min, requires_grad=False)
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min self.log_std_dif = torch.tensor(cfg.log_std_max, requires_grad=False) - self.log_std_min
self.to(cfg.rank)
if cfg.multitask:
self._task_emb = DDP(self.__task_emb, device_ids=[cfg.rank])
self._encoder = nn.ModuleDict({k: DDP(v, device_ids=[cfg.rank]) for k, v in self.__encoder.items()})
self._dynamics = DDP(self.__dynamics, device_ids=[cfg.rank])
self._reward = DDP(self.__reward, device_ids=[cfg.rank])
self._pi = DDP(self.__pi, device_ids=[cfg.rank])
self._Qs = DDP(self.__Qs, device_ids=[cfg.rank])
@property @property
def total_params(self): def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters() if p.requires_grad)
def __repr__(self):
modules = '\n'.join([str(m) for m in [self._encoder, self._dynamics, self._reward, self._pi, self._Qs]])
return f"{self.__class__.__name__}({modules})\nLearnable parameters: {self.total_params:,}"
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
""" """
Overriding `to` method to also move additional tensors to device. Overriding `to` method to also move additional tensors to device.

View File

@@ -11,6 +11,7 @@ eval_episodes: 10
eval_freq: 50000 eval_freq: 50000
# training # training
world_size: 1
steps: 10_000_000 steps: 10_000_000
batch_size: 256 batch_size: 256
reward_coef: 0.1 reward_coef: 0.1
@@ -74,6 +75,7 @@ save_agent: true
seed: 1 seed: 1
# convenience # convenience
rank: ???
work_dir: ??? work_dir: ???
task_title: ??? task_title: ???
multitask: ??? multitask: ???

View File

@@ -35,6 +35,7 @@ def make_multitask_env(cfg):
""" """
Make a multi-task environment for TD-MPC2 experiments. Make a multi-task environment for TD-MPC2 experiments.
""" """
if cfg.rank == 0:
print('Creating multi-task environment with tasks:', cfg.tasks) print('Creating multi-task environment with tasks:', cfg.tasks)
envs = [] envs = []
for task in cfg.tasks: for task in cfg.tasks:

View File

@@ -16,8 +16,8 @@ class TDMPC2:
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self.device = torch.device('cuda') self.device = torch.device(cfg.rank)
self.model = WorldModel(cfg).to(self.device) self.model = WorldModel(cfg)
self.optim = torch.optim.Adam([ self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()}, {'params': self.model._dynamics.parameters()},
@@ -30,7 +30,7 @@ class TDMPC2:
self.scale = RunningScale(cfg) self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
def _get_discount(self, episode_length): def _get_discount(self, episode_length):

View File

@@ -14,14 +14,28 @@ from common.buffer import Buffer
from envs import make_env from envs import make_env
from tdmpc2 import TDMPC2 from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer from trainer.offline_trainer import OfflineTrainer
from trainer.online_trainer import OnlineTrainer
from common.logger import Logger from common.logger import Logger
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@hydra.main(config_name='config', config_path='.') def setup(rank, world_size):
def train(cfg: dict): os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
torch.distributed.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size
)
def cleanup():
torch.distributed.destroy_process_group()
def train(rank: int, cfg: dict):
""" """
Script for training single-task / multi-task TD-MPC2 agents. Script for training single-task / multi-task TD-MPC2 agents.
@@ -40,14 +54,11 @@ def train(cfg: dict):
$ python train.py task=dog-run steps=7000000 $ python train.py task=dog-run steps=7000000
``` ```
""" """
assert torch.cuda.is_available() setup(rank, cfg.world_size)
assert cfg.steps > 0, 'Must train for at least 1 step.' set_seed(cfg.seed + rank)
cfg = parse_cfg(cfg) cfg.rank = rank
set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer = OfflineTrainer(
trainer = trainer_cls(
cfg=cfg, cfg=cfg,
env=make_env(cfg), env=make_env(cfg),
agent=TDMPC2(cfg), agent=TDMPC2(cfg),
@@ -55,8 +66,26 @@ def train(cfg: dict):
logger=Logger(cfg), logger=Logger(cfg),
) )
trainer.train() trainer.train()
if cfg.rank == 0:
print('\nTraining completed successfully') print('\nTraining completed successfully')
cleanup()
@hydra.main(config_name='config', config_path='.')
def launch(cfg: dict):
assert torch.cuda.is_available()
assert cfg.world_size > 0, 'Must train with at least 1 GPU.'
assert cfg.task in {'mt30', 'mt80'}, 'Distributed training is only supported for multi-task experiments.'
assert cfg.steps > 0, 'Must train for at least 1 step.'
cfg = parse_cfg(cfg)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
torch.multiprocessing.spawn(
train,
args=(cfg,),
nprocs=cfg.world_size,
join=True,
)
if __name__ == '__main__': if __name__ == '__main__':
train() launch()

View File

@@ -7,6 +7,7 @@ class Trainer:
self.agent = agent self.agent = agent
self.buffer = buffer self.buffer = buffer
self.logger = logger self.logger = logger
if cfg.rank == 0:
print("Learnable parameters: {:,}".format(self.agent.model.total_params)) print("Learnable parameters: {:,}".format(self.agent.model.total_params))
print('Architecture:', self.agent.model) print('Architecture:', self.agent.model)

View File

@@ -50,12 +50,21 @@ class OfflineTrainer(Trainer):
fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp))) fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}' assert len(fps) > 0, f'No data found at {fp}'
if self.cfg.rank == 0:
print(f'Found {len(fps)} files in {fp}') print(f'Found {len(fps)} files in {fp}')
# Distribute data across processes
assert len(fps) >= self.cfg.world_size, \
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
fps = fps[self.cfg.rank::self.cfg.world_size]
print(f'Process {self.cfg.rank} has {len(fps)} files')
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
# Create buffer for sampling # Create buffer for sampling
_cfg = deepcopy(self.cfg) _cfg = deepcopy(self.cfg)
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501 _cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000 _cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
_cfg.buffer_size //= self.cfg.world_size
_cfg.steps = _cfg.buffer_size _cfg.steps = _cfg.buffer_size
self.buffer = Buffer(_cfg) self.buffer = Buffer(_cfg)
for fp in tqdm(fps, desc='Loading data'): for fp in tqdm(fps, desc='Loading data'):
@@ -65,9 +74,11 @@ class OfflineTrainer(Trainer):
f'please double-check your config.' f'please double-check your config.'
for i in range(len(td)): for i in range(len(td)):
self.buffer.add(td[i]) self.buffer.add(td[i])
assert self.buffer.num_eps == self.buffer.capacity, \ if self.buffer.num_transitions > self.buffer.capacity:
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.' print(f'Buffer has {self.buffer.num_transitions} transitions,' \
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
if self.cfg.rank == 0:
print(f'Training agent for {self.cfg.steps} iterations...') print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {} metrics = {}
for i in range(self.cfg.steps): for i in range(self.cfg.steps):
@@ -76,7 +87,7 @@ class OfflineTrainer(Trainer):
train_metrics = self.agent.update(self.buffer) train_metrics = self.agent.update(self.buffer)
# Evaluate agent periodically # Evaluate agent periodically
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0: if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
metrics = { metrics = {
'iteration': i, 'iteration': i,
'total_time': time() - self._start_time, 'total_time': time() - self._start_time,
@@ -89,4 +100,5 @@ class OfflineTrainer(Trainer):
self.logger.save_agent(self.agent, identifier=f'{i}') self.logger.save_agent(self.agent, identifier=f'{i}')
self.logger.log(metrics, 'pretrain') self.logger.log(metrics, 'pretrain')
if self.cfg.rank == 0:
self.logger.finish(self.agent) self.logger.finish(self.agent)