step-based counting
This commit is contained in:
@@ -337,6 +337,7 @@ def main(config):
|
||||
acts = train_envs[0].action_space
|
||||
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
||||
|
||||
state = None
|
||||
if not config.offline_traindir:
|
||||
prefill = max(0, config.prefill - count_steps(config.traindir))
|
||||
print(f"Prefill dataset ({prefill} steps).")
|
||||
@@ -358,7 +359,7 @@ def main(config):
|
||||
logprob = random_actor.log_prob(action)
|
||||
return {"action": action, "logprob": logprob}, None
|
||||
|
||||
tools.simulate(random_agent, train_envs, prefill)
|
||||
state = tools.simulate(random_agent, train_envs, prefill)
|
||||
logger.step = config.action_repeat * count_steps(config.traindir)
|
||||
|
||||
print("Simulate agent.")
|
||||
@@ -376,7 +377,6 @@ def main(config):
|
||||
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
||||
agent._should_pretrain._once = False
|
||||
|
||||
state = None
|
||||
while agent._step < config.steps:
|
||||
logger.write()
|
||||
print("Start evaluation.")
|
||||
|
||||
Reference in New Issue
Block a user