minor updates to vectorization
This commit is contained in:
@@ -77,4 +77,8 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
||||
cfg.task_dim = 0
|
||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||
|
||||
# Ensure that eval_episodes is divisible by num_envs and is at least 1*num_envs
|
||||
cfg.eval_episodes = max(cfg.eval_episodes, cfg.num_envs)
|
||||
cfg.eval_episodes = cfg.eval_episodes - (cfg.eval_episodes % cfg.num_envs)
|
||||
|
||||
return cfg_to_dataclass(cfg)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -21,8 +21,7 @@ class Vectorized():
|
||||
return env_fn(_cfg)
|
||||
|
||||
print(f'Creating {cfg.num_envs} environments...')
|
||||
# self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
|
||||
self.env = SyncVectorEnv([make for _ in range(cfg.num_envs)])
|
||||
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
|
||||
env = make()
|
||||
self.observation_space = env.observation_space
|
||||
self.action_space = env.action_space
|
||||
|
||||
@@ -124,6 +124,8 @@ class OnlineTrainer(Trainer):
|
||||
for _ in range(num_updates):
|
||||
_train_metrics = self.agent.update(self.buffer)
|
||||
train_metrics.update(_train_metrics)
|
||||
if self._step == self.cfg.seed_steps:
|
||||
print('Pretraining complete.')
|
||||
|
||||
self._step += self.cfg.num_envs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user