avoid ".to(device)"

This commit is contained in:
NM512
2024-09-28 07:58:15 +09:00
parent 669b7e1b43
commit 7433d1e877
5 changed files with 37 additions and 33 deletions

View File

@@ -97,22 +97,22 @@ class RSSM(nn.Module):
)
def initial(self, batch_size):
deter = torch.zeros(batch_size, self._deter).to(self._device)
deter = torch.zeros(batch_size, self._deter, device=self._device)
if self._discrete:
state = dict(
logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(
self._device
logit=torch.zeros(
[batch_size, self._stoch, self._discrete], device=self._device
),
stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(
self._device
stoch=torch.zeros(
[batch_size, self._stoch, self._discrete], device=self._device
),
deter=deter,
)
else:
state = dict(
mean=torch.zeros([batch_size, self._stoch]).to(self._device),
std=torch.zeros([batch_size, self._stoch]).to(self._device),
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
mean=torch.zeros([batch_size, self._stoch], device=self._device),
std=torch.zeros([batch_size, self._stoch], device=self._device),
stoch=torch.zeros([batch_size, self._stoch], device=self._device),
deter=deter,
)
if self._initial == "zeros":
@@ -175,8 +175,8 @@ class RSSM(nn.Module):
# initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first):
prev_state = self.initial(len(is_first))
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
self._device
prev_action = torch.zeros(
(len(is_first), self._num_actions), device=self._device
)
# overwrite the prev_state only where is_first=True
elif torch.sum(is_first) > 0: