Merge branch 'vectorized_env' of github.com:nicklashansen/tdmpc2 into vectorized_env
This commit is contained in:
@@ -99,7 +99,7 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
|
||||
obs = obs.to(self.device, non_blocking=True)
|
||||
if task is not None:
|
||||
task = torch.tensor([task], device=self.device)
|
||||
if self.cfg.mpc:
|
||||
|
||||
Reference in New Issue
Block a user