clean up prints

This commit is contained in:
Nicklas Hansen
2024-10-18 15:31:25 -07:00
parent c3a912e10d
commit 970792e2b6
6 changed files with 47 additions and 86 deletions

View File

@@ -1,34 +0,0 @@
absl-py
cython
dm-control
ffmpeg
glfw
hydra-core
hydra-submitit-launcher
imageio
imageio-ffmpeg
kornia
moviepy
mujoco
mujoco-py
numpy<2
omegaconf
open3d
opencv-contrib-python
opencv-python
pandas
sapien
submitit
setuptools
patchelf
protobuf
pillow
pyquaternion
tensordict-nightly
termcolor
torchrl-nightly
transforms3d
trimesh
tqdm
wandb
wheel

View File

@@ -5,7 +5,6 @@ import re
import numpy as np
import pandas as pd
from termcolor import colored
from torchrl._utils import timeit
from common import TASK_SET
@@ -238,5 +237,3 @@ class Logger:
self._log_dir / "eval.csv", header=keys, index=None
)
self._print(d, category)
timeit.print()
timeit.erase()

View File

@@ -1,11 +1,9 @@
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from common import layers, math, init
from tensordict import TensorDict
from tensordict.nn import TensorDictParams
class WorldModel(nn.Module):
@@ -48,6 +46,14 @@ class WorldModel(nn.Module):
self._detach_Qs.params = self._detach_Qs_params
self._target_Qs.params = self._target_Qs_params
def __repr__(self):
repr = 'TD-MPC2 World Model\n'
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
repr += f"{modules[i]}: {m}\n"
repr += "Learnable parameters: {:,}".format(self.total_params)
return repr
@property
def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)

View File

@@ -1,7 +1,6 @@
import torch
import torch.nn.functional as F
import functools
from torchrl._utils import timeit
from common import math
from common.scale import RunningScale
@@ -280,8 +279,7 @@ class TDMPC2(torch.nn.Module):
Returns:
dict: Dictionary of training statistics.
"""
with timeit("sample"):
obs, action, reward, task = buffer.sample()
obs, action, reward, task = buffer.sample()
kwargs = {}
if task is not None:
kwargs["task"] = task

View File

@@ -8,7 +8,6 @@ class Trainer:
self.buffer = buffer
self.logger = logger
print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self):
"""Evaluate a TD-MPC2 agent."""

View File

@@ -3,7 +3,6 @@ from time import time
import numpy as np
import torch
from tensordict.tensordict import TensorDict
from torchrl._utils import timeit
from trainer.base import Trainer
@@ -68,53 +67,49 @@ class OnlineTrainer(Trainer):
"""Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, False
while self._step <= self.cfg.steps:
with timeit("global-step"):
# Evaluate agent periodically
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
eval_next = True
# Evaluate agent periodically
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
eval_next = True
# Reset environment
if done or (self._step == self.cfg.seed_steps + 1):
if eval_next:
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())
self.logger.log(eval_metrics, 'eval')
eval_next = False
# Reset environment
if done or (self._step == self.cfg.seed_steps + 1):
if eval_next:
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())
self.logger.log(eval_metrics, 'eval')
eval_next = False
if self._step > 0:
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
)
train_metrics.update(self.common_metrics())
train_metrics.update(timeit.todict())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))
if self._step > 0:
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset()
self._tds = [self.to_td(obs)]
obs = self.env.reset()
self._tds = [self.to_td(obs)]
# Collect experience
with timeit("act"):
if self._step > self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1)
else:
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward))
# Collect experience
if self._step > self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1)
else:
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward))
# Update agent
if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps
print('Pretraining agent on seed data...')
else:
num_updates = 1
for _ in range(num_updates):
with timeit("update"):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
# Update agent
if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps
print('Pretraining agent on seed data...')
else:
num_updates = 1
for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
self._step += 1
self._step += 1
self.logger.finish(self.agent)