Merge branch 'main' into memmaze
This commit is contained in:
10
dreamer.py
10
dreamer.py
@@ -55,7 +55,9 @@ class Dreamer(nn.Module):
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad
|
||||
)
|
||||
if config.compile and os.name != 'nt': # compilation is not supported on windows
|
||||
if (
|
||||
config.compile and os.name != "nt"
|
||||
): # compilation is not supported on windows
|
||||
self._wm = torch.compile(self._wm)
|
||||
self._task_behavior = torch.compile(self._task_behavior)
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
|
||||
@@ -156,7 +158,6 @@ class Dreamer(nn.Module):
|
||||
post, context, mets = self._wm._train(data)
|
||||
metrics.update(mets)
|
||||
start = post
|
||||
# start['deter'] (16, 64, 512)
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||
self._wm.dynamics.get_feat(s)
|
||||
).mode()
|
||||
@@ -221,6 +222,11 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
from envs.memorymaze import MemoryMaze
|
||||
env = MemoryMaze(env)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "crafter":
|
||||
import envs.crafter as crafter
|
||||
|
||||
env = crafter.Crafter(task, config.size)
|
||||
env = wrappers.OneHotAction(env)
|
||||
else:
|
||||
raise NotImplementedError(suite)
|
||||
env = wrappers.TimeLimit(env, config.time_limit)
|
||||
|
||||
Reference in New Issue
Block a user