applied formatter
This commit is contained in:
19
tools.py
19
tools.py
@@ -122,7 +122,18 @@ class Logger:
|
||||
self._writer.add_video(name, value, step, 16)
|
||||
|
||||
|
||||
def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None):
|
||||
def simulate(
|
||||
agent,
|
||||
envs,
|
||||
cache,
|
||||
directory,
|
||||
logger,
|
||||
is_eval=False,
|
||||
limit=None,
|
||||
steps=0,
|
||||
episodes=0,
|
||||
state=None,
|
||||
):
|
||||
# initialize or unpack simulation state
|
||||
if state is None:
|
||||
step, episode = 0, 0
|
||||
@@ -200,7 +211,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
||||
logger.scalar(f"train_episodes", len(cache))
|
||||
logger.write(step=logger.step)
|
||||
else:
|
||||
if not 'eval_lengths' in locals():
|
||||
if not "eval_lengths" in locals():
|
||||
eval_lengths = []
|
||||
eval_scores = []
|
||||
eval_done = False
|
||||
@@ -278,6 +289,7 @@ class CollectDataset:
|
||||
self.add_to_cache(transition)
|
||||
return obs
|
||||
|
||||
|
||||
def add_to_cache(cache, id, transition):
|
||||
if id not in cache:
|
||||
cache[id] = dict()
|
||||
@@ -292,6 +304,7 @@ def add_to_cache(cache, id, transition):
|
||||
else:
|
||||
cache[id][key].append(convert(val))
|
||||
|
||||
|
||||
def erase_over_episodes(cache, dataset_size):
|
||||
step_in_dataset = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
@@ -304,6 +317,7 @@ def erase_over_episodes(cache, dataset_size):
|
||||
del cache[key]
|
||||
return step_in_dataset
|
||||
|
||||
|
||||
def convert(value, precision=32):
|
||||
value = np.array(value)
|
||||
if np.issubdtype(value.dtype, np.floating):
|
||||
@@ -318,6 +332,7 @@ def convert(value, precision=32):
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
|
||||
|
||||
def save_episodes(directory, episodes):
|
||||
directory = pathlib.Path(directory).expanduser()
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user