avoid ".to(device)"
This commit is contained in:
20
networks.py
20
networks.py
@@ -97,22 +97,22 @@ class RSSM(nn.Module):
|
||||
)
|
||||
|
||||
def initial(self, batch_size):
|
||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
||||
deter = torch.zeros(batch_size, self._deter, device=self._device)
|
||||
if self._discrete:
|
||||
state = dict(
|
||||
logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(
|
||||
self._device
|
||||
logit=torch.zeros(
|
||||
[batch_size, self._stoch, self._discrete], device=self._device
|
||||
),
|
||||
stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(
|
||||
self._device
|
||||
stoch=torch.zeros(
|
||||
[batch_size, self._stoch, self._discrete], device=self._device
|
||||
),
|
||||
deter=deter,
|
||||
)
|
||||
else:
|
||||
state = dict(
|
||||
mean=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
std=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
mean=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||
std=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||
stoch=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||
deter=deter,
|
||||
)
|
||||
if self._initial == "zeros":
|
||||
@@ -175,8 +175,8 @@ class RSSM(nn.Module):
|
||||
# initialize all prev_state
|
||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||
prev_state = self.initial(len(is_first))
|
||||
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
|
||||
self._device
|
||||
prev_action = torch.zeros(
|
||||
(len(is_first), self._num_actions), device=self._device
|
||||
)
|
||||
# overwrite the prev_state only where is_first=True
|
||||
elif torch.sum(is_first) > 0:
|
||||
|
||||
Reference in New Issue
Block a user