added the option for a deterministic run
This commit is contained in:
27
tools.py
27
tools.py
@@ -6,7 +6,7 @@ import json
|
||||
import pathlib
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -321,7 +321,7 @@ def from_generator(generator, batch_size):
|
||||
|
||||
|
||||
def sample_episodes(episodes, length, seed=0):
|
||||
random = np.random.RandomState(seed)
|
||||
np_random = np.random.RandomState(seed)
|
||||
while True:
|
||||
size = 0
|
||||
ret = None
|
||||
@@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0):
|
||||
)
|
||||
p = p / np.sum(p)
|
||||
while size < length:
|
||||
episode = random.choice(list(episodes.values()), p=p)
|
||||
episode = np_random.choice(list(episodes.values()), p=p)
|
||||
total = len(next(iter(episode.values())))
|
||||
# make sure at least one transition included
|
||||
if total < 2:
|
||||
continue
|
||||
if not ret:
|
||||
index = int(random.randint(0, total - 1))
|
||||
index = int(np_random.randint(0, total - 1))
|
||||
ret = {
|
||||
k: v[index : min(index + length, total)] for k, v in episode.items()
|
||||
k: v[index : min(index + length, total)]
|
||||
for k, v in episode.items()
|
||||
if "log_" not in k
|
||||
}
|
||||
if "is_first" in ret:
|
||||
ret["is_first"][0] = True
|
||||
@@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0):
|
||||
ret[k], v[index : min(index + possible, total)], axis=0
|
||||
)
|
||||
for k, v in episode.items()
|
||||
if "log_" not in k
|
||||
}
|
||||
if "is_first" in ret:
|
||||
ret["is_first"][size] = True
|
||||
@@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None):
|
||||
if prefix:
|
||||
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
|
||||
return metrics
|
||||
|
||||
|
||||
def set_seed_everywhere(seed):
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def enable_deterministic_run():
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
Reference in New Issue
Block a user