avoid ".to(device)"

This commit is contained in:
NM512
2024-09-28 07:58:15 +09:00
parent 669b7e1b43
commit 7433d1e877
5 changed files with 37 additions and 33 deletions

View File

@@ -14,7 +14,7 @@ class RewardEMA:
def __init__(self, device, alpha=1e-2):
self.device = device
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
self.range = torch.tensor([0.05, 0.95], device=device)
def __call__(self, x, ema_vals):
flat_x = torch.flatten(x.detach())
@@ -172,18 +172,20 @@ class WorldModel(nn.Module):
# this function is called during both rollout and training
def preprocess(self, obs):
obs = obs.copy()
obs["image"] = torch.Tensor(obs["image"]) / 255.0
obs = {
k: torch.tensor(v, device=self._config.device, dtype=torch.float32)
for k, v in obs.items()
}
obs["image"] = obs["image"] / 255.0
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
obs["discount"] = obs["discount"].unsqueeze(-1)
# 'is_first' is necesarry to initialize hidden state at training
assert "is_first" in obs
# 'is_terminal' is necesarry to train cont_head
assert "is_terminal" in obs
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
obs["cont"] = (1.0 - obs["is_terminal"]).unsqueeze(-1)
return obs
def video_pred(self, data):
@@ -277,7 +279,9 @@ class ImagBehavior(nn.Module):
)
if self._config.reward_EMA:
# register ema_vals to nn.Module for enabling torch.save and torch.load
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
self.register_buffer(
"ema_vals", torch.zeros((2,), device=self._config.device)
)
self.reward_ema = RewardEMA(device=self._config.device)
def _train(