init
This commit is contained in:
@@ -41,7 +41,7 @@ class Buffer():
|
|||||||
storage=storage,
|
storage=storage,
|
||||||
sampler=self._sampler,
|
sampler=self._sampler,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
prefetch=1,
|
prefetch=int(self.cfg.num_envs / self.cfg.steps_per_update),
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,11 +82,13 @@ class Buffer():
|
|||||||
|
|
||||||
def add(self, td):
|
def add(self, td):
|
||||||
"""Add an episode to the buffer."""
|
"""Add an episode to the buffer."""
|
||||||
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs)
|
||||||
|
td = td.permute(1, 0)
|
||||||
if self._num_eps == 0:
|
if self._num_eps == 0:
|
||||||
self._buffer = self._init(td)
|
self._buffer = self._init(td[0])
|
||||||
self._buffer.extend(td)
|
for i in range(self.cfg.num_envs):
|
||||||
self._num_eps += 1
|
self._buffer.extend(td[i])
|
||||||
|
self._num_eps += self.cfg.num_envs
|
||||||
return self._num_eps
|
return self._num_eps
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ defaults:
|
|||||||
# environment
|
# environment
|
||||||
task: dog-run
|
task: dog-run
|
||||||
obs: state
|
obs: state
|
||||||
|
num_envs: 1
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
checkpoint: ???
|
checkpoint: ???
|
||||||
@@ -13,6 +14,7 @@ eval_freq: 50000
|
|||||||
# training
|
# training
|
||||||
steps: 10_000_000
|
steps: 10_000_000
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
|
steps_per_update: 1
|
||||||
reward_coef: 0.1
|
reward_coef: 0.1
|
||||||
value_coef: 0.1
|
value_coef: 0.1
|
||||||
consistency_coef: 20
|
consistency_coef: 20
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import gym
|
|||||||
from envs.wrappers.multitask import MultitaskWrapper
|
from envs.wrappers.multitask import MultitaskWrapper
|
||||||
from envs.wrappers.pixels import PixelWrapper
|
from envs.wrappers.pixels import PixelWrapper
|
||||||
from envs.wrappers.tensor import TensorWrapper
|
from envs.wrappers.tensor import TensorWrapper
|
||||||
|
from envs.wrappers.vectorized import Vectorized
|
||||||
|
|
||||||
|
|
||||||
def missing_dependencies(task):
|
def missing_dependencies(task):
|
||||||
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
||||||
@@ -59,16 +61,19 @@ def make_env(cfg):
|
|||||||
gym.logger.set_level(40)
|
gym.logger.set_level(40)
|
||||||
if cfg.multitask:
|
if cfg.multitask:
|
||||||
env = make_multitask_env(cfg)
|
env = make_multitask_env(cfg)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
env = None
|
env = None
|
||||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||||
try:
|
try:
|
||||||
env = fn(cfg)
|
env = fn(cfg)
|
||||||
|
break
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
if env is None:
|
if env is None:
|
||||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||||
|
assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \
|
||||||
|
'Vectorized environments only support state observations.'
|
||||||
|
env = Vectorized(cfg, fn)
|
||||||
env = TensorWrapper(env)
|
env = TensorWrapper(env)
|
||||||
if cfg.get('obs', 'state') == 'rgb':
|
if cfg.get('obs', 'state') == 'rgb':
|
||||||
env = PixelWrapper(cfg, env)
|
env = PixelWrapper(cfg, env)
|
||||||
@@ -78,5 +83,5 @@ def make_env(cfg):
|
|||||||
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
|
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
|
||||||
cfg.action_dim = env.action_space.shape[0]
|
cfg.action_dim = env.action_space.shape[0]
|
||||||
cfg.episode_length = env.max_episode_steps
|
cfg.episode_length = env.max_episode_steps
|
||||||
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -177,6 +177,9 @@ class TimeStepToGymWrapper:
|
|||||||
camera_id = dict(quadruped=2).get(self.domain, camera_id)
|
camera_id = dict(quadruped=2).get(self.domain, camera_id)
|
||||||
return self.env.physics.render(height, width, camera_id)
|
return self.env.physics.render(height, width, camera_id)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.env.close()
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg):
|
def make_env(cfg):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -12,8 +12,11 @@ class TensorWrapper(gym.Wrapper):
|
|||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized'
|
||||||
|
|
||||||
def rand_act(self):
|
def rand_act(self):
|
||||||
|
if self._wrapped_vectorized:
|
||||||
|
return self.env.rand_act()
|
||||||
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
||||||
|
|
||||||
def _try_f32_tensor(self, x):
|
def _try_f32_tensor(self, x):
|
||||||
@@ -30,11 +33,23 @@ class TensorWrapper(gym.Wrapper):
|
|||||||
obs = self._try_f32_tensor(obs)
|
obs = self._try_f32_tensor(obs)
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def reset(self, task_idx=None):
|
def reset(self, task_idx=None, **kwargs):
|
||||||
return self._obs_to_tensor(self.env.reset())
|
if self._wrapped_vectorized:
|
||||||
|
obs = self.env.reset(**kwargs)
|
||||||
|
else:
|
||||||
|
obs = self.env.reset()
|
||||||
|
return self._obs_to_tensor(obs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action, **kwargs):
|
||||||
|
if self._wrapped_vectorized:
|
||||||
|
obs, reward, done, info = self.env.step(action.numpy(), **kwargs)
|
||||||
|
else:
|
||||||
obs, reward, done, info = self.env.step(action.numpy())
|
obs, reward, done, info = self.env.step(action.numpy())
|
||||||
|
if isinstance(info, tuple):
|
||||||
|
info = {key: torch.stack([torch.tensor(d[key]) for d in info]) for key in info[0].keys()}
|
||||||
|
if 'success' not in info.keys():
|
||||||
|
info['success'] = torch.zeros(len(done))
|
||||||
|
else:
|
||||||
info = defaultdict(float, info)
|
info = defaultdict(float, info)
|
||||||
info['success'] = float(info['success'])
|
info['success'] = float(info['success'])
|
||||||
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
||||||
|
|||||||
40
tdmpc2/envs/wrappers/vectorized.py
Normal file
40
tdmpc2/envs/wrappers/vectorized.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from gym.vector import AsyncVectorEnv
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Vectorized():
|
||||||
|
"""
|
||||||
|
Vectorized environment for TD-MPC2 online training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg, env_fn):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def make():
|
||||||
|
_cfg = deepcopy(cfg)
|
||||||
|
_cfg.num_envs = 1
|
||||||
|
_cfg.seed = cfg.seed + np.random.randint(1000)
|
||||||
|
return env_fn(_cfg)
|
||||||
|
|
||||||
|
print(f'Creating {cfg.num_envs} environments...')
|
||||||
|
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
|
||||||
|
env = make()
|
||||||
|
self.observation_space = env.observation_space
|
||||||
|
self.action_space = env.action_space
|
||||||
|
self.max_episode_steps = env.max_episode_steps
|
||||||
|
|
||||||
|
def rand_act(self):
|
||||||
|
return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.env.reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
return self.env.step(action)
|
||||||
|
|
||||||
|
def render(self, *args, **kwargs):
|
||||||
|
return self.env.render(*args, **kwargs)
|
||||||
@@ -81,23 +81,23 @@ class TDMPC2:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action to take in the environment.
|
torch.Tensor: Action to take in the environment.
|
||||||
"""
|
"""
|
||||||
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
|
obs = obs.to(self.device, non_blocking=True)
|
||||||
if task is not None:
|
if task is not None:
|
||||||
task = torch.tensor([task], device=self.device)
|
task = torch.tensor([task], device=self.device)
|
||||||
z = self.model.encode(obs, task)
|
z = self.model.encode(obs, task)
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
|
action = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
|
||||||
else:
|
else:
|
||||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
action = self.model.pi(z, task)[int(not eval_mode)]
|
||||||
return a.cpu()
|
return action.cpu()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _estimate_value(self, z, actions, task):
|
def _estimate_value(self, z, actions, task):
|
||||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||||
G, discount = 0, 1
|
G, discount = 0, 1
|
||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
|
||||||
z = self.model.next(z, actions[t], task)
|
z = self.model.next(z, actions[:, t], task)
|
||||||
G += discount * reward
|
G += discount * reward
|
||||||
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||||
@@ -118,57 +118,69 @@ class TDMPC2:
|
|||||||
"""
|
"""
|
||||||
# Sample policy trajectories
|
# Sample policy trajectories
|
||||||
if self.cfg.num_pi_trajs > 0:
|
if self.cfg.num_pi_trajs > 0:
|
||||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
_z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1)
|
||||||
for t in range(self.cfg.horizon-1):
|
for t in range(self.cfg.horizon-1):
|
||||||
pi_actions[t] = self.model.pi(_z, task)[1]
|
pi_actions[:,t] = self.model.pi(_z, task)[1]
|
||||||
_z = self.model.next(_z, pi_actions[t], task)
|
_z = self.model.next(_z, pi_actions[:,t], task)
|
||||||
pi_actions[-1] = self.model.pi(_z, task)[1]
|
pi_actions[:,-1] = self.model.pi(_z, task)[1]
|
||||||
|
|
||||||
# Initialize state and parameters
|
# Initialize state and parameters
|
||||||
z = z.repeat(self.cfg.num_samples, 1)
|
z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1)
|
||||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||||
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
std = self.cfg.max_std*torch.ones(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||||
if not t0:
|
if not t0:
|
||||||
mean[:-1] = self._prev_mean[1:]
|
mean[:, :-1] = self._prev_mean[:, 1:]
|
||||||
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||||
if self.cfg.num_pi_trajs > 0:
|
if self.cfg.num_pi_trajs > 0:
|
||||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
|
||||||
|
|
||||||
# Iterate MPPI
|
# Iterate MPPI
|
||||||
for _ in range(self.cfg.iterations):
|
for _ in range(self.cfg.iterations):
|
||||||
|
|
||||||
# Sample actions
|
# Sample actions
|
||||||
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \
|
actions[:, :, self.cfg.num_pi_trajs:] = (mean.unsqueeze(2) + std.unsqueeze(2) * \
|
||||||
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \
|
torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \
|
||||||
.clamp(-1, 1)
|
.clamp(-1, 1)
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
actions = actions * self.model._action_masks[task]
|
actions = actions * self.model._action_masks[task]
|
||||||
|
|
||||||
# Compute elite actions
|
# Compute elite actions
|
||||||
value = self._estimate_value(z, actions, task).nan_to_num_(0)
|
value = self._estimate_value(z, actions, task).nan_to_num_(0)
|
||||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
|
||||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
||||||
|
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
|
||||||
|
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
|
||||||
|
|
||||||
|
# vectorized version
|
||||||
|
# elite_value, elite_actions = [], []
|
||||||
|
# for i in range(self.cfg.num_envs):
|
||||||
|
# elite_value.append(value[i, elite_idxs[i]])
|
||||||
|
# elite_actions.append(actions[i, elite_idxs[i]])
|
||||||
|
# elite_value = torch.stack(elite_value, dim=0)
|
||||||
|
|
||||||
# Update parameters
|
# Update parameters
|
||||||
max_value = elite_value.max(0)[0]
|
max_value = elite_value.max(1)[0]
|
||||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
|
||||||
score /= score.sum(0)
|
score /= score.sum(1, keepdim=True)
|
||||||
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
mean = torch.sum(score.unsqueeze(1) * elite_actions, dim=2) / (score.sum(1, keepdim=True) + 1e-9)
|
||||||
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \
|
std = torch.sqrt(torch.sum(score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2, dim=2) / (score.sum(1, keepdim=True) + 1e-9)) \
|
||||||
.clamp_(self.cfg.min_std, self.cfg.max_std)
|
.clamp_(self.cfg.min_std, self.cfg.max_std)
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
mean = mean * self.model._action_masks[task]
|
mean = mean * self.model._action_masks[task]
|
||||||
std = std * self.model._action_masks[task]
|
std = std * self.model._action_masks[task]
|
||||||
|
|
||||||
# Select action
|
# Select action sequence with probability `score`
|
||||||
score = score.squeeze(1).cpu().numpy()
|
score = score.squeeze(1).squeeze(-1).cpu().numpy()
|
||||||
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
actions = torch.stack([
|
||||||
|
elite_actions[i, :, np.random.choice(np.arange(score.shape[1]), p=score[i])] \
|
||||||
|
for i in range(score.shape[0])], dim=0)
|
||||||
|
|
||||||
self._prev_mean = mean
|
self._prev_mean = mean
|
||||||
a, std = actions[0], std[0]
|
action, std = actions[:, 0], std[:, 0]
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
a += std * torch.randn(self.cfg.action_dim, device=std.device)
|
action += std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||||
return a.clamp_(-1, 1)
|
return action.clamp_(-1, 1)
|
||||||
|
|
||||||
def update_pi(self, zs, task):
|
def update_pi(self, zs, task):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from tensordict.tensordict import TensorDict
|
from tensordict.tensordict import TensorDict
|
||||||
|
|
||||||
@@ -26,25 +25,25 @@ class OnlineTrainer(Trainer):
|
|||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Evaluate a TD-MPC2 agent."""
|
"""Evaluate a TD-MPC2 agent."""
|
||||||
ep_rewards, ep_successes = [], []
|
ep_rewards = []
|
||||||
for i in range(self.cfg.eval_episodes):
|
for i in range(self.cfg.eval_episodes // self.cfg.num_envs):
|
||||||
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0
|
||||||
if self.cfg.save_video:
|
if self.cfg.save_video:
|
||||||
self.logger.video.init(self.env, enabled=(i==0))
|
self.logger.video.init(self.env, enabled=(i==0))
|
||||||
while not done:
|
while not done.any():
|
||||||
action = self.agent.act(obs, t0=t==0, eval_mode=True)
|
action = self.agent.act(obs, t0=t==0, eval_mode=True)
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
ep_reward += reward
|
ep_reward += reward
|
||||||
t += 1
|
t += 1
|
||||||
if self.cfg.save_video:
|
if self.cfg.save_video:
|
||||||
self.logger.video.record(self.env)
|
self.logger.video.record(self.env)
|
||||||
|
assert done.all(), 'Vectorized environments must reset all environments at once.'
|
||||||
ep_rewards.append(ep_reward)
|
ep_rewards.append(ep_reward)
|
||||||
ep_successes.append(info['success'])
|
|
||||||
if self.cfg.save_video:
|
if self.cfg.save_video:
|
||||||
self.logger.video.save(self._step)
|
self.logger.video.save(self._step)
|
||||||
return dict(
|
return dict(
|
||||||
episode_reward=np.nanmean(ep_rewards),
|
episode_reward=torch.cat(ep_rewards).mean(),
|
||||||
episode_success=np.nanmean(ep_successes),
|
episode_success=info['success'].mean(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_td(self, obs, action=None, reward=None):
|
def to_td(self, obs, action=None, reward=None):
|
||||||
@@ -56,17 +55,17 @@ class OnlineTrainer(Trainer):
|
|||||||
if action is None:
|
if action is None:
|
||||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||||
if reward is None:
|
if reward is None:
|
||||||
reward = torch.tensor(float('nan'))
|
reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
|
||||||
td = TensorDict(dict(
|
td = TensorDict(dict(
|
||||||
obs=obs,
|
obs=obs,
|
||||||
action=action.unsqueeze(0),
|
action=action.unsqueeze(0),
|
||||||
reward=reward.unsqueeze(0),
|
reward=reward.unsqueeze(0),
|
||||||
), batch_size=(1,))
|
), batch_size=(1, self.cfg.num_envs,))
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Train a TD-MPC2 agent."""
|
"""Train a TD-MPC2 agent."""
|
||||||
train_metrics, done, eval_next = {}, True, True
|
train_metrics, done, eval_next = {}, torch.tensor(True), True
|
||||||
while self._step <= self.cfg.steps:
|
while self._step <= self.cfg.steps:
|
||||||
|
|
||||||
# Evaluate agent periodically
|
# Evaluate agent periodically
|
||||||
@@ -74,7 +73,8 @@ class OnlineTrainer(Trainer):
|
|||||||
eval_next = True
|
eval_next = True
|
||||||
|
|
||||||
# Reset environment
|
# Reset environment
|
||||||
if done:
|
if done.any():
|
||||||
|
assert done.all(), 'Vectorized environments must reset all environments at once.'
|
||||||
if eval_next:
|
if eval_next:
|
||||||
eval_metrics = self.eval()
|
eval_metrics = self.eval()
|
||||||
eval_metrics.update(self.common_metrics())
|
eval_metrics.update(self.common_metrics())
|
||||||
@@ -82,13 +82,14 @@ class OnlineTrainer(Trainer):
|
|||||||
eval_next = False
|
eval_next = False
|
||||||
|
|
||||||
if self._step > 0:
|
if self._step > 0:
|
||||||
|
tds = torch.cat(self._tds)
|
||||||
train_metrics.update(
|
train_metrics.update(
|
||||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
episode_reward=tds['reward'].nansum(0).mean(),
|
||||||
episode_success=info['success'],
|
episode_success=info['success'].nanmean(),
|
||||||
)
|
)
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
self._ep_idx = self.buffer.add(tds)
|
||||||
|
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
self._tds = [self.to_td(obs)]
|
self._tds = [self.to_td(obs)]
|
||||||
@@ -104,14 +105,14 @@ class OnlineTrainer(Trainer):
|
|||||||
# Update agent
|
# Update agent
|
||||||
if self._step >= self.cfg.seed_steps:
|
if self._step >= self.cfg.seed_steps:
|
||||||
if self._step == self.cfg.seed_steps:
|
if self._step == self.cfg.seed_steps:
|
||||||
num_updates = self.cfg.seed_steps
|
num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update)
|
||||||
print('Pretraining agent on seed data...')
|
print('Pretraining agent on seed data...')
|
||||||
else:
|
else:
|
||||||
num_updates = 1
|
num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update))
|
||||||
for _ in range(num_updates):
|
for _ in range(num_updates):
|
||||||
_train_metrics = self.agent.update(self.buffer)
|
_train_metrics = self.agent.update(self.buffer)
|
||||||
train_metrics.update(_train_metrics)
|
train_metrics.update(_train_metrics)
|
||||||
|
|
||||||
self._step += 1
|
self._step += self.cfg.num_envs
|
||||||
|
|
||||||
self.logger.finish(self.agent)
|
self.logger.finish(self.agent)
|
||||||
|
|||||||
Reference in New Issue
Block a user