update documentation

This commit is contained in:
Nicklas Hansen
2023-12-31 14:38:22 -08:00
parent e3c876670a
commit 1d224cec3a

View File

@@ -10,7 +10,8 @@ from common.world_model import WorldModel
class TDMPC2:
"""
TD-MPC2 agent. Implements training + inference.
Can be used for both single-task and multi-task experiments.
Can be used for both single-task and multi-task experiments,
and supports both state and pixel observations.
"""
def __init__(self, cfg):
@@ -132,7 +133,7 @@ class TDMPC2:
actions[:, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI
for i in range(self.cfg.iterations):
for _ in range(self.cfg.iterations):
# Sample actions
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \