Merge branch 'vectorized_env' of github.com:nicklashansen/tdmpc2 into vectorized_env

This commit is contained in:
Nicklas Hansen
2024-11-10 13:04:54 -08:00

View File

@@ -99,7 +99,7 @@ class TDMPC2(torch.nn.Module):
Returns:
torch.Tensor: Action to take in the environment.
"""
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
obs = obs.to(self.device, non_blocking=True)
if task is not None:
task = torch.tensor([task], device=self.device)
if self.cfg.mpc: