avoid ".to(device)"
This commit is contained in:
@@ -16,19 +16,19 @@ class Random(nn.Module):
|
||||
def actor(self, feat):
|
||||
if self._config.actor["dist"] == "onehot":
|
||||
return tools.OneHotDist(
|
||||
torch.zeros(self._config.num_actions)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device)
|
||||
torch.zeros(
|
||||
self._config.num_actions, device=self._config.device
|
||||
).repeat(self._config.envs, 1)
|
||||
)
|
||||
else:
|
||||
return torchd.independent.Independent(
|
||||
torchd.uniform.Uniform(
|
||||
torch.Tensor(self._act_space.low)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
torch.Tensor(self._act_space.high)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
torch.tensor(
|
||||
self._act_space.low, device=self._config.device
|
||||
).repeat(self._config.envs, 1),
|
||||
torch.tensor(
|
||||
self._act_space.high, device=self._config.device
|
||||
).repeat(self._config.envs, 1),
|
||||
),
|
||||
1,
|
||||
)
|
||||
@@ -97,7 +97,8 @@ class Plan2Explore(nn.Module):
|
||||
inputs = context["feat"]
|
||||
if self._config.disag_action_cond:
|
||||
inputs = torch.concat(
|
||||
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
||||
[inputs, torch.tensor(data["action"], device=self._config.device)],
|
||||
-1,
|
||||
)
|
||||
metrics.update(self._train_ensemble(inputs, target))
|
||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||
|
||||
Reference in New Issue
Block a user