update documentation
This commit is contained in:
@@ -10,7 +10,8 @@ from common.world_model import WorldModel
|
|||||||
class TDMPC2:
|
class TDMPC2:
|
||||||
"""
|
"""
|
||||||
TD-MPC2 agent. Implements training + inference.
|
TD-MPC2 agent. Implements training + inference.
|
||||||
Can be used for both single-task and multi-task experiments.
|
Can be used for both single-task and multi-task experiments,
|
||||||
|
and supports both state and pixel observations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
@@ -132,7 +133,7 @@ class TDMPC2:
|
|||||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
||||||
|
|
||||||
# Iterate MPPI
|
# Iterate MPPI
|
||||||
for i in range(self.cfg.iterations):
|
for _ in range(self.cfg.iterations):
|
||||||
|
|
||||||
# Sample actions
|
# Sample actions
|
||||||
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \
|
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \
|
||||||
|
|||||||
Reference in New Issue
Block a user