cleaned up envs

This commit is contained in:
NM512
2023-04-15 23:16:43 +09:00
parent fba87a33e0
commit 1e070a3daf
8 changed files with 507 additions and 431 deletions

View File

@@ -16,7 +16,7 @@ sys.path.append(str(pathlib.Path(__file__).parent))
import exploration as expl
import models
import tools
import wrappers
import envs.wrappers as wrappers
import torch
from torch import nn
@@ -189,21 +189,29 @@ def make_dataset(episodes, config):
def make_env(config, logger, mode, train_eps, eval_eps):
suite, task = config.task.split("_", 1)
if suite == "dmc":
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
import envs.dmc as dmc
env = dmc.DeepMindControl(task, config.action_repeat, config.size)
env = wrappers.NormalizeActions(env)
elif suite == "atari":
env = wrappers.Atari(
import envs.atari as atari
env = atari.Atari(
task,
config.action_repeat,
config.size,
grayscale=config.grayscale,
life_done=False and ("train" in mode),
sticky_actions=False,
all_actions=False,
gray=config.grayscale,
noops=config.noops,
lives=config.lives,
sticky=config.stickey,
actions=config.actions,
resize=config.resize,
)
env = wrappers.OneHotAction(env)
elif suite == "dmlab":
env = wrappers.DeepMindLabyrinth(
import envs.dmlab as dmlab
env = dmlab.DeepMindLabyrinth(
task, mode if "train" in mode else "test", config.action_repeat
)
env = wrappers.OneHotAction(env)
@@ -326,7 +334,7 @@ def main(config):
print(f"Prefill dataset ({prefill} steps).")
if hasattr(acts, "discrete"):
random_actor = tools.OneHotDist(
torch.zeros_like(torch.Tensor(acts.low)).repeat(config.envs, 1)
torch.zeros(config.num_actions).repeat(config.envs, 1)
)
else:
random_actor = torchd.independent.Independent(