This commit is contained in:
Nicklas Hansen
2024-02-11 14:41:20 -08:00
parent 8bbc14ebab
commit 829e329b3b
10 changed files with 174 additions and 88 deletions

View File

@@ -83,11 +83,13 @@ class Buffer():
def add(self, td): def add(self, td):
"""Add an episode to the buffer.""" """Add an episode to the buffer."""
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs)
td = td.permute(1, 0)
if self._num_eps == 0: if self._num_eps == 0:
self._buffer = self._init(td) self._buffer = self._init(td[0])
self._buffer.extend(td) for i in range(self.cfg.num_envs):
self._num_eps += 1 self._buffer.extend(td[i])
self._num_eps += self.cfg.num_envs
return self._num_eps return self._num_eps
def _prepare_batch(self, td): def _prepare_batch(self, td):

View File

@@ -84,15 +84,12 @@ def two_hot_inv(x, cfg):
return symexp(x) return symexp(x)
def gumbel_softmax_sample(p, temperature=1.0, dim=0): def gumbel_softmax_sample(p, temperature=1.0, dim=1):
"""Sample from the Gumbel-Softmax distribution.""" """Sample indices from a Gumbel-Softmax distribution."""
logits = p.log() logits = torch.log(p + 1e-9)
gumbels = ( gumbels = -torch.empty_like(logits).exponential_().log()
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() y = (logits + gumbels) / temperature
) # ~Gumbel(0,1) return y.argmax(dim=dim)
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
return y_soft.argmax(-1)
def termination_statistics(pred, target, eps=1e-9): def termination_statistics(pred, target, eps=1e-9):

View File

@@ -5,6 +5,7 @@ defaults:
task: dog-run task: dog-run
obs: state obs: state
episodic: false episodic: false
num_envs: 1
# evaluation # evaluation
checkpoint: ??? checkpoint: ???
@@ -14,6 +15,7 @@ eval_freq: 50000
# training # training
steps: 10_000_000 steps: 10_000_000
batch_size: 256 batch_size: 256
steps_per_update: 1
reward_coef: 0.1 reward_coef: 0.1
value_coef: 0.1 value_coef: 0.1
termination_coef: 1 termination_coef: 1
@@ -64,8 +66,8 @@ dropout: 0.01
simnorm_dim: 8 simnorm_dim: 8
# logging # logging
wandb_project: ??? wandb_project: tdmpc3
wandb_entity: ??? wandb_entity: nicklashansen
wandb_silent: false wandb_silent: false
enable_wandb: true enable_wandb: true
save_csv: true save_csv: true

View File

@@ -5,6 +5,8 @@ import gymnasium as gym
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
from envs.wrappers.vectorized import Vectorized
def missing_dependencies(task): def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
@@ -62,16 +64,19 @@ def make_env(cfg):
gym.logger.set_level(40) gym.logger.set_level(40)
if cfg.multitask: if cfg.multitask:
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]:
try: try:
env = fn(cfg) env = fn(cfg)
break
except ValueError: except ValueError:
pass pass
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \
'Vectorized environments only support state observations.'
env = Vectorized(cfg, fn)
env = TensorWrapper(env) env = TensorWrapper(env)
try: # Dict try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
@@ -79,5 +84,5 @@ def make_env(cfg):
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0] cfg.action_dim = env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length) cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs
return env return env

View File

@@ -44,12 +44,16 @@ class DMControlWrapper:
def unwrapped(self): def unwrapped(self):
return self.env return self.env
@property
def metadata(self):
return None
def _obs_to_array(self, obs): def _obs_to_array(self, obs):
return torch.from_numpy( return torch.from_numpy(
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32)) np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
def reset(self): def reset(self):
return self._obs_to_array(self.env.reset().observation) return self._obs_to_array(self.env.reset().observation), defaultdict(float)
def step(self, action): def step(self, action):
reward = 0 reward = 0
@@ -62,6 +66,9 @@ class DMControlWrapper:
def render(self, width=384, height=384, camera_id=None): def render(self, width=384, height=384, camera_id=None):
return self.env.physics.render(height, width, camera_id or self.camera_id) return self.env.physics.render(height, width, camera_id or self.camera_id)
def close(self):
self.env.close()
class Pixels(gym.Wrapper): class Pixels(gym.Wrapper):
def __init__(self, env, cfg, num_frames=3, size=64): def __init__(self, env, cfg, num_frames=3, size=64):
@@ -88,6 +95,9 @@ class Pixels(gym.Wrapper):
_, reward, done, info = self.env.step(action) _, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info return self._get_obs(), reward, done, info
def close(self):
self.env.close()
def make_env(cfg): def make_env(cfg):
""" """

View File

@@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized'
def rand_act(self): def rand_act(self):
if self._wrapped_vectorized:
return self.env.rand_act()
return torch.from_numpy(self.action_space.sample().astype(np.float32)) return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x): def _try_f32_tensor(self, x):
@@ -31,12 +34,24 @@ class TensorWrapper(gym.Wrapper):
obs = self._try_f32_tensor(obs) obs = self._try_f32_tensor(obs)
return obs return obs
def reset(self, task_idx=None): def reset(self, task_idx=None, **kwargs):
return self._obs_to_tensor(self.env.reset()) if self._wrapped_vectorized:
obs = self.env.reset(**kwargs)
else:
obs = self.env.reset()
return self._obs_to_tensor(obs)
def step(self, action): def step(self, action, **kwargs):
obs, reward, done, info = self.env.step(action.numpy()) if self._wrapped_vectorized:
info = defaultdict(float, info) obs, reward, terminated, truncated, info = self.env.step(action.numpy(), **kwargs)
info['success'] = float(info['success']) else:
info['terminated'] = torch.tensor(float(info['terminated'])) obs, reward, terminated, truncated, info = self.env.step(action.numpy())
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info reward = torch.tensor(reward, dtype=torch.float32)
terminated = torch.tensor(terminated)
truncated = torch.tensor(truncated)
done = terminated | truncated
if 'success' not in info:
info['success'] = torch.zeros_like(reward)
info['terminated'] = terminated.float()
info['truncated'] = truncated.float()
return self._obs_to_tensor(obs), reward, done, info

View File

@@ -19,7 +19,9 @@ class Timeout(gym.Wrapper):
return self.env.reset(**kwargs) return self.env.reset(**kwargs)
def step(self, action): def step(self, action):
obs, reward, done, info = self.env.step(action) obs, reward, terminated, info = self.env.step(action)
self._t += 1 self._t += 1
done = done or self._t >= self.max_episode_steps truncated = self._t >= self.max_episode_steps
return obs, reward, done, info info['terminated'] = terminated
info['truncated'] = truncated
return obs, reward, terminated, truncated, info

View File

@@ -0,0 +1,42 @@
from copy import deepcopy
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
import numpy as np
import torch
class Vectorized():
"""
Vectorized environment for TD-MPC2 online training.
"""
def __init__(self, cfg, env_fn):
super().__init__()
self.cfg = cfg
def make():
_cfg = deepcopy(cfg)
_cfg.num_envs = 1
_cfg.seed = cfg.seed + np.random.randint(1000)
return env_fn(_cfg)
print(f'Creating {cfg.num_envs} environments...')
# self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
self.env = SyncVectorEnv([make for _ in range(cfg.num_envs)])
env = make()
self.observation_space = env.observation_space
self.action_space = env.action_space
self.max_episode_steps = env.max_episode_steps
def rand_act(self):
return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1
def reset(self):
obs, _ = self.env.reset()
return obs
def step(self, action):
return self.env.step(action)
def render(self, *args, **kwargs):
return self.env.render(*args, **kwargs)

View File

@@ -38,7 +38,7 @@ class TDMPC2(torch.nn.Module):
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
print('Episode length:', cfg.episode_length) print('Episode length:', cfg.episode_length)
print('Discount factor:', self.discount) print('Discount factor:', self.discount)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile: if cfg.compile:
print('Compiling update function with torch.compile...') print('Compiling update function with torch.compile...')
self._update = torch.compile(self._update, mode="reduce-overhead") self._update = torch.compile(self._update, mode="reduce-overhead")
@@ -109,7 +109,7 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
obs = obs.to(self.device, non_blocking=True).unsqueeze(0) obs = obs.to(self.device, non_blocking=True)
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
if self.cfg.mpc: if self.cfg.mpc:
@@ -118,7 +118,7 @@ class TDMPC2(torch.nn.Module):
action, info = self.model.pi(z, task) action, info = self.model.pi(z, task)
if eval_mode: if eval_mode:
action = info["mean"] action = info["mean"]
return action[0].cpu() return action.cpu()
@torch.no_grad() @torch.no_grad()
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
@@ -126,8 +126,8 @@ class TDMPC2(torch.nn.Module):
G, discount = 0, 1 G, discount = 0, 1
termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device) termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[:, t], task)
G = G + discount * (1-termination) * reward G = G + discount * (1-termination) * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update discount = discount * discount_update
@@ -137,7 +137,7 @@ class TDMPC2(torch.nn.Module):
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg') return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
@torch.no_grad() @torch.no_grad()
def _plan(self, obs, t0=False, eval_mode=False, task=None): def _plan(self, z, t0=False, eval_mode=False, task=None):
""" """
Plan a sequence of actions using the learned world model. Plan a sequence of actions using the learned world model.
@@ -151,60 +151,68 @@ class TDMPC2(torch.nn.Module):
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
# Sample policy trajectories # Sample policy trajectories
z = self.model.encode(obs, task)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1) _z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1).view(self.cfg.num_envs * self.cfg.num_pi_trajs, -1)
for t in range(self.cfg.horizon - 1): for t in range(self.cfg.horizon - 1):
pi_actions[t], _ = self.model.pi(_z, task) a, _ = self.model.pi(_z, task)
_z = self.model.next(_z, pi_actions[t], task) pi_actions[:, t] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim)
pi_actions[-1], _ = self.model.pi(_z, task) _z = self.model.next(_z, a, task)
a, _ = self.model.pi(_z, task)
pi_actions[:, -1] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim)
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) std = torch.full((self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, device=self.device)
if not t0: if not t0:
mean[:-1] = self._prev_mean[1:] mean[:, :-1] = self._prev_mean[:, 1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI # Iterate MPPI
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
# Sample actions # Sample new actions
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) r = torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples - self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r actions_sample = mean.unsqueeze(2) + std.unsqueeze(2) * r
actions_sample = actions_sample.clamp(-1, 1) actions[:, :, self.cfg.num_pi_trajs:] = actions_sample.clamp(-1, 1)
actions[:, self.cfg.num_pi_trajs:] = actions_sample
if self.cfg.multitask: if self.cfg.multitask:
actions = actions * self.model._action_masks[task] actions = actions * self.model._action_masks[task]
# Compute elite actions # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num(0) value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
elite_actions = actions.gather(
dim=2,
index=elite_idxs[:, None, :, None].expand(-1, self.cfg.horizon, self.cfg.num_elites, self.cfg.action_dim)
)
# Update parameters # Update parameters
max_value = elite_value.max(0).values score = torch.exp(self.cfg.temperature * (elite_value - elite_value.max(1, keepdim=True).values))
score = torch.exp(self.cfg.temperature*(elite_value - max_value)) score = score / (score.sum(dim=1, keepdim=True) + 1e-9)
score = score / score.sum(0) score_exp = score.unsqueeze(1)
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) mean = (score_exp * elite_actions).sum(dim=2) / (score_exp.sum(dim=2) + 1e-9)
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() std = ((score_exp * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) /
std = std.clamp(self.cfg.min_std, self.cfg.max_std) (score_exp.sum(dim=2) + 1e-9)).sqrt().clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask: if self.cfg.multitask:
mean = mean * self.model._action_masks[task] mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task] std = std * self.model._action_masks[task]
# Select action # Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) logits = torch.log(score.squeeze(2) + 1e-9)
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) rand_idx = math.gumbel_softmax_sample(logits, temperature=self.cfg.temperature, dim=1)
a, std = actions[0], std[0] selected_actions = elite_actions.gather(
dim=2,
index=rand_idx[:, None, None, None].expand(-1, self.cfg.horizon, 1, self.cfg.action_dim)
).squeeze(2)
action, std_out = selected_actions[:, 0], std[:, 0]
if not eval_mode: if not eval_mode:
a = a + std * torch.randn(self.cfg.action_dim, device=std.device) action = action + std_out * torch.randn_like(action)
self._prev_mean.copy_(mean) self._prev_mean.copy_(mean)
return a.clamp(-1, 1) return action.clamp(-1, 1)
def update_pi(self, zs, task): def update_pi(self, zs, task):
""" """

View File

@@ -1,6 +1,5 @@
from time import time from time import time
import numpy as np
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from trainer.base import Trainer from trainer.base import Trainer
@@ -28,11 +27,11 @@ class OnlineTrainer(Trainer):
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes, ep_lengths = [], [], [] ep_rewards, ep_successes, ep_lengths = [], [], []
for i in range(self.cfg.eval_episodes): for i in range(self.cfg.eval_episodes // self.cfg.num_envs):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0 obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0)) self.logger.video.init(self.env, enabled=(i==0))
while not done: while not done.any():
torch.compiler.cudagraph_mark_step_begin() torch.compiler.cudagraph_mark_step_begin()
action = self.agent.act(obs, t0=t==0, eval_mode=True) action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
@@ -40,15 +39,16 @@ class OnlineTrainer(Trainer):
t += 1 t += 1
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.record(self.env) self.logger.video.record(self.env)
assert done.all(), 'Vectorized environments must reset all environments at once.'
ep_rewards.append(ep_reward) ep_rewards.append(ep_reward)
ep_successes.append(info['success']) ep_successes.append(info['success'])
ep_lengths.append(t) ep_lengths.append(t)
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.save(self._step) self.logger.video.save(self._step)
return dict( return dict(
episode_reward=np.nanmean(ep_rewards), episode_reward=torch.cat(ep_rewards).mean(),
episode_success=np.nanmean(ep_successes), episode_success=info['success'].mean(),
episode_length= np.nanmean(ep_lengths), episode_length= torch.tensor(ep_lengths, dtype=torch.float32).mean(),
) )
def to_td(self, obs, action=None, reward=None, terminated=None): def to_td(self, obs, action=None, reward=None, terminated=None):
@@ -60,27 +60,28 @@ class OnlineTrainer(Trainer):
if action is None: if action is None:
action = torch.full_like(self.env.rand_act(), float('nan')) action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
if terminated is None: if terminated is None:
terminated = torch.tensor(float('nan')) terminated = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
td = TensorDict( td = TensorDict(
obs=obs, obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
terminated=terminated.unsqueeze(0), terminated=terminated.unsqueeze(0),
batch_size=(1,)) batch_size=(1, self.cfg.num_envs,))
return td return td
def train(self): def train(self):
"""Train a TD-MPC2 agent.""" """Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, False train_metrics, done, eval_next = {}, torch.tensor(True), True
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
# Evaluate agent periodically # Evaluate agent periodically
if self._step % self.cfg.eval_freq == 0: if self._step % self.cfg.eval_freq == 0:
eval_next = True eval_next = True
# Reset environment # Reset environment
if done: if done.any():
assert done.all(), 'Vectorized environments must reset all environments at once.'
if eval_next: if eval_next:
eval_metrics = self.eval() eval_metrics = self.eval()
eval_metrics.update(self.common_metrics()) eval_metrics.update(self.common_metrics())
@@ -88,17 +89,19 @@ class OnlineTrainer(Trainer):
eval_next = False eval_next = False
if self._step > 0: if self._step > 0:
if info['terminated'] and not self.cfg.episodic: if info['terminated'].any() and not self.cfg.episodic:
raise ValueError('Termination detected but you are not in episodic mode. ' \ raise ValueError('Termination detected but you are not in episodic mode. ' \
'Set `episodic=true` to enable support for terminations.') 'Set `episodic=true` to enable support for terminations.')
tds = torch.cat(self._tds)
train_metrics.update( train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_reward=tds['reward'].nansum(0).mean(),
episode_success=info['success'], episode_success=info['success'].nanmean(),
episode_length=len(self._tds), episode_length=len(self._tds),
episode_terminated=info['terminated']) episode_terminated=info['terminated'].nanmean(),
)
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(tds)
obs = self.env.reset() obs = self.env.reset()
self._tds = [self.to_td(obs)] self._tds = [self.to_td(obs)]
@@ -114,14 +117,14 @@ class OnlineTrainer(Trainer):
# Update agent # Update agent
if self._step >= self.cfg.seed_steps: if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps: if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update)
print('Pretraining agent on seed data...') print('Pretraining agent on seed data...')
else: else:
num_updates = 1 num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update))
for _ in range(num_updates): for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer) _train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics) train_metrics.update(_train_metrics)
self._step += 1 self._step += self.cfg.num_envs
self.logger.finish(self.agent) self.logger.finish(self.agent)