separated cache management of episode from env

This commit is contained in:
NM512
2023-07-22 19:22:41 +09:00
parent 88514ec022
commit 9ca5082da3
3 changed files with 194 additions and 167 deletions

167
tools.py
View File

@@ -1,6 +1,7 @@
import datetime
import collections
import io
import os
import json
import pathlib
import re
@@ -121,7 +122,7 @@ class Logger:
self._writer.add_video(name, value, step, 16)
def simulate(agent, envs, steps=0, episodes=0, state=None):
def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None):
# Initialize or unpack simulation state.
if state is None:
step, episode = 0, 0
@@ -137,6 +138,14 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
if done.any():
indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
for i in indices:
t = results[i].copy()
t = {k: convert(v) for k, v in t.items()}
# action will be added to transition in add_to_cache
t["reward"] = 0.0
t["discount"] = 1.0
# initial state should be added to cache
add_to_cache(cache, envs[i].id, t)
for index, result in zip(indices, results):
obs[index] = result
reward = [reward[i] * (1 - done[i]) for i in range(len(envs))]
@@ -161,26 +170,165 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
length += 1
step += len(envs)
length *= 1 - done
# Add to cache
for a, result, env in zip(action, results, envs):
o, r, d, info = result
o = {k: convert(v) for k, v in o.items()}
transition = o.copy()
if isinstance(a, dict):
transition.update(a)
else:
transition["action"] = a
transition["reward"] = r
transition["discount"] = info.get("discount", np.array(1 - float(d)))
add_to_cache(cache, env.id, transition)
if done.any():
indices = [index for index, d in enumerate(done) if d]
# logging for done episode
for i in indices:
save_episodes(directory, {envs[i].id: cache[envs[i].id]})
length = len(cache[envs[i].id]["reward"]) - 1
score = float(np.array(cache[envs[i].id]["reward"]).sum())
video = cache[envs[i].id]["image"]
if not is_eval:
step_in_dataset = erase_over_episodes(cache, limit)
logger.scalar(f"dataset_size", step_in_dataset)
logger.scalar(f"train_return", score)
logger.scalar(f"train_length", length)
logger.scalar(f"train_episodes", len(cache))
logger.write(step=logger.step)
else:
if not 'eval_lengths' in locals():
eval_lengths = []
eval_scores = []
eval_done = False
# start counting scores for evaluation
eval_scores.append(score)
eval_lengths.append(length)
score = sum(eval_scores) / len(eval_scores)
length = sum(eval_lengths) / len(eval_lengths)
logger.video(f"eval_policy", np.array(video)[None])
if len(eval_scores) >= episodes and not eval_done:
logger.scalar(f"eval_return", score)
logger.scalar(f"eval_length", length)
logger.scalar(f"eval_episodes", len(eval_scores))
logger.write(step=logger.step)
eval_done = True
if is_eval:
# keep only last item for saving memory. this cache is used for video_pred later
while len(cache) > 1:
# FIFO
cache.popitem(last=False)
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
class CollectDataset:
def __init__(
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
self._temp_name = str(uuid.uuid4())
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
if isinstance(action, dict):
transition.update(action)
else:
transition["action"] = action
transition["reward"] = reward
transition["discount"] = info.get("discount", np.array(1 - float(done)))
self._episode.append(transition)
self.add_to_cache(transition)
if done:
# detele transitions before whole episode is stored
del self._cache[self._temp_name]
self._temp_name = str(uuid.uuid4())
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info["episode"] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
# Missing keys will be filled with a zeroed out version of the first
# transition, because we do not know what action information the agent will
# pass yet.
transition["reward"] = 0.0
transition["discount"] = 1.0
self._episode = [transition]
self.add_to_cache(transition)
return obs
def add_to_cache(cache, id, transition):
if id not in cache:
cache[id] = dict()
for key, val in transition.items():
cache[id][key] = [convert(val)]
else:
for key, val in transition.items():
if key not in cache[id]:
# fill missing data(action, etc.) at second time
cache[id][key] = [convert(0 * val)]
cache[id][key].append(convert(val))
else:
cache[id][key].append(convert(val))
def erase_over_episodes(cache, dataset_size):
step_in_dataset = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if (
not dataset_size
or step_in_dataset + (len(ep["reward"]) - 1) <= dataset_size
):
step_in_dataset += len(ep["reward"]) - 1
else:
del cache[key]
return step_in_dataset
def convert(value, precision=32):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, bool):
dtype = bool
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
def save_episodes(directory, episodes):
directory = pathlib.Path(directory).expanduser()
directory.mkdir(parents=True, exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
filenames = []
for episode in episodes:
identifier = str(uuid.uuid4().hex)
for filename, episode in episodes.items():
length = len(episode["reward"])
filename = directory / f"{timestamp}-{identifier}-{length}.npz"
filename = directory / f"{filename}-{length}.npz"
with io.BytesIO() as f1:
np.savez_compressed(f1, **episode)
f1.seek(0)
with filename.open("wb") as f2:
f2.write(f1.read())
filenames.append(filename)
return filenames
return True
def from_generator(generator, batch_size):
@@ -244,7 +392,8 @@ def load_episodes(directory, limit=None, reverse=True):
except Exception as e:
print(f"Could not load episode: {e}")
continue
episodes[str(filename)] = episode
# extract only filename without extension
episodes[str(os.path.splitext(os.path.basename(filename))[0])] = episode
total += len(episode["reward"]) - 1
if limit and total >= limit:
break