Files
tdmpc2/tdmpc2/trainer/base.py
2024-07-02 10:12:30 -07:00

20 lines
487 B
Python
Executable File

class Trainer:
"""Base trainer class for TD-MPC2."""
def __init__(self, cfg, env, agent, buffer, logger):
self.cfg = cfg
self.env = env
self.agent = agent
self.buffer = buffer
self.logger = logger
print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self):
"""Evaluate a TD-MPC2 agent."""
raise NotImplementedError
def train(self):
"""Train a TD-MPC2 agent."""
raise NotImplementedError