Merge branch 'main' into memmaze

This commit is contained in:
NM512
2023-06-18 16:27:05 +09:00
committed by GitHub
9 changed files with 141 additions and 23 deletions

View File

@@ -55,7 +55,9 @@ class Dreamer(nn.Module):
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad
)
if config.compile and os.name != 'nt': # compilation is not supported on windows
if (
config.compile and os.name != "nt"
): # compilation is not supported on windows
self._wm = torch.compile(self._wm)
self._task_behavior = torch.compile(self._task_behavior)
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
@@ -156,7 +158,6 @@ class Dreamer(nn.Module):
post, context, mets = self._wm._train(data)
metrics.update(mets)
start = post
# start['deter'] (16, 64, 512)
reward = lambda f, s, a: self._wm.heads["reward"](
self._wm.dynamics.get_feat(s)
).mode()
@@ -221,6 +222,11 @@ def make_env(config, logger, mode, train_eps, eval_eps):
from envs.memorymaze import MemoryMaze
env = MemoryMaze(env)
env = wrappers.OneHotAction(env)
elif suite == "crafter":
import envs.crafter as crafter
env = crafter.Crafter(task, config.size)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)