clean up prints

This commit is contained in:
Nicklas Hansen
2024-10-18 15:31:25 -07:00
parent c3a912e10d
commit 970792e2b6
6 changed files with 47 additions and 86 deletions

View File

@@ -1,34 +0,0 @@
absl-py
cython
dm-control
ffmpeg
glfw
hydra-core
hydra-submitit-launcher
imageio
imageio-ffmpeg
kornia
moviepy
mujoco
mujoco-py
numpy<2
omegaconf
open3d
opencv-contrib-python
opencv-python
pandas
sapien
submitit
setuptools
patchelf
protobuf
pillow
pyquaternion
tensordict-nightly
termcolor
torchrl-nightly
transforms3d
trimesh
tqdm
wandb
wheel

View File

@@ -5,7 +5,6 @@ import re
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from termcolor import colored from termcolor import colored
from torchrl._utils import timeit
from common import TASK_SET from common import TASK_SET
@@ -238,5 +237,3 @@ class Logger:
self._log_dir / "eval.csv", header=keys, index=None self._log_dir / "eval.csv", header=keys, index=None
) )
self._print(d, category) self._print(d, category)
timeit.print()
timeit.erase()

View File

@@ -1,11 +1,9 @@
from copy import deepcopy from copy import deepcopy
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from common import layers, math, init from common import layers, math, init
from tensordict import TensorDict
from tensordict.nn import TensorDictParams from tensordict.nn import TensorDictParams
class WorldModel(nn.Module): class WorldModel(nn.Module):
@@ -48,6 +46,14 @@ class WorldModel(nn.Module):
self._detach_Qs.params = self._detach_Qs_params self._detach_Qs.params = self._detach_Qs_params
self._target_Qs.params = self._target_Qs_params self._target_Qs.params = self._target_Qs_params
def __repr__(self):
repr = 'TD-MPC2 World Model\n'
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
repr += f"{modules[i]}: {m}\n"
repr += "Learnable parameters: {:,}".format(self.total_params)
return repr
@property @property
def total_params(self): def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters() if p.requires_grad)

View File

@@ -1,7 +1,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import functools import functools
from torchrl._utils import timeit
from common import math from common import math
from common.scale import RunningScale from common.scale import RunningScale
@@ -280,7 +279,6 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
dict: Dictionary of training statistics. dict: Dictionary of training statistics.
""" """
with timeit("sample"):
obs, action, reward, task = buffer.sample() obs, action, reward, task = buffer.sample()
kwargs = {} kwargs = {}
if task is not None: if task is not None:

View File

@@ -8,7 +8,6 @@ class Trainer:
self.buffer = buffer self.buffer = buffer
self.logger = logger self.logger = logger
print('Architecture:', self.agent.model) print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""

View File

@@ -3,7 +3,6 @@ from time import time
import numpy as np import numpy as np
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl._utils import timeit
from trainer.base import Trainer from trainer.base import Trainer
@@ -68,7 +67,6 @@ class OnlineTrainer(Trainer):
"""Train a TD-MPC2 agent.""" """Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, False train_metrics, done, eval_next = {}, True, False
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
with timeit("global-step"):
# Evaluate agent periodically # Evaluate agent periodically
if self._step > 0 and self._step % self.cfg.eval_freq == 0: if self._step > 0 and self._step % self.cfg.eval_freq == 0:
eval_next = True eval_next = True
@@ -87,7 +85,6 @@ class OnlineTrainer(Trainer):
episode_success=info['success'], episode_success=info['success'],
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
train_metrics.update(timeit.todict())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(torch.cat(self._tds))
@@ -95,7 +92,6 @@ class OnlineTrainer(Trainer):
self._tds = [self.to_td(obs)] self._tds = [self.to_td(obs)]
# Collect experience # Collect experience
with timeit("act"):
if self._step > self.cfg.seed_steps: if self._step > self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1) action = self.agent.act(obs, t0=len(self._tds)==1)
else: else:
@@ -111,7 +107,6 @@ class OnlineTrainer(Trainer):
else: else:
num_updates = 1 num_updates = 1
for _ in range(num_updates): for _ in range(num_updates):
with timeit("update"):
_train_metrics = self.agent.update(self.buffer) _train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics) train_metrics.update(_train_metrics)