erased unnecessary lines
This commit is contained in:
@@ -58,7 +58,9 @@ class Plan2Explore(nn.Module):
|
||||
"feat": config.dyn_stoch + config.dyn_deter,
|
||||
}[self._config.disag_target]
|
||||
kw = dict(
|
||||
inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version
|
||||
inp_dim=feat_size + config.num_actions
|
||||
if config.disag_action_cond
|
||||
else 0, # pytorch version
|
||||
shape=size,
|
||||
layers=config.disag_layers,
|
||||
units=config.disag_units,
|
||||
@@ -93,7 +95,9 @@ class Plan2Explore(nn.Module):
|
||||
}[self._config.disag_target]
|
||||
inputs = context["feat"]
|
||||
if self._config.disag_action_cond:
|
||||
inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1)
|
||||
inputs = torch.concat(
|
||||
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
||||
)
|
||||
metrics.update(self._train_ensemble(inputs, target))
|
||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||
return None, metrics
|
||||
|
||||
Reference in New Issue
Block a user