diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 9ee3ff5..3925359 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -10,7 +10,8 @@ from common.world_model import WorldModel class TDMPC2: """ TD-MPC2 agent. Implements training + inference. - Can be used for both single-task and multi-task experiments. + Can be used for both single-task and multi-task experiments, + and supports both state and pixel observations. """ def __init__(self, cfg): @@ -132,7 +133,7 @@ class TDMPC2: actions[:, :self.cfg.num_pi_trajs] = pi_actions # Iterate MPPI - for i in range(self.cfg.iterations): + for _ in range(self.cfg.iterations): # Sample actions actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \