diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index a8d9f25..451dac1 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -77,4 +77,8 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) + # Ensure that eval_episodes is divisible by num_envs and is at least 1*num_envs + cfg.eval_episodes = max(cfg.eval_episodes, cfg.num_envs) + cfg.eval_episodes = cfg.eval_episodes - (cfg.eval_episodes % cfg.num_envs) + return cfg_to_dataclass(cfg) diff --git a/tdmpc2/envs/wrappers/vectorized.py b/tdmpc2/envs/wrappers/vectorized.py index 4dadc9b..560c8ce 100644 --- a/tdmpc2/envs/wrappers/vectorized.py +++ b/tdmpc2/envs/wrappers/vectorized.py @@ -1,6 +1,6 @@ from copy import deepcopy -from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv +from gymnasium.vector import AsyncVectorEnv import numpy as np import torch @@ -21,8 +21,7 @@ class Vectorized(): return env_fn(_cfg) print(f'Creating {cfg.num_envs} environments...') - # self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)]) - self.env = SyncVectorEnv([make for _ in range(cfg.num_envs)]) + self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)]) env = make() self.observation_space = env.observation_space self.action_space = env.action_space diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 3cee3f9..71773ca 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -124,6 +124,8 @@ class OnlineTrainer(Trainer): for _ in range(num_updates): _train_metrics = self.agent.update(self.buffer) train_metrics.update(_train_metrics) + if self._step == self.cfg.seed_steps: + print('Pretraining complete.') self._step += self.cfg.num_envs