Merge branch 'cudagraphs' of https://github.com/vmoens/tdmpc2 into cudagraphs
This commit is contained in:
@@ -7,20 +7,16 @@ class RunningScale(torch.nn.Module):
|
|||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
|
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
|
||||||
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
|
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return dict(value=self._value, percentiles=self._percentiles)
|
return dict(value=self.value, percentiles=self._percentiles)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self._value.copy_(state_dict['value'])
|
self.value.copy_(state_dict['value'])
|
||||||
self._percentiles.copy_(state_dict['percentiles'])
|
self._percentiles.copy_(state_dict['percentiles'])
|
||||||
|
|
||||||
@property
|
|
||||||
def value(self):
|
|
||||||
return self._value
|
|
||||||
|
|
||||||
def _positions(self, x_shape):
|
def _positions(self, x_shape):
|
||||||
positions = self._percentiles * (x_shape-1) / 100
|
positions = self._percentiles * (x_shape-1) / 100
|
||||||
floored = torch.floor(positions)
|
floored = torch.floor(positions)
|
||||||
@@ -42,7 +38,7 @@ class RunningScale(torch.nn.Module):
|
|||||||
def update(self, x):
|
def update(self, x):
|
||||||
percentiles = self._percentile(x.detach())
|
percentiles = self._percentile(x.detach())
|
||||||
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
|
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
|
||||||
self._value.lerp_(value, self.cfg.tau)
|
self.value.data.lerp_(value, self.cfg.tau)
|
||||||
|
|
||||||
def forward(self, x, update=False):
|
def forward(self, x, update=False):
|
||||||
if update:
|
if update:
|
||||||
|
|||||||
@@ -23,6 +23,22 @@ from typing import Any
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
|
def cfg_to_dataclass(cfg, frozen=False):
|
||||||
|
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||||
|
cfg_dict = OmegaConf.to_container(cfg)
|
||||||
|
fields = []
|
||||||
|
for key, value in cfg_dict.items():
|
||||||
|
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||||
|
|
||||||
|
# Create the dataclass
|
||||||
|
dataclass_name = "Config"
|
||||||
|
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||||
|
def get(self, val, default=None):
|
||||||
|
return getattr(self, val, default)
|
||||||
|
dataclass.get = get
|
||||||
|
return dataclass()
|
||||||
|
|
||||||
def cfg_to_dataclass(cfg, frozen=False):
|
def cfg_to_dataclass(cfg, frozen=False):
|
||||||
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||||
|
|||||||
Reference in New Issue
Block a user