changed the discount head to predict terminal

This commit is contained in:
NM512
2023-04-22 09:34:23 +09:00
parent 16151efb3c
commit 628b856c63
4 changed files with 50 additions and 50 deletions

View File

@@ -155,16 +155,11 @@ class Dreamer(nn.Module):
metrics.update(mets)
start = post
# start['deter'] (16, 64, 512)
if self._config.pred_discount: # Last step could be terminal.
start = {k: v[:, :-1] for k, v in post.items()}
context = {k: v[:, :-1] for k, v in context.items()}
reward = lambda f, s, a: self._wm.heads["reward"](
self._wm.dynamics.get_feat(s)
).mode()
metrics.update(self._task_behavior._train(start, reward)[-1])
if self._config.expl_behavior != "greedy":
if self._config.pred_discount:
data = {k: v[:, :-1] for k, v in data.items()}
mets = self._expl_behavior.train(start, context, data)[-1]
metrics.update({"expl_" + key: value for key, value in mets.items()})
for name, value in metrics.items():