separated cache management of episode from env

This commit is contained in:
NM512
2023-07-22 19:22:41 +09:00
parent 88514ec022
commit 9ca5082da3
3 changed files with 194 additions and 167 deletions

View File

@@ -1,89 +1,9 @@
import datetime
import gym
import numpy as np
import uuid
class CollectDataset:
def __init__(
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
self._temp_name = str(uuid.uuid4())
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
if isinstance(action, dict):
transition.update(action)
else:
transition["action"] = action
transition["reward"] = reward
transition["discount"] = info.get("discount", np.array(1 - float(done)))
self._episode.append(transition)
self.add_to_cache(transition)
if done:
# detele transitions before whole episode is stored
del self._cache[self._temp_name]
self._temp_name = str(uuid.uuid4())
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info["episode"] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
# Missing keys will be filled with a zeroed out version of the first
# transition, because we do not know what action information the agent will
# pass yet.
transition["reward"] = 0.0
transition["discount"] = 1.0
self._episode = [transition]
self.add_to_cache(transition)
return obs
def add_to_cache(self, transition):
if self._temp_name not in self._cache:
self._cache[self._temp_name] = dict()
for key, val in transition.items():
self._cache[self._temp_name][key] = [self._convert(val)]
else:
for key, val in transition.items():
if key not in self._cache[self._temp_name]:
# fill missing data(action)
self._cache[self._temp_name][key] = [self._convert(0 * val)]
self._cache[self._temp_name][key].append(self._convert(val))
else:
self._cache[self._temp_name][key].append(self._convert(val))
def _convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, bool):
dtype = bool
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
@@ -208,3 +128,17 @@ class SelectAction:
def step(self, action):
return self._env.step(action[self._key])
class UUID:
def __init__(self, env):
self._env = env
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
def __getattr__(self, name):
return getattr(self._env, name)
def reset(self):
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
return self._env.reset()