clean up prints
This commit is contained in:
@@ -1,34 +0,0 @@
|
|||||||
absl-py
|
|
||||||
cython
|
|
||||||
dm-control
|
|
||||||
ffmpeg
|
|
||||||
glfw
|
|
||||||
hydra-core
|
|
||||||
hydra-submitit-launcher
|
|
||||||
imageio
|
|
||||||
imageio-ffmpeg
|
|
||||||
kornia
|
|
||||||
moviepy
|
|
||||||
mujoco
|
|
||||||
mujoco-py
|
|
||||||
numpy<2
|
|
||||||
omegaconf
|
|
||||||
open3d
|
|
||||||
opencv-contrib-python
|
|
||||||
opencv-python
|
|
||||||
pandas
|
|
||||||
sapien
|
|
||||||
submitit
|
|
||||||
setuptools
|
|
||||||
patchelf
|
|
||||||
protobuf
|
|
||||||
pillow
|
|
||||||
pyquaternion
|
|
||||||
tensordict-nightly
|
|
||||||
termcolor
|
|
||||||
torchrl-nightly
|
|
||||||
transforms3d
|
|
||||||
trimesh
|
|
||||||
tqdm
|
|
||||||
wandb
|
|
||||||
wheel
|
|
||||||
@@ -5,7 +5,6 @@ import re
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torchrl._utils import timeit
|
|
||||||
from common import TASK_SET
|
from common import TASK_SET
|
||||||
|
|
||||||
|
|
||||||
@@ -238,5 +237,3 @@ class Logger:
|
|||||||
self._log_dir / "eval.csv", header=keys, index=None
|
self._log_dir / "eval.csv", header=keys, index=None
|
||||||
)
|
)
|
||||||
self._print(d, category)
|
self._print(d, category)
|
||||||
timeit.print()
|
|
||||||
timeit.erase()
|
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from common import layers, math, init
|
from common import layers, math, init
|
||||||
from tensordict import TensorDict
|
|
||||||
from tensordict.nn import TensorDictParams
|
from tensordict.nn import TensorDictParams
|
||||||
|
|
||||||
class WorldModel(nn.Module):
|
class WorldModel(nn.Module):
|
||||||
@@ -48,6 +46,14 @@ class WorldModel(nn.Module):
|
|||||||
self._detach_Qs.params = self._detach_Qs_params
|
self._detach_Qs.params = self._detach_Qs_params
|
||||||
self._target_Qs.params = self._target_Qs_params
|
self._target_Qs.params = self._target_Qs_params
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr = 'TD-MPC2 World Model\n'
|
||||||
|
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
|
||||||
|
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
|
||||||
|
repr += f"{modules[i]}: {m}\n"
|
||||||
|
repr += "Learnable parameters: {:,}".format(self.total_params)
|
||||||
|
return repr
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_params(self):
|
def total_params(self):
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from torchrl._utils import timeit
|
|
||||||
|
|
||||||
from common import math
|
from common import math
|
||||||
from common.scale import RunningScale
|
from common.scale import RunningScale
|
||||||
@@ -280,8 +279,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary of training statistics.
|
dict: Dictionary of training statistics.
|
||||||
"""
|
"""
|
||||||
with timeit("sample"):
|
obs, action, reward, task = buffer.sample()
|
||||||
obs, action, reward, task = buffer.sample()
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if task is not None:
|
if task is not None:
|
||||||
kwargs["task"] = task
|
kwargs["task"] = task
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ class Trainer:
|
|||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
print('Architecture:', self.agent.model)
|
print('Architecture:', self.agent.model)
|
||||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Evaluate a TD-MPC2 agent."""
|
"""Evaluate a TD-MPC2 agent."""
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from time import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensordict.tensordict import TensorDict
|
from tensordict.tensordict import TensorDict
|
||||||
from torchrl._utils import timeit
|
|
||||||
from trainer.base import Trainer
|
from trainer.base import Trainer
|
||||||
|
|
||||||
|
|
||||||
@@ -68,53 +67,49 @@ class OnlineTrainer(Trainer):
|
|||||||
"""Train a TD-MPC2 agent."""
|
"""Train a TD-MPC2 agent."""
|
||||||
train_metrics, done, eval_next = {}, True, False
|
train_metrics, done, eval_next = {}, True, False
|
||||||
while self._step <= self.cfg.steps:
|
while self._step <= self.cfg.steps:
|
||||||
with timeit("global-step"):
|
# Evaluate agent periodically
|
||||||
# Evaluate agent periodically
|
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
|
||||||
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
|
eval_next = True
|
||||||
eval_next = True
|
|
||||||
|
|
||||||
# Reset environment
|
# Reset environment
|
||||||
if done or (self._step == self.cfg.seed_steps + 1):
|
if done or (self._step == self.cfg.seed_steps + 1):
|
||||||
if eval_next:
|
if eval_next:
|
||||||
eval_metrics = self.eval()
|
eval_metrics = self.eval()
|
||||||
eval_metrics.update(self.common_metrics())
|
eval_metrics.update(self.common_metrics())
|
||||||
self.logger.log(eval_metrics, 'eval')
|
self.logger.log(eval_metrics, 'eval')
|
||||||
eval_next = False
|
eval_next = False
|
||||||
|
|
||||||
if self._step > 0:
|
if self._step > 0:
|
||||||
train_metrics.update(
|
train_metrics.update(
|
||||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||||
episode_success=info['success'],
|
episode_success=info['success'],
|
||||||
)
|
)
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
train_metrics.update(timeit.todict())
|
self.logger.log(train_metrics, 'train')
|
||||||
self.logger.log(train_metrics, 'train')
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
|
||||||
|
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
self._tds = [self.to_td(obs)]
|
self._tds = [self.to_td(obs)]
|
||||||
|
|
||||||
# Collect experience
|
# Collect experience
|
||||||
with timeit("act"):
|
if self._step > self.cfg.seed_steps:
|
||||||
if self._step > self.cfg.seed_steps:
|
action = self.agent.act(obs, t0=len(self._tds)==1)
|
||||||
action = self.agent.act(obs, t0=len(self._tds)==1)
|
else:
|
||||||
else:
|
action = self.env.rand_act()
|
||||||
action = self.env.rand_act()
|
obs, reward, done, info = self.env.step(action)
|
||||||
obs, reward, done, info = self.env.step(action)
|
self._tds.append(self.to_td(obs, action, reward))
|
||||||
self._tds.append(self.to_td(obs, action, reward))
|
|
||||||
|
|
||||||
# Update agent
|
# Update agent
|
||||||
if self._step >= self.cfg.seed_steps:
|
if self._step >= self.cfg.seed_steps:
|
||||||
if self._step == self.cfg.seed_steps:
|
if self._step == self.cfg.seed_steps:
|
||||||
num_updates = self.cfg.seed_steps
|
num_updates = self.cfg.seed_steps
|
||||||
print('Pretraining agent on seed data...')
|
print('Pretraining agent on seed data...')
|
||||||
else:
|
else:
|
||||||
num_updates = 1
|
num_updates = 1
|
||||||
for _ in range(num_updates):
|
for _ in range(num_updates):
|
||||||
with timeit("update"):
|
_train_metrics = self.agent.update(self.buffer)
|
||||||
_train_metrics = self.agent.update(self.buffer)
|
train_metrics.update(_train_metrics)
|
||||||
train_metrics.update(_train_metrics)
|
|
||||||
|
|
||||||
self._step += 1
|
self._step += 1
|
||||||
|
|
||||||
self.logger.finish(self.agent)
|
self.logger.finish(self.agent)
|
||||||
|
|||||||
Reference in New Issue
Block a user