diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index f403b3e..0000000 --- a/requirements.txt +++ /dev/null @@ -1,34 +0,0 @@ -absl-py -cython -dm-control -ffmpeg -glfw -hydra-core -hydra-submitit-launcher -imageio -imageio-ffmpeg -kornia -moviepy -mujoco -mujoco-py -numpy<2 -omegaconf -open3d -opencv-contrib-python -opencv-python -pandas -sapien -submitit -setuptools -patchelf -protobuf -pillow -pyquaternion -tensordict-nightly -termcolor -torchrl-nightly -transforms3d -trimesh -tqdm -wandb -wheel diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index b3047f9..bdbb998 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -5,7 +5,6 @@ import re import numpy as np import pandas as pd from termcolor import colored -from torchrl._utils import timeit from common import TASK_SET @@ -238,5 +237,3 @@ class Logger: self._log_dir / "eval.csv", header=keys, index=None ) self._print(d, category) - timeit.print() - timeit.erase() diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 65572dd..eb9633d 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -1,11 +1,9 @@ from copy import deepcopy -import numpy as np import torch import torch.nn as nn from common import layers, math, init -from tensordict import TensorDict from tensordict.nn import TensorDictParams class WorldModel(nn.Module): @@ -48,6 +46,14 @@ class WorldModel(nn.Module): self._detach_Qs.params = self._detach_Qs_params self._target_Qs.params = self._target_Qs_params + def __repr__(self): + repr = 'TD-MPC2 World Model\n' + modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions'] + for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]): + repr += f"{modules[i]}: {m}\n" + repr += "Learnable parameters: {:,}".format(self.total_params) + return repr + @property def total_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 6ce3d7f..a4da1db 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -1,7 +1,6 @@ import torch import torch.nn.functional as F import functools -from torchrl._utils import timeit from common import math from common.scale import RunningScale @@ -280,8 +279,7 @@ class TDMPC2(torch.nn.Module): Returns: dict: Dictionary of training statistics. """ - with timeit("sample"): - obs, action, reward, task = buffer.sample() + obs, action, reward, task = buffer.sample() kwargs = {} if task is not None: kwargs["task"] = task diff --git a/tdmpc2/trainer/base.py b/tdmpc2/trainer/base.py index 27a328d..6d14783 100755 --- a/tdmpc2/trainer/base.py +++ b/tdmpc2/trainer/base.py @@ -8,7 +8,6 @@ class Trainer: self.buffer = buffer self.logger = logger print('Architecture:', self.agent.model) - print("Learnable parameters: {:,}".format(self.agent.model.total_params)) def eval(self): """Evaluate a TD-MPC2 agent.""" diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index f3072b5..103d129 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -3,7 +3,6 @@ from time import time import numpy as np import torch from tensordict.tensordict import TensorDict -from torchrl._utils import timeit from trainer.base import Trainer @@ -68,53 +67,49 @@ class OnlineTrainer(Trainer): """Train a TD-MPC2 agent.""" train_metrics, done, eval_next = {}, True, False while self._step <= self.cfg.steps: - with timeit("global-step"): - # Evaluate agent periodically - if self._step > 0 and self._step % self.cfg.eval_freq == 0: - eval_next = True + # Evaluate agent periodically + if self._step > 0 and self._step % self.cfg.eval_freq == 0: + eval_next = True - # Reset environment - if done or (self._step == self.cfg.seed_steps + 1): - if eval_next: - eval_metrics = self.eval() - eval_metrics.update(self.common_metrics()) - self.logger.log(eval_metrics, 'eval') - eval_next = False + # Reset environment + if done or (self._step == self.cfg.seed_steps + 1): + if eval_next: + eval_metrics = self.eval() + eval_metrics.update(self.common_metrics()) + self.logger.log(eval_metrics, 'eval') + eval_next = False - if self._step > 0: - train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), - episode_success=info['success'], - ) - train_metrics.update(self.common_metrics()) - train_metrics.update(timeit.todict()) - self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + if self._step > 0: + train_metrics.update( + episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), + episode_success=info['success'], + ) + train_metrics.update(self.common_metrics()) + self.logger.log(train_metrics, 'train') + self._ep_idx = self.buffer.add(torch.cat(self._tds)) - obs = self.env.reset() - self._tds = [self.to_td(obs)] + obs = self.env.reset() + self._tds = [self.to_td(obs)] - # Collect experience - with timeit("act"): - if self._step > self.cfg.seed_steps: - action = self.agent.act(obs, t0=len(self._tds)==1) - else: - action = self.env.rand_act() - obs, reward, done, info = self.env.step(action) - self._tds.append(self.to_td(obs, action, reward)) + # Collect experience + if self._step > self.cfg.seed_steps: + action = self.agent.act(obs, t0=len(self._tds)==1) + else: + action = self.env.rand_act() + obs, reward, done, info = self.env.step(action) + self._tds.append(self.to_td(obs, action, reward)) - # Update agent - if self._step >= self.cfg.seed_steps: - if self._step == self.cfg.seed_steps: - num_updates = self.cfg.seed_steps - print('Pretraining agent on seed data...') - else: - num_updates = 1 - for _ in range(num_updates): - with timeit("update"): - _train_metrics = self.agent.update(self.buffer) - train_metrics.update(_train_metrics) + # Update agent + if self._step >= self.cfg.seed_steps: + if self._step == self.cfg.seed_steps: + num_updates = self.cfg.seed_steps + print('Pretraining agent on seed data...') + else: + num_updates = 1 + for _ in range(num_updates): + _train_metrics = self.agent.update(self.buffer) + train_metrics.update(_train_metrics) - self._step += 1 + self._step += 1 self.logger.finish(self.agent)