applied formatter

This commit is contained in:
NM512
2023-07-23 22:02:06 +09:00
parent afa5ab988d
commit 12ed21e06d
10 changed files with 506 additions and 440 deletions

View File

@@ -122,7 +122,18 @@ class Logger:
self._writer.add_video(name, value, step, 16)
def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None):
def simulate(
agent,
envs,
cache,
directory,
logger,
is_eval=False,
limit=None,
steps=0,
episodes=0,
state=None,
):
# initialize or unpack simulation state
if state is None:
step, episode = 0, 0
@@ -200,7 +211,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
logger.scalar(f"train_episodes", len(cache))
logger.write(step=logger.step)
else:
if not 'eval_lengths' in locals():
if not "eval_lengths" in locals():
eval_lengths = []
eval_scores = []
eval_done = False
@@ -278,6 +289,7 @@ class CollectDataset:
self.add_to_cache(transition)
return obs
def add_to_cache(cache, id, transition):
if id not in cache:
cache[id] = dict()
@@ -292,6 +304,7 @@ def add_to_cache(cache, id, transition):
else:
cache[id][key].append(convert(val))
def erase_over_episodes(cache, dataset_size):
step_in_dataset = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
@@ -304,6 +317,7 @@ def erase_over_episodes(cache, dataset_size):
del cache[key]
return step_in_dataset
def convert(value, precision=32):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
@@ -318,6 +332,7 @@ def convert(value, precision=32):
raise NotImplementedError(value.dtype)
return value.astype(dtype)
def save_episodes(directory, episodes):
directory = pathlib.Path(directory).expanduser()
directory.mkdir(parents=True, exist_ok=True)