Compare commits
12 Commits
main
...
vectorized
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97c1447199 | ||
|
|
a586d8f393 | ||
|
|
6116eb3fa5 | ||
|
|
491d367fc6 | ||
|
|
10f368f20d | ||
|
|
829e329b3b | ||
|
|
10a0be2724 | ||
|
|
ad2342e258 | ||
|
|
fa41a3e450 | ||
|
|
f6d1bfe12d | ||
|
|
9dd3e673c4 | ||
|
|
51d6b8d7a9 |
@@ -83,11 +83,13 @@ class Buffer():
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer."""
|
||||
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
|
||||
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs)
|
||||
td = td.permute(1, 0)
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td)
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += 1
|
||||
self._buffer = self._init(td[0])
|
||||
for i in range(self.cfg.num_envs):
|
||||
self._buffer.extend(td[i])
|
||||
self._num_eps += self.cfg.num_envs
|
||||
return self._num_eps
|
||||
|
||||
def _prepare_batch(self, td):
|
||||
|
||||
@@ -84,15 +84,12 @@ def two_hot_inv(x, cfg):
|
||||
return symexp(x)
|
||||
|
||||
|
||||
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
||||
"""Sample from the Gumbel-Softmax distribution."""
|
||||
logits = p.log()
|
||||
gumbels = (
|
||||
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
||||
) # ~Gumbel(0,1)
|
||||
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
|
||||
y_soft = gumbels.softmax(dim)
|
||||
return y_soft.argmax(-1)
|
||||
def gumbel_softmax_sample(p, temperature=1.0, dim=1):
|
||||
"""Sample indices from a Gumbel-Softmax distribution."""
|
||||
logits = torch.log(p + 1e-9)
|
||||
gumbels = -torch.empty_like(logits).exponential_().log()
|
||||
y = (logits + gumbels) / temperature
|
||||
return y.argmax(dim=dim)
|
||||
|
||||
|
||||
def termination_statistics(pred, target, eps=1e-9):
|
||||
|
||||
@@ -77,4 +77,8 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
||||
cfg.task_dim = 0
|
||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||
|
||||
# Ensure that eval_episodes is divisible by num_envs and is at least 1*num_envs
|
||||
cfg.eval_episodes = max(cfg.eval_episodes, cfg.num_envs)
|
||||
cfg.eval_episodes = cfg.eval_episodes - (cfg.eval_episodes % cfg.num_envs)
|
||||
|
||||
return cfg_to_dataclass(cfg)
|
||||
|
||||
@@ -5,6 +5,7 @@ defaults:
|
||||
task: dog-run
|
||||
obs: state
|
||||
episodic: false
|
||||
num_envs: 1
|
||||
|
||||
# evaluation
|
||||
checkpoint: ???
|
||||
@@ -14,6 +15,7 @@ eval_freq: 50000
|
||||
# training
|
||||
steps: 10_000_000
|
||||
batch_size: 256
|
||||
steps_per_update: 1
|
||||
reward_coef: 0.1
|
||||
value_coef: 0.1
|
||||
termination_coef: 1
|
||||
@@ -64,8 +66,8 @@ dropout: 0.01
|
||||
simnorm_dim: 8
|
||||
|
||||
# logging
|
||||
wandb_project: ???
|
||||
wandb_entity: ???
|
||||
wandb_project: tdmpc3
|
||||
wandb_entity: nicklashansen
|
||||
wandb_silent: false
|
||||
enable_wandb: true
|
||||
save_csv: true
|
||||
|
||||
@@ -5,6 +5,8 @@ import gymnasium as gym
|
||||
|
||||
from envs.wrappers.multitask import MultitaskWrapper
|
||||
from envs.wrappers.tensor import TensorWrapper
|
||||
from envs.wrappers.vectorized import Vectorized
|
||||
|
||||
|
||||
def missing_dependencies(task):
|
||||
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
||||
@@ -62,16 +64,19 @@ def make_env(cfg):
|
||||
gym.logger.set_level(40)
|
||||
if cfg.multitask:
|
||||
env = make_multitask_env(cfg)
|
||||
|
||||
else:
|
||||
env = None
|
||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]:
|
||||
try:
|
||||
env = fn(cfg)
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
if env is None:
|
||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||
assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \
|
||||
'Vectorized environments only support state observations.'
|
||||
env = Vectorized(cfg, fn)
|
||||
env = TensorWrapper(env)
|
||||
try: # Dict
|
||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||
@@ -79,5 +84,5 @@ def make_env(cfg):
|
||||
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
|
||||
cfg.action_dim = env.action_space.shape[0]
|
||||
cfg.episode_length = env.max_episode_steps
|
||||
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
||||
cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs
|
||||
return env
|
||||
|
||||
@@ -44,12 +44,16 @@ class DMControlWrapper:
|
||||
def unwrapped(self):
|
||||
return self.env
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return None
|
||||
|
||||
def _obs_to_array(self, obs):
|
||||
return torch.from_numpy(
|
||||
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
|
||||
|
||||
def reset(self):
|
||||
return self._obs_to_array(self.env.reset().observation)
|
||||
return self._obs_to_array(self.env.reset().observation), defaultdict(float)
|
||||
|
||||
def step(self, action):
|
||||
reward = 0
|
||||
@@ -61,6 +65,9 @@ class DMControlWrapper:
|
||||
|
||||
def render(self, width=384, height=384, camera_id=None):
|
||||
return self.env.physics.render(height, width, camera_id or self.camera_id)
|
||||
|
||||
def close(self):
|
||||
self.env.close()
|
||||
|
||||
|
||||
class Pixels(gym.Wrapper):
|
||||
@@ -88,6 +95,9 @@ class Pixels(gym.Wrapper):
|
||||
_, reward, done, info = self.env.step(action)
|
||||
return self._get_obs(), reward, done, info
|
||||
|
||||
def close(self):
|
||||
self.env.close()
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
"""
|
||||
|
||||
@@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized'
|
||||
|
||||
def rand_act(self):
|
||||
if self._wrapped_vectorized:
|
||||
return self.env.rand_act()
|
||||
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
||||
|
||||
def _try_f32_tensor(self, x):
|
||||
@@ -31,12 +34,24 @@ class TensorWrapper(gym.Wrapper):
|
||||
obs = self._try_f32_tensor(obs)
|
||||
return obs
|
||||
|
||||
def reset(self, task_idx=None):
|
||||
return self._obs_to_tensor(self.env.reset())
|
||||
def reset(self, task_idx=None, **kwargs):
|
||||
if self._wrapped_vectorized:
|
||||
obs = self.env.reset(**kwargs)
|
||||
else:
|
||||
obs = self.env.reset()
|
||||
return self._obs_to_tensor(obs)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action.numpy())
|
||||
info = defaultdict(float, info)
|
||||
info['success'] = float(info['success'])
|
||||
info['terminated'] = torch.tensor(float(info['terminated']))
|
||||
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
||||
def step(self, action, **kwargs):
|
||||
if self._wrapped_vectorized:
|
||||
obs, reward, terminated, truncated, info = self.env.step(action.numpy(), **kwargs)
|
||||
else:
|
||||
obs, reward, terminated, truncated, info = self.env.step(action.numpy())
|
||||
reward = torch.tensor(reward, dtype=torch.float32)
|
||||
terminated = torch.tensor(terminated)
|
||||
truncated = torch.tensor(truncated)
|
||||
done = terminated | truncated
|
||||
if 'success' not in info:
|
||||
info['success'] = torch.zeros_like(reward)
|
||||
info['terminated'] = terminated.float()
|
||||
info['truncated'] = truncated.float()
|
||||
return self._obs_to_tensor(obs), reward, done, info
|
||||
|
||||
@@ -19,7 +19,9 @@ class Timeout(gym.Wrapper):
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
obs, reward, terminated, info = self.env.step(action)
|
||||
self._t += 1
|
||||
done = done or self._t >= self.max_episode_steps
|
||||
return obs, reward, done, info
|
||||
truncated = self._t >= self.max_episode_steps
|
||||
info['terminated'] = terminated
|
||||
info['truncated'] = truncated
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
41
tdmpc2/envs/wrappers/vectorized.py
Normal file
41
tdmpc2/envs/wrappers/vectorized.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class Vectorized():
|
||||
"""
|
||||
Vectorized environment for TD-MPC2 online training.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env_fn):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
def make():
|
||||
_cfg = deepcopy(cfg)
|
||||
_cfg.num_envs = 1
|
||||
_cfg.seed = cfg.seed + np.random.randint(1000)
|
||||
return env_fn(_cfg)
|
||||
|
||||
print(f'Creating {cfg.num_envs} environments...')
|
||||
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
|
||||
env = make()
|
||||
self.observation_space = env.observation_space
|
||||
self.action_space = env.action_space
|
||||
self.max_episode_steps = env.max_episode_steps
|
||||
|
||||
def rand_act(self):
|
||||
return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1
|
||||
|
||||
def reset(self):
|
||||
obs, _ = self.env.reset()
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
return self.env.step(action)
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
return self.env.render(*args, **kwargs)
|
||||
@@ -38,7 +38,7 @@ class TDMPC2(torch.nn.Module):
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
print('Episode length:', cfg.episode_length)
|
||||
print('Discount factor:', self.discount)
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
if cfg.compile:
|
||||
print('Compiling update function with torch.compile...')
|
||||
self._update = torch.compile(self._update, mode="reduce-overhead")
|
||||
@@ -109,7 +109,7 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
|
||||
obs = obs.to(self.device, non_blocking=True)
|
||||
if task is not None:
|
||||
task = torch.tensor([task], device=self.device)
|
||||
if self.cfg.mpc:
|
||||
@@ -118,7 +118,7 @@ class TDMPC2(torch.nn.Module):
|
||||
action, info = self.model.pi(z, task)
|
||||
if eval_mode:
|
||||
action = info["mean"]
|
||||
return action[0].cpu()
|
||||
return action.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def _estimate_value(self, z, actions, task):
|
||||
@@ -126,8 +126,8 @@ class TDMPC2(torch.nn.Module):
|
||||
G, discount = 0, 1
|
||||
termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
||||
for t in range(self.cfg.horizon):
|
||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||
z = self.model.next(z, actions[t], task)
|
||||
reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
|
||||
z = self.model.next(z, actions[:, t], task)
|
||||
G = G + discount * (1-termination) * reward
|
||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
discount = discount * discount_update
|
||||
@@ -142,7 +142,7 @@ class TDMPC2(torch.nn.Module):
|
||||
Plan a sequence of actions using the learned world model.
|
||||
|
||||
Args:
|
||||
z (torch.Tensor): Latent state from which to plan.
|
||||
obs (torch.Tensor): Observation from which to plan.
|
||||
t0 (bool): Whether this is the first observation in the episode.
|
||||
eval_mode (bool): Whether to use the mean of the action distribution.
|
||||
task (Torch.Tensor): Task index (only used for multi-task experiments).
|
||||
@@ -150,62 +150,72 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
# Sample policy trajectories
|
||||
z = self.model.encode(obs, task)
|
||||
|
||||
# Sample policy trajectories
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||
for t in range(self.cfg.horizon-1):
|
||||
pi_actions[t], _ = self.model.pi(_z, task)
|
||||
_z = self.model.next(_z, pi_actions[t], task)
|
||||
pi_actions[-1], _ = self.model.pi(_z, task)
|
||||
pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||
_z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1).view(self.cfg.num_envs * self.cfg.num_pi_trajs, -1)
|
||||
for t in range(self.cfg.horizon - 1):
|
||||
a, _ = self.model.pi(_z, task)
|
||||
pi_actions[:, t] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim)
|
||||
_z = self.model.next(_z, a, task)
|
||||
a, _ = self.model.pi(_z, task)
|
||||
pi_actions[:, -1] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim)
|
||||
|
||||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples, 1)
|
||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
||||
z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1)
|
||||
mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||
std = torch.full((self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, device=self.device)
|
||||
if not t0:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||
mean[:, :-1] = self._prev_mean[:, 1:]
|
||||
actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
||||
actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
|
||||
|
||||
# Iterate MPPI
|
||||
for _ in range(self.cfg.iterations):
|
||||
|
||||
# Sample actions
|
||||
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
|
||||
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
|
||||
actions_sample = actions_sample.clamp(-1, 1)
|
||||
actions[:, self.cfg.num_pi_trajs:] = actions_sample
|
||||
# Sample new actions
|
||||
r = torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples - self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
|
||||
actions_sample = mean.unsqueeze(2) + std.unsqueeze(2) * r
|
||||
actions[:, :, self.cfg.num_pi_trajs:] = actions_sample.clamp(-1, 1)
|
||||
if self.cfg.multitask:
|
||||
actions = actions * self.model._action_masks[task]
|
||||
|
||||
# Compute elite actions
|
||||
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
||||
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
|
||||
elite_actions = actions.gather(
|
||||
dim=2,
|
||||
index=elite_idxs[:, None, :, None].expand(-1, self.cfg.horizon, self.cfg.num_elites, self.cfg.action_dim)
|
||||
)
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0).values
|
||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
||||
score = score / score.sum(0)
|
||||
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
|
||||
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
|
||||
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
||||
score = torch.exp(self.cfg.temperature * (elite_value - elite_value.max(1, keepdim=True).values))
|
||||
score = score / (score.sum(dim=1, keepdim=True) + 1e-9)
|
||||
score_exp = score.unsqueeze(1)
|
||||
mean = (score_exp * elite_actions).sum(dim=2) / (score_exp.sum(dim=2) + 1e-9)
|
||||
std = ((score_exp * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) /
|
||||
(score_exp.sum(dim=2) + 1e-9)).sqrt().clamp(self.cfg.min_std, self.cfg.max_std)
|
||||
if self.cfg.multitask:
|
||||
mean = mean * self.model._action_masks[task]
|
||||
std = std * self.model._action_masks[task]
|
||||
|
||||
# Select action
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1))
|
||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||
a, std = actions[0], std[0]
|
||||
logits = torch.log(score.squeeze(2) + 1e-9)
|
||||
rand_idx = math.gumbel_softmax_sample(logits, temperature=self.cfg.temperature, dim=1)
|
||||
selected_actions = elite_actions.gather(
|
||||
dim=2,
|
||||
index=rand_idx[:, None, None, None].expand(-1, self.cfg.horizon, 1, self.cfg.action_dim)
|
||||
).squeeze(2)
|
||||
action, std_out = selected_actions[:, 0], std[:, 0]
|
||||
if not eval_mode:
|
||||
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||
action = action + std_out * torch.randn_like(action)
|
||||
self._prev_mean.copy_(mean)
|
||||
return a.clamp(-1, 1)
|
||||
|
||||
return action.clamp(-1, 1)
|
||||
|
||||
def update_pi(self, zs, task):
|
||||
"""
|
||||
Update policy using a sequence of latent states.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from time import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.tensordict import TensorDict
|
||||
from trainer.base import Trainer
|
||||
@@ -28,11 +27,11 @@ class OnlineTrainer(Trainer):
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
ep_rewards, ep_successes, ep_lengths = [], [], []
|
||||
for i in range(self.cfg.eval_episodes):
|
||||
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
||||
for i in range(self.cfg.eval_episodes // self.cfg.num_envs):
|
||||
obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.init(self.env, enabled=(i==0))
|
||||
while not done:
|
||||
while not done.any():
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
action = self.agent.act(obs, t0=t==0, eval_mode=True)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
@@ -40,15 +39,16 @@ class OnlineTrainer(Trainer):
|
||||
t += 1
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.record(self.env)
|
||||
assert done.all(), 'Vectorized environments must reset all environments at once.'
|
||||
ep_rewards.append(ep_reward)
|
||||
ep_successes.append(info['success'])
|
||||
ep_lengths.append(t)
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.save(self._step)
|
||||
return dict(
|
||||
episode_reward=np.nanmean(ep_rewards),
|
||||
episode_success=np.nanmean(ep_successes),
|
||||
episode_length= np.nanmean(ep_lengths),
|
||||
episode_reward=torch.cat(ep_rewards).mean(),
|
||||
episode_success=info['success'].mean(),
|
||||
episode_length= torch.tensor(ep_lengths, dtype=torch.float32).mean(),
|
||||
)
|
||||
|
||||
def to_td(self, obs, action=None, reward=None, terminated=None):
|
||||
@@ -60,27 +60,28 @@ class OnlineTrainer(Trainer):
|
||||
if action is None:
|
||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||
if reward is None:
|
||||
reward = torch.tensor(float('nan'))
|
||||
reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
|
||||
if terminated is None:
|
||||
terminated = torch.tensor(float('nan'))
|
||||
terminated = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
|
||||
td = TensorDict(
|
||||
obs=obs,
|
||||
action=action.unsqueeze(0),
|
||||
reward=reward.unsqueeze(0),
|
||||
terminated=terminated.unsqueeze(0),
|
||||
batch_size=(1,))
|
||||
batch_size=(1, self.cfg.num_envs,))
|
||||
return td
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
train_metrics, done, eval_next = {}, True, False
|
||||
train_metrics, done, eval_next = {}, torch.tensor(True), True
|
||||
while self._step <= self.cfg.steps:
|
||||
# Evaluate agent periodically
|
||||
if self._step % self.cfg.eval_freq == 0:
|
||||
eval_next = True
|
||||
|
||||
# Reset environment
|
||||
if done:
|
||||
if done.any():
|
||||
assert done.all(), 'Vectorized environments must reset all environments at once.'
|
||||
if eval_next:
|
||||
eval_metrics = self.eval()
|
||||
eval_metrics.update(self.common_metrics())
|
||||
@@ -88,17 +89,19 @@ class OnlineTrainer(Trainer):
|
||||
eval_next = False
|
||||
|
||||
if self._step > 0:
|
||||
if info['terminated'] and not self.cfg.episodic:
|
||||
if info['terminated'].any() and not self.cfg.episodic:
|
||||
raise ValueError('Termination detected but you are not in episodic mode. ' \
|
||||
'Set `episodic=true` to enable support for terminations.')
|
||||
tds = torch.cat(self._tds)
|
||||
train_metrics.update(
|
||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||
episode_success=info['success'],
|
||||
episode_reward=tds['reward'].nansum(0).mean(),
|
||||
episode_success=info['success'].nanmean(),
|
||||
episode_length=len(self._tds),
|
||||
episode_terminated=info['terminated'])
|
||||
episode_terminated=info['terminated'].nanmean(),
|
||||
)
|
||||
train_metrics.update(self.common_metrics())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||
self._ep_idx = self.buffer.add(tds)
|
||||
|
||||
obs = self.env.reset()
|
||||
self._tds = [self.to_td(obs)]
|
||||
@@ -114,14 +117,16 @@ class OnlineTrainer(Trainer):
|
||||
# Update agent
|
||||
if self._step >= self.cfg.seed_steps:
|
||||
if self._step == self.cfg.seed_steps:
|
||||
num_updates = self.cfg.seed_steps
|
||||
num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update)
|
||||
print('Pretraining agent on seed data...')
|
||||
else:
|
||||
num_updates = 1
|
||||
num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update))
|
||||
for _ in range(num_updates):
|
||||
_train_metrics = self.agent.update(self.buffer)
|
||||
train_metrics.update(_train_metrics)
|
||||
if self._step == self.cfg.seed_steps:
|
||||
print('Pretraining complete.')
|
||||
|
||||
self._step += 1
|
||||
|
||||
self._step += self.cfg.num_envs
|
||||
|
||||
self.logger.finish(self.agent)
|
||||
|
||||
Reference in New Issue
Block a user