Merge branch 'cudagraphs' of https://github.com/vmoens/tdmpc2 into cudagraphs
This commit is contained in:
@@ -7,20 +7,16 @@ class RunningScale(torch.nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self._value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
|
||||
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
|
||||
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
|
||||
|
||||
def state_dict(self):
|
||||
return dict(value=self._value, percentiles=self._percentiles)
|
||||
return dict(value=self.value, percentiles=self._percentiles)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._value.copy_(state_dict['value'])
|
||||
self.value.copy_(state_dict['value'])
|
||||
self._percentiles.copy_(state_dict['percentiles'])
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
def _positions(self, x_shape):
|
||||
positions = self._percentiles * (x_shape-1) / 100
|
||||
floored = torch.floor(positions)
|
||||
@@ -42,7 +38,7 @@ class RunningScale(torch.nn.Module):
|
||||
def update(self, x):
|
||||
percentiles = self._percentile(x.detach())
|
||||
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
|
||||
self._value.lerp_(value, self.cfg.tau)
|
||||
self.value.data.lerp_(value, self.cfg.tau)
|
||||
|
||||
def forward(self, x, update=False):
|
||||
if update:
|
||||
|
||||
@@ -23,6 +23,22 @@ from typing import Any
|
||||
from omegaconf import OmegaConf
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def cfg_to_dataclass(cfg, frozen=False):
|
||||
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||
cfg_dict = OmegaConf.to_container(cfg)
|
||||
fields = []
|
||||
for key, value in cfg_dict.items():
|
||||
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||
|
||||
# Create the dataclass
|
||||
dataclass_name = "Config"
|
||||
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||
def get(self, val, default=None):
|
||||
return getattr(self, val, default)
|
||||
dataclass.get = get
|
||||
return dataclass()
|
||||
|
||||
def cfg_to_dataclass(cfg, frozen=False):
|
||||
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||
|
||||
Reference in New Issue
Block a user