clean code

This commit is contained in:
NM512
2024-09-24 00:16:12 +09:00
parent 4e50f302cd
commit 59939222d1
2 changed files with 7 additions and 7 deletions

View File

@@ -441,7 +441,7 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
def sample(self, sample_shape=(), seed=None):
if seed is not None:
raise ValueError("need to check")
sample = super().sample(sample_shape)
sample = super().sample(sample_shape).detach()
probs = super().probs
while len(probs.shape) < len(sample.shape):
probs = probs[None]