Compare commits
8 Commits
main
...
distribute
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1afbccb05 | ||
|
|
c218c0ff1b | ||
|
|
d3bff48d58 | ||
|
|
c16f2557bb | ||
|
|
de87519c60 | ||
|
|
20f4064dfa | ||
|
|
c6d1bd85bf | ||
|
|
33555b5982 |
@@ -39,9 +39,9 @@ dependencies:
|
||||
- protobuf==4.25.2
|
||||
- pillow==10.2.0
|
||||
- pyquaternion==0.9.9
|
||||
- tensordict-nightly==2024.1.10
|
||||
- tensordict-nightly==2024.3.26
|
||||
- termcolor==2.4.0
|
||||
- torchrl-nightly==2024.1.10
|
||||
- torchrl-nightly==2024.3.26
|
||||
- transforms3d==0.4.1
|
||||
- trimesh==4.0.9
|
||||
- tqdm==4.66.1
|
||||
|
||||
@@ -12,16 +12,18 @@ class Buffer():
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._device = torch.device('cuda')
|
||||
self._device = torch.device(self.cfg.rank)
|
||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||
self._sampler = SliceSampler(
|
||||
num_slices=self.cfg.batch_size,
|
||||
end_key=None,
|
||||
traj_key='episode',
|
||||
truncated_key=None,
|
||||
strict_length=True,
|
||||
)
|
||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||
self._num_eps = 0
|
||||
self._num_transitions = 0
|
||||
|
||||
@property
|
||||
def capacity(self):
|
||||
@@ -32,6 +34,11 @@ class Buffer():
|
||||
def num_eps(self):
|
||||
"""Return the number of episodes in the buffer."""
|
||||
return self._num_eps
|
||||
|
||||
@property
|
||||
def num_transitions(self):
|
||||
"""Return the number of transitions in the buffer."""
|
||||
return self._num_transitions
|
||||
|
||||
def _reserve_buffer(self, storage):
|
||||
"""
|
||||
@@ -47,7 +54,11 @@ class Buffer():
|
||||
|
||||
def _init(self, tds):
|
||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||
print(f'Buffer capacity: {self._capacity:,}')
|
||||
if self.cfg.rank == 0:
|
||||
if self.cfg.world_size > 1:
|
||||
print(f'Buffer capacity per process: {self._capacity:,}')
|
||||
else:
|
||||
print(f'Buffer capacity: {self._capacity:,}')
|
||||
mem_free, _ = torch.cuda.mem_get_info()
|
||||
bytes_per_step = sum([
|
||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||
@@ -55,10 +66,15 @@ class Buffer():
|
||||
for v in tds.values()
|
||||
]) / len(tds)
|
||||
total_bytes = bytes_per_step*self._capacity
|
||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||
if self.cfg.rank == 0:
|
||||
if self.cfg.world_size > 1:
|
||||
print(f'Storage required per process: {total_bytes/1e9:.2f} GB')
|
||||
else:
|
||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||
# Heuristic: decide whether to use CUDA or CPU memory
|
||||
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
||||
print(f'Using {storage_device.upper()} memory for storage.')
|
||||
storage_device = self.cfg.rank if 2.5*total_bytes < mem_free else 'cpu'
|
||||
if self.cfg.rank == 0:
|
||||
print(f'Using {storage_device.upper()} memory for storage.')
|
||||
return self._reserve_buffer(
|
||||
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
|
||||
)
|
||||
@@ -87,6 +103,7 @@ class Buffer():
|
||||
self._buffer = self._init(td)
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += 1
|
||||
self._num_transitions += len(td)
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
|
||||
@@ -113,11 +113,13 @@ class Logger:
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
self._eval = []
|
||||
print_run(cfg)
|
||||
if cfg.rank == 0:
|
||||
print_run(cfg)
|
||||
self.project = cfg.get("wandb_project", "none")
|
||||
self.entity = cfg.get("wandb_entity", "none")
|
||||
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
||||
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
||||
if cfg.rank == 0 or cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
||||
if cfg.rank == 0:
|
||||
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
||||
cfg.save_agent = False
|
||||
cfg.save_video = False
|
||||
self._wandb = None
|
||||
|
||||
@@ -6,8 +6,8 @@ class RunningScale:
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))
|
||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))
|
||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device(cfg.rank))
|
||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device(cfg.rank))
|
||||
|
||||
def state_dict(self):
|
||||
return dict(value=self._value, percentiles=self._percentiles)
|
||||
|
||||
@@ -3,13 +3,15 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tensordict.tensordict import TensorDict
|
||||
|
||||
from common import layers, math, init
|
||||
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
"""
|
||||
TD-MPC2 implicit world model architecture.
|
||||
Distributed version of the TD-MPC2 world model architecture.
|
||||
Can be used for both single-task and multi-task experiments.
|
||||
"""
|
||||
|
||||
@@ -17,24 +19,36 @@ class WorldModel(nn.Module):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
if cfg.multitask:
|
||||
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
||||
self.__task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
||||
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
|
||||
for i in range(len(cfg.tasks)):
|
||||
self._action_masks[i, :cfg.action_dims[i]] = 1.
|
||||
self._encoder = layers.enc(cfg)
|
||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.__encoder = layers.enc(cfg)
|
||||
self.__dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self.__reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self.__pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self.__Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.apply(init.weight_init)
|
||||
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
|
||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
||||
self.log_std_min = torch.tensor(cfg.log_std_min)
|
||||
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
|
||||
init.zero_([self.__reward[-1].weight, self.__Qs.params[-2]])
|
||||
self._target_Qs = deepcopy(self.__Qs).requires_grad_(False)
|
||||
self.log_std_min = torch.tensor(cfg.log_std_min, requires_grad=False)
|
||||
self.log_std_dif = torch.tensor(cfg.log_std_max, requires_grad=False) - self.log_std_min
|
||||
self.to(cfg.rank)
|
||||
if cfg.multitask:
|
||||
self._task_emb = DDP(self.__task_emb, device_ids=[cfg.rank])
|
||||
self._encoder = nn.ModuleDict({k: DDP(v, device_ids=[cfg.rank]) for k, v in self.__encoder.items()})
|
||||
self._dynamics = DDP(self.__dynamics, device_ids=[cfg.rank])
|
||||
self._reward = DDP(self.__reward, device_ids=[cfg.rank])
|
||||
self._pi = DDP(self.__pi, device_ids=[cfg.rank])
|
||||
self._Qs = DDP(self.__Qs, device_ids=[cfg.rank])
|
||||
|
||||
@property
|
||||
def total_params(self):
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
def __repr__(self):
|
||||
modules = '\n'.join([str(m) for m in [self._encoder, self._dynamics, self._reward, self._pi, self._Qs]])
|
||||
return f"{self.__class__.__name__}({modules})\nLearnable parameters: {self.total_params:,}"
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ eval_episodes: 10
|
||||
eval_freq: 50000
|
||||
|
||||
# training
|
||||
world_size: 1
|
||||
steps: 10_000_000
|
||||
batch_size: 256
|
||||
reward_coef: 0.1
|
||||
@@ -74,6 +75,7 @@ save_agent: true
|
||||
seed: 1
|
||||
|
||||
# convenience
|
||||
rank: ???
|
||||
work_dir: ???
|
||||
task_title: ???
|
||||
multitask: ???
|
||||
|
||||
@@ -35,7 +35,8 @@ def make_multitask_env(cfg):
|
||||
"""
|
||||
Make a multi-task environment for TD-MPC2 experiments.
|
||||
"""
|
||||
print('Creating multi-task environment with tasks:', cfg.tasks)
|
||||
if cfg.rank == 0:
|
||||
print('Creating multi-task environment with tasks:', cfg.tasks)
|
||||
envs = []
|
||||
for task in cfg.tasks:
|
||||
_cfg = deepcopy(cfg)
|
||||
|
||||
@@ -16,8 +16,8 @@ class TDMPC2:
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.device = torch.device('cuda')
|
||||
self.model = WorldModel(cfg).to(self.device)
|
||||
self.device = torch.device(cfg.rank)
|
||||
self.model = WorldModel(cfg)
|
||||
self.optim = torch.optim.Adam([
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
{'params': self.model._dynamics.parameters()},
|
||||
@@ -30,7 +30,7 @@ class TDMPC2:
|
||||
self.scale = RunningScale(cfg)
|
||||
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda'
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
|
||||
def _get_discount(self, episode_length):
|
||||
|
||||
@@ -14,14 +14,28 @@ from common.buffer import Buffer
|
||||
from envs import make_env
|
||||
from tdmpc2 import TDMPC2
|
||||
from trainer.offline_trainer import OfflineTrainer
|
||||
from trainer.online_trainer import OnlineTrainer
|
||||
from common.logger import Logger
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def train(cfg: dict):
|
||||
def setup(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
# initialize the process group
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
|
||||
def cleanup():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def train(rank: int, cfg: dict):
|
||||
"""
|
||||
Script for training single-task / multi-task TD-MPC2 agents.
|
||||
|
||||
@@ -40,14 +54,11 @@ def train(cfg: dict):
|
||||
$ python train.py task=dog-run steps=7000000
|
||||
```
|
||||
"""
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
setup(rank, cfg.world_size)
|
||||
set_seed(cfg.seed + rank)
|
||||
cfg.rank = rank
|
||||
|
||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||
trainer = trainer_cls(
|
||||
trainer = OfflineTrainer(
|
||||
cfg=cfg,
|
||||
env=make_env(cfg),
|
||||
agent=TDMPC2(cfg),
|
||||
@@ -55,8 +66,26 @@ def train(cfg: dict):
|
||||
logger=Logger(cfg),
|
||||
)
|
||||
trainer.train()
|
||||
print('\nTraining completed successfully')
|
||||
if cfg.rank == 0:
|
||||
print('\nTraining completed successfully')
|
||||
cleanup()
|
||||
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def launch(cfg: dict):
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.world_size > 0, 'Must train with at least 1 GPU.'
|
||||
assert cfg.task in {'mt30', 'mt80'}, 'Distributed training is only supported for multi-task experiments.'
|
||||
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
||||
cfg = parse_cfg(cfg)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
torch.multiprocessing.spawn(
|
||||
train,
|
||||
args=(cfg,),
|
||||
nprocs=cfg.world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
launch()
|
||||
|
||||
@@ -7,8 +7,9 @@ class Trainer:
|
||||
self.agent = agent
|
||||
self.buffer = buffer
|
||||
self.logger = logger
|
||||
print('Architecture:', self.agent.model)
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
if cfg.rank == 0:
|
||||
print('Architecture:', self.agent.model)
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
|
||||
@@ -50,12 +50,21 @@ class OfflineTrainer(Trainer):
|
||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||
fps = sorted(glob(str(fp)))
|
||||
assert len(fps) > 0, f'No data found at {fp}'
|
||||
print(f'Found {len(fps)} files in {fp}')
|
||||
|
||||
if self.cfg.rank == 0:
|
||||
print(f'Found {len(fps)} files in {fp}')
|
||||
|
||||
# Distribute data across processes
|
||||
assert len(fps) >= self.cfg.world_size, \
|
||||
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
|
||||
fps = fps[self.cfg.rank::self.cfg.world_size]
|
||||
print(f'Process {self.cfg.rank} has {len(fps)} files')
|
||||
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
|
||||
|
||||
# Create buffer for sampling
|
||||
_cfg = deepcopy(self.cfg)
|
||||
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
|
||||
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
|
||||
_cfg.buffer_size //= self.cfg.world_size
|
||||
_cfg.steps = _cfg.buffer_size
|
||||
self.buffer = Buffer(_cfg)
|
||||
for fp in tqdm(fps, desc='Loading data'):
|
||||
@@ -65,10 +74,12 @@ class OfflineTrainer(Trainer):
|
||||
f'please double-check your config.'
|
||||
for i in range(len(td)):
|
||||
self.buffer.add(td[i])
|
||||
assert self.buffer.num_eps == self.buffer.capacity, \
|
||||
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
|
||||
if self.buffer.num_transitions > self.buffer.capacity:
|
||||
print(f'Buffer has {self.buffer.num_transitions} transitions,' \
|
||||
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
|
||||
|
||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||
if self.cfg.rank == 0:
|
||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||
metrics = {}
|
||||
for i in range(self.cfg.steps):
|
||||
|
||||
@@ -76,7 +87,7 @@ class OfflineTrainer(Trainer):
|
||||
train_metrics = self.agent.update(self.buffer)
|
||||
|
||||
# Evaluate agent periodically
|
||||
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
||||
if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
|
||||
metrics = {
|
||||
'iteration': i,
|
||||
'total_time': time() - self._start_time,
|
||||
@@ -89,4 +100,5 @@ class OfflineTrainer(Trainer):
|
||||
self.logger.save_agent(self.agent, identifier=f'{i}')
|
||||
self.logger.log(metrics, 'pretrain')
|
||||
|
||||
self.logger.finish(self.agent)
|
||||
if self.cfg.rank == 0:
|
||||
self.logger.finish(self.agent)
|
||||
|
||||
Reference in New Issue
Block a user