sampling from the replay buffer across episodes
This commit is contained in:
38
tools.py
38
tools.py
@@ -199,22 +199,38 @@ def from_generator(generator, batch_size):
|
||||
yield data
|
||||
|
||||
|
||||
def sample_episodes(episodes, length=None, balance=False, seed=0):
|
||||
def sample_episodes(episodes, length, seed=0):
|
||||
random = np.random.RandomState(seed)
|
||||
while True:
|
||||
episode = random.choice(list(episodes.values()))
|
||||
if length:
|
||||
size = 0
|
||||
ret = None
|
||||
p = np.array(
|
||||
[len(next(iter(episode.values()))) for episode in episodes.values()]
|
||||
)
|
||||
p = p / np.sum(p)
|
||||
while size < length:
|
||||
episode = random.choice(list(episodes.values()), p=p)
|
||||
total = len(next(iter(episode.values())))
|
||||
available = total - length
|
||||
if available < 1:
|
||||
# print(f"Skipped short episode of length {available}.")
|
||||
# make sure at least one transition included
|
||||
if total < 2:
|
||||
continue
|
||||
if balance:
|
||||
index = min(random.randint(0, total), available)
|
||||
if not ret:
|
||||
index = int(random.randint(0, total - 1))
|
||||
ret = {
|
||||
k: v[index : min(index + length, total)] for k, v in episode.items()
|
||||
}
|
||||
else:
|
||||
index = int(random.randint(0, available + 1))
|
||||
episode = {k: v[index : index + length] for k, v in episode.items()}
|
||||
yield episode
|
||||
# 'is_first' comes after 'is_last'
|
||||
index = 0
|
||||
possible = length - size
|
||||
ret = {
|
||||
k: np.append(
|
||||
ret[k], v[index : min(index + possible, total)], axis=0
|
||||
)
|
||||
for k, v in episode.items()
|
||||
}
|
||||
size = len(next(iter(ret.values())))
|
||||
yield ret
|
||||
|
||||
|
||||
def load_episodes(directory, limit=None, reverse=True):
|
||||
|
||||
Reference in New Issue
Block a user