removed scheduling function
This commit is contained in:
24
models.py
24
models.py
@@ -128,9 +128,9 @@ class WorldModel(nn.Module):
|
||||
post, prior = self.dynamics.observe(
|
||||
embed, data["action"], data["is_first"]
|
||||
)
|
||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
||||
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
||||
kl_free = self._config.kl_free
|
||||
dyn_scale = self._config.dyn_scale
|
||||
rep_scale = self._config.rep_scale
|
||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||
post, prior, kl_free, dyn_scale, rep_scale
|
||||
)
|
||||
@@ -393,10 +393,10 @@ class ImagBehavior(nn.Module):
|
||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||
else:
|
||||
discount = self._config.discount * torch.ones_like(reward)
|
||||
if self._config.future_entropy and self._config.actor_entropy() > 0:
|
||||
reward += self._config.actor_entropy() * actor_ent
|
||||
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
|
||||
reward += self._config.actor_state_entropy() * state_ent
|
||||
if self._config.future_entropy and self._config.actor_entropy > 0:
|
||||
reward += self._config.actor_entropy * actor_ent
|
||||
if self._config.future_entropy and self._config.actor_state_entropy > 0:
|
||||
reward += self._config.actor_state_entropy * state_ent
|
||||
value = self.value(imag_feat).mode()
|
||||
target = tools.lambda_return(
|
||||
reward[1:],
|
||||
@@ -450,16 +450,16 @@ class ImagBehavior(nn.Module):
|
||||
policy.log_prob(imag_action)[:-1][:, :, None]
|
||||
* (target - self.value(imag_feat[:-1]).mode()).detach()
|
||||
)
|
||||
mix = self._config.imag_gradient_mix()
|
||||
mix = self._config.imag_gradient_mix
|
||||
actor_target = mix * target + (1 - mix) * actor_target
|
||||
metrics["imag_gradient_mix"] = mix
|
||||
else:
|
||||
raise NotImplementedError(self._config.imag_gradient)
|
||||
if not self._config.future_entropy and (self._config.actor_entropy() > 0):
|
||||
actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None]
|
||||
if not self._config.future_entropy and self._config.actor_entropy > 0:
|
||||
actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None]
|
||||
actor_target += actor_entropy
|
||||
if not self._config.future_entropy and (self._config.actor_state_entropy() > 0):
|
||||
state_entropy = self._config.actor_state_entropy() * state_ent[:-1]
|
||||
if not self._config.future_entropy and (self._config.actor_state_entropy > 0):
|
||||
state_entropy = self._config.actor_state_entropy * state_ent[:-1]
|
||||
actor_target += state_entropy
|
||||
metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
|
||||
actor_loss = -torch.mean(weights[:-1] * actor_target)
|
||||
|
||||
Reference in New Issue
Block a user