changed the discount head to predict terminal

This commit is contained in:
NM512
2023-04-22 09:34:23 +09:00
parent 16151efb3c
commit 628b856c63
4 changed files with 50 additions and 50 deletions

View File

@@ -273,28 +273,24 @@ class RSSM(nn.Module):
std = std + self._min_std
return {"mean": mean, "std": std}
def kl_loss(self, post, prior, forward, free, lscale, rscale):
def kl_loss(self, post, prior, free, dyn_scale, rep_scale):
kld = torchd.kl.kl_divergence
dist = lambda x: self.get_dist(x)
sg = lambda x: {k: v.detach() for k, v in x.items()}
# forward == false -> (post, prior)
lhs, rhs = (prior, post) if forward else (post, prior)
# forward == false -> Lrep
value_lhs = value = kld(
dist(lhs) if self._discrete else dist(lhs)._dist,
dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist,
rep_loss = value = kld(
dist(post) if self._discrete else dist(post)._dist,
dist(sg(prior)) if self._discrete else dist(sg(prior))._dist,
)
# forward == false -> Ldyn
value_rhs = kld(
dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist,
dist(rhs) if self._discrete else dist(rhs)._dist,
dyn_loss = kld(
dist(sg(post)) if self._discrete else dist(sg(post))._dist,
dist(prior) if self._discrete else dist(prior)._dist,
)
loss_lhs = torch.clip(torch.mean(value_lhs), min=free)
loss_rhs = torch.clip(torch.mean(value_rhs), min=free)
loss = lscale * loss_lhs + rscale * loss_rhs
rep_loss = torch.mean(torch.clip(rep_loss, min=free))
dyn_loss = torch.mean(torch.clip(dyn_loss, min=free))
loss = dyn_scale * dyn_loss + rep_scale * rep_loss
return loss, value, loss_lhs, loss_rhs
return loss, value, dyn_loss, rep_loss
class ConvEncoder(nn.Module):