Merge branch 'cudagraphs' of https://github.com/vmoens/tdmpc2 into cudagraphs

This commit is contained in:
Nicklas Hansen
2024-10-17 14:57:32 -07:00
2 changed files with 20 additions and 8 deletions

View File

@@ -7,20 +7,16 @@ class RunningScale(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self._value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles)
return dict(value=self.value, percentiles=self._percentiles)
def load_state_dict(self, state_dict):
self._value.copy_(state_dict['value'])
self.value.copy_(state_dict['value'])
self._percentiles.copy_(state_dict['percentiles'])
@property
def value(self):
return self._value
def _positions(self, x_shape):
positions = self._percentiles * (x_shape-1) / 100
floored = torch.floor(positions)
@@ -42,7 +38,7 @@ class RunningScale(torch.nn.Module):
def update(self, x):
percentiles = self._percentile(x.detach())
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
self._value.lerp_(value, self.cfg.tau)
self.value.data.lerp_(value, self.cfg.tau)
def forward(self, x, update=False):
if update:

View File

@@ -23,6 +23,22 @@ from typing import Any
from omegaconf import OmegaConf
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
def cfg_to_dataclass(cfg, frozen=False):
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
cfg_dict = OmegaConf.to_container(cfg)
fields = []
for key, value in cfg_dict.items():
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
# Create the dataclass
dataclass_name = "Config"
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
def get(self, val, default=None):
return getattr(self, val, default)
dataclass.get = get
return dataclass()
def cfg_to_dataclass(cfg, frozen=False):
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks