erased unused options

This commit is contained in:
NM512
2024-01-05 23:23:09 +09:00
parent a27711ab96
commit 7f66ed5333
6 changed files with 84 additions and 211 deletions

View File

@@ -42,9 +42,7 @@ class Dreamer(nn.Module):
self._update_count = 0
self._dataset = dataset
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad
)
self._task_behavior = models.ImagBehavior(config, self._wm)
if (
config.compile and os.name != "nt"
): # compilation is not supported on windows
@@ -92,9 +90,7 @@ class Dreamer(nn.Module):
latent, action = state
obs = self._wm.preprocess(obs)
embed = self._wm.encoder(obs)
latent, _ = self._wm.dynamics.obs_step(
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
)
latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
if self._config.eval_state_mean:
latent["stoch"] = latent["mean"]
feat = self._wm.dynamics.get_feat(latent)
@@ -114,21 +110,10 @@ class Dreamer(nn.Module):
action = torch.one_hot(
torch.argmax(action, dim=-1), self._config.num_actions
)
action = self._exploration(action, training)
policy_output = {"action": action, "logprob": logprob}
state = (latent, action)
return policy_output, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
if "onehot" in self._config.actor["dist"]:
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
else:
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
def _train(self, data):
metrics = {}
post, context, mets = self._wm._train(data)