minor updates to vectorization
This commit is contained in:
@@ -77,4 +77,8 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
|||||||
cfg.task_dim = 0
|
cfg.task_dim = 0
|
||||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||||
|
|
||||||
|
# Ensure that eval_episodes is divisible by num_envs and is at least 1*num_envs
|
||||||
|
cfg.eval_episodes = max(cfg.eval_episodes, cfg.num_envs)
|
||||||
|
cfg.eval_episodes = cfg.eval_episodes - (cfg.eval_episodes % cfg.num_envs)
|
||||||
|
|
||||||
return cfg_to_dataclass(cfg)
|
return cfg_to_dataclass(cfg)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
from gymnasium.vector import AsyncVectorEnv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,8 +21,7 @@ class Vectorized():
|
|||||||
return env_fn(_cfg)
|
return env_fn(_cfg)
|
||||||
|
|
||||||
print(f'Creating {cfg.num_envs} environments...')
|
print(f'Creating {cfg.num_envs} environments...')
|
||||||
# self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
|
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
|
||||||
self.env = SyncVectorEnv([make for _ in range(cfg.num_envs)])
|
|
||||||
env = make()
|
env = make()
|
||||||
self.observation_space = env.observation_space
|
self.observation_space = env.observation_space
|
||||||
self.action_space = env.action_space
|
self.action_space = env.action_space
|
||||||
|
|||||||
@@ -124,6 +124,8 @@ class OnlineTrainer(Trainer):
|
|||||||
for _ in range(num_updates):
|
for _ in range(num_updates):
|
||||||
_train_metrics = self.agent.update(self.buffer)
|
_train_metrics = self.agent.update(self.buffer)
|
||||||
train_metrics.update(_train_metrics)
|
train_metrics.update(_train_metrics)
|
||||||
|
if self._step == self.cfg.seed_steps:
|
||||||
|
print('Pretraining complete.')
|
||||||
|
|
||||||
self._step += self.cfg.num_envs
|
self._step += self.cfg.num_envs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user