avoid ".to(device)"
This commit is contained in:
18
models.py
18
models.py
@@ -14,7 +14,7 @@ class RewardEMA:
|
||||
def __init__(self, device, alpha=1e-2):
|
||||
self.device = device
|
||||
self.alpha = alpha
|
||||
self.range = torch.tensor([0.05, 0.95]).to(device)
|
||||
self.range = torch.tensor([0.05, 0.95], device=device)
|
||||
|
||||
def __call__(self, x, ema_vals):
|
||||
flat_x = torch.flatten(x.detach())
|
||||
@@ -172,18 +172,20 @@ class WorldModel(nn.Module):
|
||||
|
||||
# this function is called during both rollout and training
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
||||
obs = {
|
||||
k: torch.tensor(v, device=self._config.device, dtype=torch.float32)
|
||||
for k, v in obs.items()
|
||||
}
|
||||
obs["image"] = obs["image"] / 255.0
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
||||
obs["discount"] = obs["discount"].unsqueeze(-1)
|
||||
# 'is_first' is necesarry to initialize hidden state at training
|
||||
assert "is_first" in obs
|
||||
# 'is_terminal' is necesarry to train cont_head
|
||||
assert "is_terminal" in obs
|
||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
||||
obs["cont"] = (1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
return obs
|
||||
|
||||
def video_pred(self, data):
|
||||
@@ -277,7 +279,9 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
if self._config.reward_EMA:
|
||||
# register ema_vals to nn.Module for enabling torch.save and torch.load
|
||||
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
|
||||
self.register_buffer(
|
||||
"ema_vals", torch.zeros((2,), device=self._config.device)
|
||||
)
|
||||
self.reward_ema = RewardEMA(device=self._config.device)
|
||||
|
||||
def _train(
|
||||
|
||||
Reference in New Issue
Block a user