minor updates to vectorization

This commit is contained in:
Nicklas Hansen
2025-05-21 16:06:45 -07:00
parent a586d8f393
commit 97c1447199
3 changed files with 8 additions and 3 deletions

View File

@@ -77,4 +77,8 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
# Ensure that eval_episodes is divisible by num_envs and is at least 1*num_envs
cfg.eval_episodes = max(cfg.eval_episodes, cfg.num_envs)
cfg.eval_episodes = cfg.eval_episodes - (cfg.eval_episodes % cfg.num_envs)
return cfg_to_dataclass(cfg)

View File

@@ -1,6 +1,6 @@
from copy import deepcopy
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from gymnasium.vector import AsyncVectorEnv
import numpy as np
import torch
@@ -21,8 +21,7 @@ class Vectorized():
return env_fn(_cfg)
print(f'Creating {cfg.num_envs} environments...')
# self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
self.env = SyncVectorEnv([make for _ in range(cfg.num_envs)])
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
env = make()
self.observation_space = env.observation_space
self.action_space = env.action_space

View File

@@ -124,6 +124,8 @@ class OnlineTrainer(Trainer):
for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
if self._step == self.cfg.seed_steps:
print('Pretraining complete.')
self._step += self.cfg.num_envs