modified loss calculation
This commit is contained in:
27
models.py
27
models.py
@@ -144,10 +144,14 @@ class WorldModel(nn.Module):
|
||||
preds[name] = pred
|
||||
losses = {}
|
||||
for name, pred in preds.items():
|
||||
like = pred.log_prob(data[name])
|
||||
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
|
||||
model_loss = sum(losses.values()) + kl_loss
|
||||
metrics = self._model_opt(model_loss, self.parameters())
|
||||
loss = -pred.log_prob(data[name])
|
||||
assert loss.shape == embed.shape[:2], (name, loss.shape)
|
||||
losses[name] = loss
|
||||
scaled = {
|
||||
key: value * self._scales[key] for key, value in losses.items()
|
||||
}
|
||||
model_loss = sum(scaled.values()) + kl_loss
|
||||
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
||||
|
||||
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
||||
metrics["kl_free"] = kl_free
|
||||
@@ -318,6 +322,8 @@ class ImagBehavior(nn.Module):
|
||||
weights,
|
||||
base,
|
||||
)
|
||||
actor_loss -= self._config.actor["entropy"] * actor_ent[:-1, ..., None]
|
||||
actor_loss = torch.mean(actor_loss)
|
||||
metrics.update(mets)
|
||||
value_input = imag_feat
|
||||
|
||||
@@ -382,10 +388,6 @@ class ImagBehavior(nn.Module):
|
||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||
else:
|
||||
discount = self._config.discount * torch.ones_like(reward)
|
||||
if self._config.future_entropy and self._config.actor_entropy > 0:
|
||||
reward += self._config.actor_entropy * actor_ent
|
||||
if self._config.future_entropy and self._config.actor_state_entropy > 0:
|
||||
reward += self._config.actor_state_entropy * state_ent
|
||||
value = self.value(imag_feat).mode()
|
||||
target = tools.lambda_return(
|
||||
reward[1:],
|
||||
@@ -444,14 +446,7 @@ class ImagBehavior(nn.Module):
|
||||
metrics["imag_gradient_mix"] = mix
|
||||
else:
|
||||
raise NotImplementedError(self._config.imag_gradient)
|
||||
if not self._config.future_entropy and self._config.actor_entropy > 0:
|
||||
actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None]
|
||||
actor_target += actor_entropy
|
||||
if not self._config.future_entropy and (self._config.actor_state_entropy > 0):
|
||||
state_entropy = self._config.actor_state_entropy * state_ent[:-1]
|
||||
actor_target += state_entropy
|
||||
metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
|
||||
actor_loss = -torch.mean(weights[:-1] * actor_target)
|
||||
actor_loss = -weights[:-1] * actor_target
|
||||
return actor_loss, metrics
|
||||
|
||||
def _update_slow_target(self):
|
||||
|
||||
Reference in New Issue
Block a user