diff --git a/tdmpc2/common/scale.py b/tdmpc2/common/scale.py index a9f4654..8fd1740 100644 --- a/tdmpc2/common/scale.py +++ b/tdmpc2/common/scale.py @@ -7,20 +7,16 @@ class RunningScale(torch.nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg - self._value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))) + self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))) self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))) def state_dict(self): - return dict(value=self._value, percentiles=self._percentiles) + return dict(value=self.value, percentiles=self._percentiles) def load_state_dict(self, state_dict): - self._value.copy_(state_dict['value']) + self.value.copy_(state_dict['value']) self._percentiles.copy_(state_dict['percentiles']) - @property - def value(self): - return self._value - def _positions(self, x_shape): positions = self._percentiles * (x_shape-1) / 100 floored = torch.floor(positions) @@ -42,7 +38,7 @@ class RunningScale(torch.nn.Module): def update(self, x): percentiles = self._percentile(x.detach()) value = torch.clamp(percentiles[1] - percentiles[0], min=1.) - self._value.lerp_(value, self.cfg.tau) + self.value.data.lerp_(value, self.cfg.tau) def forward(self, x, update=False): if update: diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 6afb648..48206ec 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -23,6 +23,22 @@ from typing import Any from omegaconf import OmegaConf torch.backends.cudnn.benchmark = True +torch.set_float32_matmul_precision('high') + +def cfg_to_dataclass(cfg, frozen=False): + # Converts an OmegaConf config to a dataclass, which will not cause graph breaks + cfg_dict = OmegaConf.to_container(cfg) + fields = [] + for key, value in cfg_dict.items(): + fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) + + # Create the dataclass + dataclass_name = "Config" + dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) + def get(self, val, default=None): + return getattr(self, val, default) + dataclass.get = get + return dataclass() def cfg_to_dataclass(cfg, frozen=False): # Converts an OmegaConf config to a dataclass, which will not cause graph breaks