erased unnecessary lines

This commit is contained in:
NM512
2023-06-17 15:27:09 +09:00
parent 6c861ca7cb
commit f7c505579c
4 changed files with 12 additions and 8 deletions

View File

@@ -55,7 +55,9 @@ class Dreamer(nn.Module):
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad
)
if config.compile and os.name != 'nt': # compilation is not supported on windows
if (
config.compile and os.name != "nt"
): # compilation is not supported on windows
self._wm = torch.compile(self._wm)
self._task_behavior = torch.compile(self._task_behavior)
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
@@ -156,7 +158,6 @@ class Dreamer(nn.Module):
post, context, mets = self._wm._train(data)
metrics.update(mets)
start = post
# start['deter'] (16, 64, 512)
reward = lambda f, s, a: self._wm.heads["reward"](
self._wm.dynamics.get_feat(s)
).mode()