support distributed training
This commit is contained in:
@@ -13,7 +13,7 @@ class Buffer():
|
|||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._device = torch.device('cuda')
|
self._device = torch.device(self.cfg.rank)
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||||
self._sampler = SliceSampler(
|
self._sampler = SliceSampler(
|
||||||
num_slices=self.cfg.batch_size,
|
num_slices=self.cfg.batch_size,
|
||||||
@@ -23,6 +23,7 @@ class Buffer():
|
|||||||
)
|
)
|
||||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||||
self._num_eps = 0
|
self._num_eps = 0
|
||||||
|
self._num_transitions = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def capacity(self):
|
def capacity(self):
|
||||||
@@ -34,6 +35,11 @@ class Buffer():
|
|||||||
"""Return the number of episodes in the buffer."""
|
"""Return the number of episodes in the buffer."""
|
||||||
return self._num_eps
|
return self._num_eps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_transitions(self):
|
||||||
|
"""Return the number of transitions in the buffer."""
|
||||||
|
return self._num_transitions
|
||||||
|
|
||||||
def _reserve_buffer(self, storage):
|
def _reserve_buffer(self, storage):
|
||||||
"""
|
"""
|
||||||
Reserve a buffer with the given storage.
|
Reserve a buffer with the given storage.
|
||||||
@@ -48,7 +54,11 @@ class Buffer():
|
|||||||
|
|
||||||
def _init(self, tds):
|
def _init(self, tds):
|
||||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||||
print(f'Buffer capacity: {self._capacity:,}')
|
if self.cfg.rank == 0:
|
||||||
|
if self.cfg.world_size > 1:
|
||||||
|
print(f'Buffer capacity per process: {self._capacity:,}')
|
||||||
|
else:
|
||||||
|
print(f'Buffer capacity: {self._capacity:,}')
|
||||||
mem_free, _ = torch.cuda.mem_get_info()
|
mem_free, _ = torch.cuda.mem_get_info()
|
||||||
bytes_per_step = sum([
|
bytes_per_step = sum([
|
||||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||||
@@ -56,10 +66,15 @@ class Buffer():
|
|||||||
for v in tds.values()
|
for v in tds.values()
|
||||||
]) / len(tds)
|
]) / len(tds)
|
||||||
total_bytes = bytes_per_step*self._capacity
|
total_bytes = bytes_per_step*self._capacity
|
||||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
if self.cfg.rank == 0:
|
||||||
|
if self.cfg.world_size > 1:
|
||||||
|
print(f'Storage required per process: {total_bytes/1e9:.2f} GB')
|
||||||
|
else:
|
||||||
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||||
# Heuristic: decide whether to use CUDA or CPU memory
|
# Heuristic: decide whether to use CUDA or CPU memory
|
||||||
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
storage_device = self.cfg.rank if 2.5*total_bytes < mem_free else 'cpu'
|
||||||
print(f'Using {storage_device.upper()} memory for storage.')
|
if self.cfg.rank == 0:
|
||||||
|
print(f'Using {storage_device.upper()} memory for storage.')
|
||||||
return self._reserve_buffer(
|
return self._reserve_buffer(
|
||||||
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
|
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
|
||||||
)
|
)
|
||||||
@@ -88,6 +103,7 @@ class Buffer():
|
|||||||
self._buffer = self._init(td)
|
self._buffer = self._init(td)
|
||||||
self._buffer.extend(td)
|
self._buffer.extend(td)
|
||||||
self._num_eps += 1
|
self._num_eps += 1
|
||||||
|
self._num_transitions += len(td)
|
||||||
return self._num_eps
|
return self._num_eps
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
|
|||||||
@@ -113,11 +113,13 @@ class Logger:
|
|||||||
self._group = cfg_to_group(cfg)
|
self._group = cfg_to_group(cfg)
|
||||||
self._seed = cfg.seed
|
self._seed = cfg.seed
|
||||||
self._eval = []
|
self._eval = []
|
||||||
print_run(cfg)
|
if cfg.rank == 0:
|
||||||
|
print_run(cfg)
|
||||||
self.project = cfg.get("wandb_project", "none")
|
self.project = cfg.get("wandb_project", "none")
|
||||||
self.entity = cfg.get("wandb_entity", "none")
|
self.entity = cfg.get("wandb_entity", "none")
|
||||||
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
if cfg.rank == 0 or cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
||||||
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
if cfg.rank == 0:
|
||||||
|
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
||||||
cfg.save_agent = False
|
cfg.save_agent = False
|
||||||
cfg.save_video = False
|
cfg.save_video = False
|
||||||
self._wandb = None
|
self._wandb = None
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ class RunningScale:
|
|||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))
|
self._value = torch.ones(1, dtype=torch.float32, device=torch.device(cfg.rank))
|
||||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))
|
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device(cfg.rank))
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return dict(value=self._value, percentiles=self._percentiles)
|
return dict(value=self._value, percentiles=self._percentiles)
|
||||||
|
|||||||
@@ -3,13 +3,15 @@ from copy import deepcopy
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from tensordict.tensordict import TensorDict
|
||||||
|
|
||||||
from common import layers, math, init
|
from common import layers, math, init
|
||||||
|
|
||||||
|
|
||||||
class WorldModel(nn.Module):
|
class WorldModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
TD-MPC2 implicit world model architecture.
|
Distributed version of the TD-MPC2 world model architecture.
|
||||||
Can be used for both single-task and multi-task experiments.
|
Can be used for both single-task and multi-task experiments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -17,25 +19,37 @@ class WorldModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
if cfg.multitask:
|
if cfg.multitask:
|
||||||
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
self.__task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
||||||
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
|
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
|
||||||
for i in range(len(cfg.tasks)):
|
for i in range(len(cfg.tasks)):
|
||||||
self._action_masks[i, :cfg.action_dims[i]] = 1.
|
self._action_masks[i, :cfg.action_dims[i]] = 1.
|
||||||
self._encoder = layers.enc(cfg)
|
self.__encoder = layers.enc(cfg)
|
||||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
self.__dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
self.__reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
self.__pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
self.__Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||||
self.apply(init.weight_init)
|
self.apply(init.weight_init)
|
||||||
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
|
init.zero_([self.__reward[-1].weight, self.__Qs.params[-2]])
|
||||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
self._target_Qs = deepcopy(self.__Qs).requires_grad_(False)
|
||||||
self.log_std_min = torch.tensor(cfg.log_std_min)
|
self.log_std_min = torch.tensor(cfg.log_std_min, requires_grad=False)
|
||||||
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
|
self.log_std_dif = torch.tensor(cfg.log_std_max, requires_grad=False) - self.log_std_min
|
||||||
|
self.to(cfg.rank)
|
||||||
|
if cfg.multitask:
|
||||||
|
self._task_emb = DDP(self.__task_emb, device_ids=[cfg.rank])
|
||||||
|
self._encoder = nn.ModuleDict({k: DDP(v, device_ids=[cfg.rank]) for k, v in self.__encoder.items()})
|
||||||
|
self._dynamics = DDP(self.__dynamics, device_ids=[cfg.rank])
|
||||||
|
self._reward = DDP(self.__reward, device_ids=[cfg.rank])
|
||||||
|
self._pi = DDP(self.__pi, device_ids=[cfg.rank])
|
||||||
|
self._Qs = DDP(self.__Qs, device_ids=[cfg.rank])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_params(self):
|
def total_params(self):
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
modules = '\n'.join([str(m) for m in [self._encoder, self._dynamics, self._reward, self._pi, self._Qs]])
|
||||||
|
return f"{self.__class__.__name__}({modules})\nLearnable parameters: {self.total_params:,}"
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Overriding `to` method to also move additional tensors to device.
|
Overriding `to` method to also move additional tensors to device.
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ eval_episodes: 10
|
|||||||
eval_freq: 50000
|
eval_freq: 50000
|
||||||
|
|
||||||
# training
|
# training
|
||||||
|
world_size: 1
|
||||||
steps: 10_000_000
|
steps: 10_000_000
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
reward_coef: 0.1
|
reward_coef: 0.1
|
||||||
@@ -74,6 +75,7 @@ save_agent: true
|
|||||||
seed: 1
|
seed: 1
|
||||||
|
|
||||||
# convenience
|
# convenience
|
||||||
|
rank: ???
|
||||||
work_dir: ???
|
work_dir: ???
|
||||||
task_title: ???
|
task_title: ???
|
||||||
multitask: ???
|
multitask: ???
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ def make_multitask_env(cfg):
|
|||||||
"""
|
"""
|
||||||
Make a multi-task environment for TD-MPC2 experiments.
|
Make a multi-task environment for TD-MPC2 experiments.
|
||||||
"""
|
"""
|
||||||
print('Creating multi-task environment with tasks:', cfg.tasks)
|
if cfg.rank == 0:
|
||||||
|
print('Creating multi-task environment with tasks:', cfg.tasks)
|
||||||
envs = []
|
envs = []
|
||||||
for task in cfg.tasks:
|
for task in cfg.tasks:
|
||||||
_cfg = deepcopy(cfg)
|
_cfg = deepcopy(cfg)
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ class TDMPC2:
|
|||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.device = torch.device('cuda')
|
self.device = torch.device(cfg.rank)
|
||||||
self.model = WorldModel(cfg).to(self.device)
|
self.model = WorldModel(cfg)
|
||||||
self.optim = torch.optim.Adam([
|
self.optim = torch.optim.Adam([
|
||||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||||
{'params': self.model._dynamics.parameters()},
|
{'params': self.model._dynamics.parameters()},
|
||||||
@@ -30,7 +30,7 @@ class TDMPC2:
|
|||||||
self.scale = RunningScale(cfg)
|
self.scale = RunningScale(cfg)
|
||||||
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
||||||
self.discount = torch.tensor(
|
self.discount = torch.tensor(
|
||||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda'
|
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank
|
||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
|
|
||||||
def _get_discount(self, episode_length):
|
def _get_discount(self, episode_length):
|
||||||
|
|||||||
@@ -14,14 +14,28 @@ from common.buffer import Buffer
|
|||||||
from envs import make_env
|
from envs import make_env
|
||||||
from tdmpc2 import TDMPC2
|
from tdmpc2 import TDMPC2
|
||||||
from trainer.offline_trainer import OfflineTrainer
|
from trainer.offline_trainer import OfflineTrainer
|
||||||
from trainer.online_trainer import OnlineTrainer
|
|
||||||
from common.logger import Logger
|
from common.logger import Logger
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_name='config', config_path='.')
|
def setup(rank, world_size):
|
||||||
def train(cfg: dict):
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12355"
|
||||||
|
|
||||||
|
# initialize the process group
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def train(rank: int, cfg: dict):
|
||||||
"""
|
"""
|
||||||
Script for training single-task / multi-task TD-MPC2 agents.
|
Script for training single-task / multi-task TD-MPC2 agents.
|
||||||
|
|
||||||
@@ -40,14 +54,11 @@ def train(cfg: dict):
|
|||||||
$ python train.py task=dog-run steps=7000000
|
$ python train.py task=dog-run steps=7000000
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
assert torch.cuda.is_available()
|
setup(rank, cfg.world_size)
|
||||||
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
set_seed(cfg.seed + rank)
|
||||||
cfg = parse_cfg(cfg)
|
cfg.rank = rank
|
||||||
set_seed(cfg.seed)
|
|
||||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
|
||||||
|
|
||||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
trainer = OfflineTrainer(
|
||||||
trainer = trainer_cls(
|
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
env=make_env(cfg),
|
env=make_env(cfg),
|
||||||
agent=TDMPC2(cfg),
|
agent=TDMPC2(cfg),
|
||||||
@@ -55,8 +66,26 @@ def train(cfg: dict):
|
|||||||
logger=Logger(cfg),
|
logger=Logger(cfg),
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
print('\nTraining completed successfully')
|
if cfg.rank == 0:
|
||||||
|
print('\nTraining completed successfully')
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_name='config', config_path='.')
|
||||||
|
def launch(cfg: dict):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
assert cfg.world_size > 0, 'Must train with at least 1 GPU.'
|
||||||
|
assert cfg.task in {'mt30', 'mt80'}, 'Distributed training is only supported for multi-task experiments.'
|
||||||
|
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
||||||
|
cfg = parse_cfg(cfg)
|
||||||
|
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||||
|
torch.multiprocessing.spawn(
|
||||||
|
train,
|
||||||
|
args=(cfg,),
|
||||||
|
nprocs=cfg.world_size,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train()
|
launch()
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ class Trainer:
|
|||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
if cfg.rank == 0:
|
||||||
print('Architecture:', self.agent.model)
|
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||||
|
print('Architecture:', self.agent.model)
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Evaluate a TD-MPC2 agent."""
|
"""Evaluate a TD-MPC2 agent."""
|
||||||
|
|||||||
@@ -50,12 +50,21 @@ class OfflineTrainer(Trainer):
|
|||||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||||
fps = sorted(glob(str(fp)))
|
fps = sorted(glob(str(fp)))
|
||||||
assert len(fps) > 0, f'No data found at {fp}'
|
assert len(fps) > 0, f'No data found at {fp}'
|
||||||
print(f'Found {len(fps)} files in {fp}')
|
if self.cfg.rank == 0:
|
||||||
|
print(f'Found {len(fps)} files in {fp}')
|
||||||
|
|
||||||
|
# Distribute data across processes
|
||||||
|
assert len(fps) >= self.cfg.world_size, \
|
||||||
|
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
|
||||||
|
fps = fps[self.cfg.rank::self.cfg.world_size]
|
||||||
|
print(f'Process {self.cfg.rank} has {len(fps)} files')
|
||||||
|
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
|
||||||
|
|
||||||
# Create buffer for sampling
|
# Create buffer for sampling
|
||||||
_cfg = deepcopy(self.cfg)
|
_cfg = deepcopy(self.cfg)
|
||||||
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
|
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
|
||||||
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
|
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
|
||||||
|
_cfg.buffer_size //= self.cfg.world_size
|
||||||
_cfg.steps = _cfg.buffer_size
|
_cfg.steps = _cfg.buffer_size
|
||||||
self.buffer = Buffer(_cfg)
|
self.buffer = Buffer(_cfg)
|
||||||
for fp in tqdm(fps, desc='Loading data'):
|
for fp in tqdm(fps, desc='Loading data'):
|
||||||
@@ -65,10 +74,12 @@ class OfflineTrainer(Trainer):
|
|||||||
f'please double-check your config.'
|
f'please double-check your config.'
|
||||||
for i in range(len(td)):
|
for i in range(len(td)):
|
||||||
self.buffer.add(td[i])
|
self.buffer.add(td[i])
|
||||||
assert self.buffer.num_eps == self.buffer.capacity, \
|
if self.buffer.num_transitions > self.buffer.capacity:
|
||||||
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
|
print(f'Buffer has {self.buffer.num_transitions} transitions,' \
|
||||||
|
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
|
||||||
|
|
||||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
if self.cfg.rank == 0:
|
||||||
|
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||||
metrics = {}
|
metrics = {}
|
||||||
for i in range(self.cfg.steps):
|
for i in range(self.cfg.steps):
|
||||||
|
|
||||||
@@ -76,7 +87,7 @@ class OfflineTrainer(Trainer):
|
|||||||
train_metrics = self.agent.update(self.buffer)
|
train_metrics = self.agent.update(self.buffer)
|
||||||
|
|
||||||
# Evaluate agent periodically
|
# Evaluate agent periodically
|
||||||
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
|
||||||
metrics = {
|
metrics = {
|
||||||
'iteration': i,
|
'iteration': i,
|
||||||
'total_time': time() - self._start_time,
|
'total_time': time() - self._start_time,
|
||||||
@@ -89,4 +100,5 @@ class OfflineTrainer(Trainer):
|
|||||||
self.logger.save_agent(self.agent, identifier=f'{i}')
|
self.logger.save_agent(self.agent, identifier=f'{i}')
|
||||||
self.logger.log(metrics, 'pretrain')
|
self.logger.log(metrics, 'pretrain')
|
||||||
|
|
||||||
self.logger.finish(self.agent)
|
if self.cfg.rank == 0:
|
||||||
|
self.logger.finish(self.agent)
|
||||||
|
|||||||
Reference in New Issue
Block a user