Merge branch 'vectorized_env' of github.com:nicklashansen/tdmpc2 into vectorized_env
This commit is contained in:
@@ -99,7 +99,7 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
|
||||
obs = obs.to(self.device, non_blocking=True)
|
||||
if task is not None:
|
||||
task = torch.tensor([task], device=self.device)
|
||||
if self.cfg.mpc:
|
||||
@@ -154,7 +154,7 @@ class TDMPC2(torch.nn.Module):
|
||||
actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
|
||||
|
||||
|
||||
# Iterate MPPI
|
||||
for _ in range(self.cfg.iterations):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user