added the option for a deterministic run
This commit is contained in:
17
dreamer.py
17
dreamer.py
@@ -186,7 +186,9 @@ def make_env(config, mode):
|
||||
if suite == "dmc":
|
||||
import envs.dmc as dmc
|
||||
|
||||
env = dmc.DeepMindControl(task, config.action_repeat, config.size)
|
||||
env = dmc.DeepMindControl(
|
||||
task, config.action_repeat, config.size, seed=config.seed
|
||||
)
|
||||
env = wrappers.NormalizeActions(env)
|
||||
elif suite == "atari":
|
||||
import envs.atari as atari
|
||||
@@ -201,24 +203,28 @@ def make_env(config, mode):
|
||||
sticky=config.stickey,
|
||||
actions=config.actions,
|
||||
resize=config.resize,
|
||||
seed=config.seed,
|
||||
)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "dmlab":
|
||||
import envs.dmlab as dmlab
|
||||
|
||||
env = dmlab.DeepMindLabyrinth(
|
||||
task, mode if "train" in mode else "test", config.action_repeat
|
||||
task,
|
||||
mode if "train" in mode else "test",
|
||||
config.action_repeat,
|
||||
seed=config.seed,
|
||||
)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "MemoryMaze":
|
||||
from envs.memorymaze import MemoryMaze
|
||||
|
||||
env = MemoryMaze(task)
|
||||
env = MemoryMaze(task, seed=config.seed)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "crafter":
|
||||
import envs.crafter as crafter
|
||||
|
||||
env = crafter.Crafter(task, config.size)
|
||||
env = crafter.Crafter(task, config.size, seed=config.seed)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "minecraft":
|
||||
import envs.minecraft as minecraft
|
||||
@@ -236,6 +242,9 @@ def make_env(config, mode):
|
||||
|
||||
|
||||
def main(config):
|
||||
tools.set_seed_everywhere(config.seed)
|
||||
if config.deterministic_run:
|
||||
tools.enable_deterministic_run()
|
||||
logdir = pathlib.Path(config.logdir).expanduser()
|
||||
config.traindir = config.traindir or logdir / "train_eps"
|
||||
config.evaldir = config.evaldir or logdir / "eval_eps"
|
||||
|
||||
Reference in New Issue
Block a user