changed the discount head to predict terminal
This commit is contained in:
26
networks.py
26
networks.py
@@ -273,28 +273,24 @@ class RSSM(nn.Module):
|
||||
std = std + self._min_std
|
||||
return {"mean": mean, "std": std}
|
||||
|
||||
def kl_loss(self, post, prior, forward, free, lscale, rscale):
|
||||
def kl_loss(self, post, prior, free, dyn_scale, rep_scale):
|
||||
kld = torchd.kl.kl_divergence
|
||||
dist = lambda x: self.get_dist(x)
|
||||
sg = lambda x: {k: v.detach() for k, v in x.items()}
|
||||
# forward == false -> (post, prior)
|
||||
lhs, rhs = (prior, post) if forward else (post, prior)
|
||||
|
||||
# forward == false -> Lrep
|
||||
value_lhs = value = kld(
|
||||
dist(lhs) if self._discrete else dist(lhs)._dist,
|
||||
dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist,
|
||||
rep_loss = value = kld(
|
||||
dist(post) if self._discrete else dist(post)._dist,
|
||||
dist(sg(prior)) if self._discrete else dist(sg(prior))._dist,
|
||||
)
|
||||
# forward == false -> Ldyn
|
||||
value_rhs = kld(
|
||||
dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist,
|
||||
dist(rhs) if self._discrete else dist(rhs)._dist,
|
||||
dyn_loss = kld(
|
||||
dist(sg(post)) if self._discrete else dist(sg(post))._dist,
|
||||
dist(prior) if self._discrete else dist(prior)._dist,
|
||||
)
|
||||
loss_lhs = torch.clip(torch.mean(value_lhs), min=free)
|
||||
loss_rhs = torch.clip(torch.mean(value_rhs), min=free)
|
||||
loss = lscale * loss_lhs + rscale * loss_rhs
|
||||
rep_loss = torch.mean(torch.clip(rep_loss, min=free))
|
||||
dyn_loss = torch.mean(torch.clip(dyn_loss, min=free))
|
||||
loss = dyn_scale * dyn_loss + rep_scale * rep_loss
|
||||
|
||||
return loss, value, loss_lhs, loss_rhs
|
||||
return loss, value, dyn_loss, rep_loss
|
||||
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user