clean up prints
This commit is contained in:
@@ -1,34 +0,0 @@
|
||||
absl-py
|
||||
cython
|
||||
dm-control
|
||||
ffmpeg
|
||||
glfw
|
||||
hydra-core
|
||||
hydra-submitit-launcher
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
kornia
|
||||
moviepy
|
||||
mujoco
|
||||
mujoco-py
|
||||
numpy<2
|
||||
omegaconf
|
||||
open3d
|
||||
opencv-contrib-python
|
||||
opencv-python
|
||||
pandas
|
||||
sapien
|
||||
submitit
|
||||
setuptools
|
||||
patchelf
|
||||
protobuf
|
||||
pillow
|
||||
pyquaternion
|
||||
tensordict-nightly
|
||||
termcolor
|
||||
torchrl-nightly
|
||||
transforms3d
|
||||
trimesh
|
||||
tqdm
|
||||
wandb
|
||||
wheel
|
||||
@@ -5,7 +5,6 @@ import re
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from termcolor import colored
|
||||
from torchrl._utils import timeit
|
||||
from common import TASK_SET
|
||||
|
||||
|
||||
@@ -238,5 +237,3 @@ class Logger:
|
||||
self._log_dir / "eval.csv", header=keys, index=None
|
||||
)
|
||||
self._print(d, category)
|
||||
timeit.print()
|
||||
timeit.erase()
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from common import layers, math, init
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictParams
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
@@ -48,6 +46,14 @@ class WorldModel(nn.Module):
|
||||
self._detach_Qs.params = self._detach_Qs_params
|
||||
self._target_Qs.params = self._target_Qs_params
|
||||
|
||||
def __repr__(self):
|
||||
repr = 'TD-MPC2 World Model\n'
|
||||
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
|
||||
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
|
||||
repr += f"{modules[i]}: {m}\n"
|
||||
repr += "Learnable parameters: {:,}".format(self.total_params)
|
||||
return repr
|
||||
|
||||
@property
|
||||
def total_params(self):
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import functools
|
||||
from torchrl._utils import timeit
|
||||
|
||||
from common import math
|
||||
from common.scale import RunningScale
|
||||
@@ -280,8 +279,7 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
dict: Dictionary of training statistics.
|
||||
"""
|
||||
with timeit("sample"):
|
||||
obs, action, reward, task = buffer.sample()
|
||||
obs, action, reward, task = buffer.sample()
|
||||
kwargs = {}
|
||||
if task is not None:
|
||||
kwargs["task"] = task
|
||||
|
||||
@@ -8,7 +8,6 @@ class Trainer:
|
||||
self.buffer = buffer
|
||||
self.logger = logger
|
||||
print('Architecture:', self.agent.model)
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
|
||||
@@ -3,7 +3,6 @@ from time import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.tensordict import TensorDict
|
||||
from torchrl._utils import timeit
|
||||
from trainer.base import Trainer
|
||||
|
||||
|
||||
@@ -68,53 +67,49 @@ class OnlineTrainer(Trainer):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
train_metrics, done, eval_next = {}, True, False
|
||||
while self._step <= self.cfg.steps:
|
||||
with timeit("global-step"):
|
||||
# Evaluate agent periodically
|
||||
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
|
||||
eval_next = True
|
||||
# Evaluate agent periodically
|
||||
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
|
||||
eval_next = True
|
||||
|
||||
# Reset environment
|
||||
if done or (self._step == self.cfg.seed_steps + 1):
|
||||
if eval_next:
|
||||
eval_metrics = self.eval()
|
||||
eval_metrics.update(self.common_metrics())
|
||||
self.logger.log(eval_metrics, 'eval')
|
||||
eval_next = False
|
||||
# Reset environment
|
||||
if done or (self._step == self.cfg.seed_steps + 1):
|
||||
if eval_next:
|
||||
eval_metrics = self.eval()
|
||||
eval_metrics.update(self.common_metrics())
|
||||
self.logger.log(eval_metrics, 'eval')
|
||||
eval_next = False
|
||||
|
||||
if self._step > 0:
|
||||
train_metrics.update(
|
||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||
episode_success=info['success'],
|
||||
)
|
||||
train_metrics.update(self.common_metrics())
|
||||
train_metrics.update(timeit.todict())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||
if self._step > 0:
|
||||
train_metrics.update(
|
||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||
episode_success=info['success'],
|
||||
)
|
||||
train_metrics.update(self.common_metrics())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||
|
||||
obs = self.env.reset()
|
||||
self._tds = [self.to_td(obs)]
|
||||
obs = self.env.reset()
|
||||
self._tds = [self.to_td(obs)]
|
||||
|
||||
# Collect experience
|
||||
with timeit("act"):
|
||||
if self._step > self.cfg.seed_steps:
|
||||
action = self.agent.act(obs, t0=len(self._tds)==1)
|
||||
else:
|
||||
action = self.env.rand_act()
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._tds.append(self.to_td(obs, action, reward))
|
||||
# Collect experience
|
||||
if self._step > self.cfg.seed_steps:
|
||||
action = self.agent.act(obs, t0=len(self._tds)==1)
|
||||
else:
|
||||
action = self.env.rand_act()
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._tds.append(self.to_td(obs, action, reward))
|
||||
|
||||
# Update agent
|
||||
if self._step >= self.cfg.seed_steps:
|
||||
if self._step == self.cfg.seed_steps:
|
||||
num_updates = self.cfg.seed_steps
|
||||
print('Pretraining agent on seed data...')
|
||||
else:
|
||||
num_updates = 1
|
||||
for _ in range(num_updates):
|
||||
with timeit("update"):
|
||||
_train_metrics = self.agent.update(self.buffer)
|
||||
train_metrics.update(_train_metrics)
|
||||
# Update agent
|
||||
if self._step >= self.cfg.seed_steps:
|
||||
if self._step == self.cfg.seed_steps:
|
||||
num_updates = self.cfg.seed_steps
|
||||
print('Pretraining agent on seed data...')
|
||||
else:
|
||||
num_updates = 1
|
||||
for _ in range(num_updates):
|
||||
_train_metrics = self.agent.update(self.buffer)
|
||||
train_metrics.update(_train_metrics)
|
||||
|
||||
self._step += 1
|
||||
self._step += 1
|
||||
|
||||
self.logger.finish(self.agent)
|
||||
|
||||
Reference in New Issue
Block a user