From 1328ff10881891dd3f5ad1984f15a4897a8d7a83 Mon Sep 17 00:00:00 2001 From: NM512 Date: Sat, 29 Apr 2023 07:43:02 +0900 Subject: [PATCH] sampling from the replay buffer across episodes --- configs.yaml | 3 +-- dreamer.py | 4 +--- tools.py | 38 +++++++++++++++++++++++++++----------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/configs.yaml b/configs.yaml index 30943cb..8168060 100644 --- a/configs.yaml +++ b/configs.yaml @@ -12,7 +12,7 @@ defaults: log_every: 1e4 reset_every: 0 device: 'cuda:0' - compile: False + compile: True precision: 16 debug: False expl_gifs: False @@ -78,7 +78,6 @@ defaults: value_grad_clip: 100 actor_grad_clip: 100 dataset_size: 1000000 - oversample_ends: True slow_value_target: True slow_target_update: 1 slow_target_fraction: 0.02 diff --git a/dreamer.py b/dreamer.py index a84105a..a41854d 100644 --- a/dreamer.py +++ b/dreamer.py @@ -174,9 +174,7 @@ def count_steps(folder): def make_dataset(episodes, config): - generator = tools.sample_episodes( - episodes, config.batch_length, config.oversample_ends - ) + generator = tools.sample_episodes(episodes, config.batch_length) dataset = tools.from_generator(generator, config.batch_size) return dataset diff --git a/tools.py b/tools.py index de7ccc4..3b8a912 100644 --- a/tools.py +++ b/tools.py @@ -199,22 +199,38 @@ def from_generator(generator, batch_size): yield data -def sample_episodes(episodes, length=None, balance=False, seed=0): +def sample_episodes(episodes, length, seed=0): random = np.random.RandomState(seed) while True: - episode = random.choice(list(episodes.values())) - if length: + size = 0 + ret = None + p = np.array( + [len(next(iter(episode.values()))) for episode in episodes.values()] + ) + p = p / np.sum(p) + while size < length: + episode = random.choice(list(episodes.values()), p=p) total = len(next(iter(episode.values()))) - available = total - length - if available < 1: - # print(f"Skipped short episode of length {available}.") + # make sure at least one transition included + if total < 2: continue - if balance: - index = min(random.randint(0, total), available) + if not ret: + index = int(random.randint(0, total - 1)) + ret = { + k: v[index : min(index + length, total)] for k, v in episode.items() + } else: - index = int(random.randint(0, available + 1)) - episode = {k: v[index : index + length] for k, v in episode.items()} - yield episode + # 'is_first' comes after 'is_last' + index = 0 + possible = length - size + ret = { + k: np.append( + ret[k], v[index : min(index + possible, total)], axis=0 + ) + for k, v in episode.items() + } + size = len(next(iter(ret.values()))) + yield ret def load_episodes(directory, limit=None, reverse=True):