modified loss calculation

This commit is contained in:
NM512
2024-01-05 10:44:04 +09:00
parent e0487f8206
commit 78e86703f4
3 changed files with 21 additions and 25 deletions

View File

@@ -327,8 +327,9 @@ class RSSM(nn.Module):
dist(sg(post)) if self._discrete else dist(sg(post))._dist,
dist(prior) if self._discrete else dist(prior)._dist,
)
rep_loss = torch.mean(torch.clip(rep_loss, min=free))
dyn_loss = torch.mean(torch.clip(dyn_loss, min=free))
# this is implemented using maximum at the original repo as the gradients are not backpropagated for the out of limits.
rep_loss = torch.clip(rep_loss, min=free)
dyn_loss = torch.clip(dyn_loss, min=free)
loss = dyn_scale * dyn_loss + rep_scale * rep_loss
return loss, value, dyn_loss, rep_loss