12 Commits

Author SHA1 Message Date
Nicklas Hansen
4c03df676c update pinned torchrl version 2024-07-02 10:12:30 -07:00
Nicklas Hansen
8c299529a8 Update README.md 2024-07-02 10:12:30 -07:00
Nicklas Hansen
e96d4ae1a6 reduce # wandb calls 2024-07-02 10:12:30 -07:00
Nicklas Hansen
d28b03b3f9 update dockerfile 2024-07-02 10:12:30 -07:00
Nicklas Hansen
614122644d update dockerfile + pin all versions 2024-07-02 10:12:30 -07:00
Nicklas Hansen
dc39c23067 minor fix in print 2024-07-02 10:12:30 -07:00
Nicklas Hansen
173131ca48 migrate to slicebuffer from torchrl-nightly 2024-07-02 10:12:30 -07:00
Nicklas Hansen
594299d7d1 Merge branch 'uncertainty-regularization' of github.com:nicklashansen/tdmpc2 into uncertainty-regularization 2024-01-08 11:00:17 -08:00
Nicklas Hansen
188bd201aa disable uncertainty estimation when coef=0 2024-01-08 10:55:46 -08:00
Nicklas Hansen
392b16ac89 add uncertainty regularization 2024-01-08 10:55:46 -08:00
Nicklas Hansen
e5c9029c86 disable uncertainty estimation when coef=0 2024-01-04 19:39:44 -08:00
Nicklas Hansen
194c92331c add uncertainty regularization 2024-01-03 18:11:32 -08:00
10 changed files with 58 additions and 125 deletions

View File

@@ -12,7 +12,7 @@ class Buffer():
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device(self.cfg.rank)
self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
@@ -23,7 +23,6 @@ class Buffer():
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0
self._num_transitions = 0
@property
def capacity(self):
@@ -34,11 +33,6 @@ class Buffer():
def num_eps(self):
"""Return the number of episodes in the buffer."""
return self._num_eps
@property
def num_transitions(self):
"""Return the number of transitions in the buffer."""
return self._num_transitions
def _reserve_buffer(self, storage):
"""
@@ -54,11 +48,7 @@ class Buffer():
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
if self.cfg.rank == 0:
if self.cfg.world_size > 1:
print(f'Buffer capacity per process: {self._capacity:,}')
else:
print(f'Buffer capacity: {self._capacity:,}')
print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_step = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
@@ -66,15 +56,10 @@ class Buffer():
for v in tds.values()
]) / len(tds)
total_bytes = bytes_per_step*self._capacity
if self.cfg.rank == 0:
if self.cfg.world_size > 1:
print(f'Storage required per process: {total_bytes/1e9:.2f} GB')
else:
print(f'Storage required: {total_bytes/1e9:.2f} GB')
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
storage_device = self.cfg.rank if 2.5*total_bytes < mem_free else 'cpu'
if self.cfg.rank == 0:
print(f'Using {storage_device.upper()} memory for storage.')
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
print(f'Using {storage_device.upper()} memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
)
@@ -103,7 +88,6 @@ class Buffer():
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
self._num_transitions += len(td)
return self._num_eps
def sample(self):

View File

@@ -113,13 +113,11 @@ class Logger:
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._eval = []
if cfg.rank == 0:
print_run(cfg)
print_run(cfg)
self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none")
if cfg.rank == 0 or cfg.disable_wandb or self.project == "none" or self.entity == "none":
if cfg.rank == 0:
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False
cfg.save_video = False
self._wandb = None

View File

@@ -6,8 +6,8 @@ class RunningScale:
def __init__(self, cfg):
self.cfg = cfg
self._value = torch.ones(1, dtype=torch.float32, device=torch.device(cfg.rank))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device(cfg.rank))
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))
def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles)

View File

@@ -3,15 +3,13 @@ from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from tensordict.tensordict import TensorDict
from common import layers, math, init
class WorldModel(nn.Module):
"""
Distributed version of the TD-MPC2 world model architecture.
TD-MPC2 implicit world model architecture.
Can be used for both single-task and multi-task experiments.
"""
@@ -19,36 +17,24 @@ class WorldModel(nn.Module):
super().__init__()
self.cfg = cfg
if cfg.multitask:
self.__task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
for i in range(len(cfg.tasks)):
self._action_masks[i, :cfg.action_dims[i]] = 1.
self.__encoder = layers.enc(cfg)
self.__dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self.__reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self.__pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self.__Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init)
init.zero_([self.__reward[-1].weight, self.__Qs.params[-2]])
self._target_Qs = deepcopy(self.__Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min, requires_grad=False)
self.log_std_dif = torch.tensor(cfg.log_std_max, requires_grad=False) - self.log_std_min
self.to(cfg.rank)
if cfg.multitask:
self._task_emb = DDP(self.__task_emb, device_ids=[cfg.rank])
self._encoder = nn.ModuleDict({k: DDP(v, device_ids=[cfg.rank]) for k, v in self.__encoder.items()})
self._dynamics = DDP(self.__dynamics, device_ids=[cfg.rank])
self._reward = DDP(self.__reward, device_ids=[cfg.rank])
self._pi = DDP(self.__pi, device_ids=[cfg.rank])
self._Qs = DDP(self.__Qs, device_ids=[cfg.rank])
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min)
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
@property
def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def __repr__(self):
modules = '\n'.join([str(m) for m in [self._encoder, self._dynamics, self._reward, self._pi, self._Qs]])
return f"{self.__class__.__name__}({modules})\nLearnable parameters: {self.total_params:,}"
def to(self, *args, **kwargs):
"""

View File

@@ -11,7 +11,6 @@ eval_episodes: 10
eval_freq: 50000
# training
world_size: 1
steps: 10_000_000
batch_size: 256
reward_coef: 0.1
@@ -39,6 +38,7 @@ horizon: 3
min_std: 0.05
max_std: 2
temperature: 0.5
uncertainty_coef: 0
# actor
log_std_min: -10
@@ -75,7 +75,6 @@ save_agent: true
seed: 1
# convenience
rank: ???
work_dir: ???
task_title: ???
multitask: ???

View File

@@ -35,8 +35,7 @@ def make_multitask_env(cfg):
"""
Make a multi-task environment for TD-MPC2 experiments.
"""
if cfg.rank == 0:
print('Creating multi-task environment with tasks:', cfg.tasks)
print('Creating multi-task environment with tasks:', cfg.tasks)
envs = []
for task in cfg.tasks:
_cfg = deepcopy(cfg)

View File

@@ -16,8 +16,8 @@ class TDMPC2:
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device(cfg.rank)
self.model = WorldModel(cfg)
self.device = torch.device('cuda')
self.model = WorldModel(cfg).to(self.device)
self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()},
@@ -30,7 +30,7 @@ class TDMPC2:
self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda'
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
def _get_discount(self, episode_length):
@@ -90,6 +90,14 @@ class TDMPC2:
else:
a = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu()
@torch.no_grad()
def _estimate_uncertainty(self, z, task):
"""Estimates epistemic uncertainty, normalized by predicted value."""
if self.cfg.uncertainty_coef == 0:
return 0
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef
@torch.no_grad()
def _estimate_value(self, z, actions, task):
@@ -98,9 +106,10 @@ class TDMPC2:
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G += discount * reward
G += discount * (reward - self._estimate_uncertainty(z, task))
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
return G + discount * (terminal_value - self._estimate_uncertainty(z, task))
@torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None):

View File

@@ -14,28 +14,14 @@ from common.buffer import Buffer
from envs import make_env
from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer
from trainer.online_trainer import OnlineTrainer
from common.logger import Logger
torch.backends.cudnn.benchmark = True
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
torch.distributed.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size
)
def cleanup():
torch.distributed.destroy_process_group()
def train(rank: int, cfg: dict):
@hydra.main(config_name='config', config_path='.')
def train(cfg: dict):
"""
Script for training single-task / multi-task TD-MPC2 agents.
@@ -54,11 +40,14 @@ def train(rank: int, cfg: dict):
$ python train.py task=dog-run steps=7000000
```
"""
setup(rank, cfg.world_size)
set_seed(cfg.seed + rank)
cfg.rank = rank
assert torch.cuda.is_available()
assert cfg.steps > 0, 'Must train for at least 1 step.'
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
trainer = OfflineTrainer(
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
trainer = trainer_cls(
cfg=cfg,
env=make_env(cfg),
agent=TDMPC2(cfg),
@@ -66,26 +55,8 @@ def train(rank: int, cfg: dict):
logger=Logger(cfg),
)
trainer.train()
if cfg.rank == 0:
print('\nTraining completed successfully')
cleanup()
@hydra.main(config_name='config', config_path='.')
def launch(cfg: dict):
assert torch.cuda.is_available()
assert cfg.world_size > 0, 'Must train with at least 1 GPU.'
assert cfg.task in {'mt30', 'mt80'}, 'Distributed training is only supported for multi-task experiments.'
assert cfg.steps > 0, 'Must train for at least 1 step.'
cfg = parse_cfg(cfg)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
torch.multiprocessing.spawn(
train,
args=(cfg,),
nprocs=cfg.world_size,
join=True,
)
print('\nTraining completed successfully')
if __name__ == '__main__':
launch()
train()

View File

@@ -7,9 +7,8 @@ class Trainer:
self.agent = agent
self.buffer = buffer
self.logger = logger
if cfg.rank == 0:
print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self):
"""Evaluate a TD-MPC2 agent."""

View File

@@ -50,21 +50,12 @@ class OfflineTrainer(Trainer):
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}'
if self.cfg.rank == 0:
print(f'Found {len(fps)} files in {fp}')
# Distribute data across processes
assert len(fps) >= self.cfg.world_size, \
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
fps = fps[self.cfg.rank::self.cfg.world_size]
print(f'Process {self.cfg.rank} has {len(fps)} files')
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
print(f'Found {len(fps)} files in {fp}')
# Create buffer for sampling
_cfg = deepcopy(self.cfg)
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
_cfg.buffer_size //= self.cfg.world_size
_cfg.steps = _cfg.buffer_size
self.buffer = Buffer(_cfg)
for fp in tqdm(fps, desc='Loading data'):
@@ -74,12 +65,10 @@ class OfflineTrainer(Trainer):
f'please double-check your config.'
for i in range(len(td)):
self.buffer.add(td[i])
if self.buffer.num_transitions > self.buffer.capacity:
print(f'Buffer has {self.buffer.num_transitions} transitions,' \
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
assert self.buffer.num_eps == self.buffer.capacity, \
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
if self.cfg.rank == 0:
print(f'Training agent for {self.cfg.steps} iterations...')
print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {}
for i in range(self.cfg.steps):
@@ -87,7 +76,7 @@ class OfflineTrainer(Trainer):
train_metrics = self.agent.update(self.buffer)
# Evaluate agent periodically
if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
metrics = {
'iteration': i,
'total_time': time() - self._start_time,
@@ -100,5 +89,4 @@ class OfflineTrainer(Trainer):
self.logger.save_agent(self.agent, identifier=f'{i}')
self.logger.log(metrics, 'pretrain')
if self.cfg.rank == 0:
self.logger.finish(self.agent)
self.logger.finish(self.agent)