erased unnecessary lines

This commit is contained in:
NM512
2023-06-17 15:27:09 +09:00
parent 6c861ca7cb
commit f7c505579c
4 changed files with 12 additions and 8 deletions

View File

@@ -58,7 +58,9 @@ class Plan2Explore(nn.Module):
"feat": config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target]
kw = dict(
inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version
inp_dim=feat_size + config.num_actions
if config.disag_action_cond
else 0, # pytorch version
shape=size,
layers=config.disag_layers,
units=config.disag_units,
@@ -93,7 +95,9 @@ class Plan2Explore(nn.Module):
}[self._config.disag_target]
inputs = context["feat"]
if self._config.disag_action_cond:
inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1)
inputs = torch.concat(
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
)
metrics.update(self._train_ensemble(inputs, target))
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
return None, metrics