clean code
This commit is contained in:
2
tools.py
2
tools.py
@@ -441,7 +441,7 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
def sample(self, sample_shape=(), seed=None):
|
||||
if seed is not None:
|
||||
raise ValueError("need to check")
|
||||
sample = super().sample(sample_shape)
|
||||
sample = super().sample(sample_shape).detach()
|
||||
probs = super().probs
|
||||
while len(probs.shape) < len(sample.shape):
|
||||
probs = probs[None]
|
||||
|
||||
Reference in New Issue
Block a user