Merge branch 'vectorized_env' of github.com:nicklashansen/tdmpc2 into vectorized_env

This commit is contained in:
Nicklas Hansen
2024-11-10 13:04:54 -08:00

View File

@@ -99,7 +99,7 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
obs = obs.to(self.device, non_blocking=True).unsqueeze(0) obs = obs.to(self.device, non_blocking=True)
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
if self.cfg.mpc: if self.cfg.mpc:
@@ -154,7 +154,7 @@ class TDMPC2(torch.nn.Module):
actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
actions[:, :, :self.cfg.num_pi_trajs] = pi_actions actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI # Iterate MPPI
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):