avoid ".to(device)"

This commit is contained in:
NM512
2024-09-28 07:58:15 +09:00
parent 669b7e1b43
commit 7433d1e877
5 changed files with 37 additions and 33 deletions

View File

@@ -16,19 +16,19 @@ class Random(nn.Module):
def actor(self, feat):
if self._config.actor["dist"] == "onehot":
return tools.OneHotDist(
torch.zeros(self._config.num_actions)
.repeat(self._config.envs, 1)
.to(self._config.device)
torch.zeros(
self._config.num_actions, device=self._config.device
).repeat(self._config.envs, 1)
)
else:
return torchd.independent.Independent(
torchd.uniform.Uniform(
torch.Tensor(self._act_space.low)
.repeat(self._config.envs, 1)
.to(self._config.device),
torch.Tensor(self._act_space.high)
.repeat(self._config.envs, 1)
.to(self._config.device),
torch.tensor(
self._act_space.low, device=self._config.device
).repeat(self._config.envs, 1),
torch.tensor(
self._act_space.high, device=self._config.device
).repeat(self._config.envs, 1),
),
1,
)
@@ -97,7 +97,8 @@ class Plan2Explore(nn.Module):
inputs = context["feat"]
if self._config.disag_action_cond:
inputs = torch.concat(
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
[inputs, torch.tensor(data["action"], device=self._config.device)],
-1,
)
metrics.update(self._train_ensemble(inputs, target))
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])