replaced all tf function to torch

This commit is contained in:
NM512
2023-04-03 08:06:34 +09:00
parent 8bd69bfcd4
commit 57ac1c11d3
3 changed files with 56 additions and 45 deletions

View File

@@ -62,7 +62,7 @@ class Dreamer(nn.Module):
greedy=lambda: self._task_behavior,
random=lambda: expl.Random(config),
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
)[config.expl_behavior]()
)[config.expl_behavior]().to(self._config.device)
def __call__(self, obs, reset, state=None, reward=None, training=True):
step = self._step