put running episode into replay buffer
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import uuid
|
||||
|
||||
|
||||
class CollectDataset:
|
||||
def __init__(self, env, callbacks=None, precision=32):
|
||||
def __init__(
|
||||
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
|
||||
):
|
||||
self._env = env
|
||||
self._callbacks = callbacks or ()
|
||||
self._precision = precision
|
||||
self._episode = None
|
||||
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
self._temp_name = str(uuid.uuid4())
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
@@ -23,7 +28,11 @@ class CollectDataset:
|
||||
transition["reward"] = reward
|
||||
transition["discount"] = info.get("discount", np.array(1 - float(done)))
|
||||
self._episode.append(transition)
|
||||
self.add_to_cache(transition)
|
||||
if done:
|
||||
# detele transitions before whole episode is stored
|
||||
del self._cache[self._temp_name]
|
||||
self._temp_name = str(uuid.uuid4())
|
||||
for key, value in self._episode[1].items():
|
||||
if key not in self._episode[0]:
|
||||
self._episode[0][key] = 0 * value
|
||||
@@ -43,8 +52,23 @@ class CollectDataset:
|
||||
transition["reward"] = 0.0
|
||||
transition["discount"] = 1.0
|
||||
self._episode = [transition]
|
||||
self.add_to_cache(transition)
|
||||
return obs
|
||||
|
||||
def add_to_cache(self, transition):
|
||||
if self._temp_name not in self._cache:
|
||||
self._cache[self._temp_name] = dict()
|
||||
for key, val in transition.items():
|
||||
self._cache[self._temp_name][key] = [self._convert(val)]
|
||||
else:
|
||||
for key, val in transition.items():
|
||||
if key not in self._cache[self._temp_name]:
|
||||
# fill missing data(action)
|
||||
self._cache[self._temp_name][key] = [self._convert(0 * val)]
|
||||
self._cache[self._temp_name][key].append(self._convert(val))
|
||||
else:
|
||||
self._cache[self._temp_name][key].append(self._convert(val))
|
||||
|
||||
def _convert(self, value):
|
||||
value = np.array(value)
|
||||
if np.issubdtype(value.dtype, np.floating):
|
||||
|
||||
Reference in New Issue
Block a user