erased unused options
This commit is contained in:
19
dreamer.py
19
dreamer.py
@@ -42,9 +42,7 @@ class Dreamer(nn.Module):
|
||||
self._update_count = 0
|
||||
self._dataset = dataset
|
||||
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad
|
||||
)
|
||||
self._task_behavior = models.ImagBehavior(config, self._wm)
|
||||
if (
|
||||
config.compile and os.name != "nt"
|
||||
): # compilation is not supported on windows
|
||||
@@ -92,9 +90,7 @@ class Dreamer(nn.Module):
|
||||
latent, action = state
|
||||
obs = self._wm.preprocess(obs)
|
||||
embed = self._wm.encoder(obs)
|
||||
latent, _ = self._wm.dynamics.obs_step(
|
||||
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
|
||||
)
|
||||
latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
|
||||
if self._config.eval_state_mean:
|
||||
latent["stoch"] = latent["mean"]
|
||||
feat = self._wm.dynamics.get_feat(latent)
|
||||
@@ -114,21 +110,10 @@ class Dreamer(nn.Module):
|
||||
action = torch.one_hot(
|
||||
torch.argmax(action, dim=-1), self._config.num_actions
|
||||
)
|
||||
action = self._exploration(action, training)
|
||||
policy_output = {"action": action, "logprob": logprob}
|
||||
state = (latent, action)
|
||||
return policy_output, state
|
||||
|
||||
def _exploration(self, action, training):
|
||||
amount = self._config.expl_amount if training else self._config.eval_noise
|
||||
if amount == 0:
|
||||
return action
|
||||
if "onehot" in self._config.actor["dist"]:
|
||||
probs = amount / self._config.num_actions + (1 - amount) * action
|
||||
return tools.OneHotDist(probs=probs).sample()
|
||||
else:
|
||||
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
|
||||
|
||||
def _train(self, data):
|
||||
metrics = {}
|
||||
post, context, mets = self._wm._train(data)
|
||||
|
||||
Reference in New Issue
Block a user