changed the discount head to predict terminal
This commit is contained in:
55
models.py
55
models.py
@@ -107,16 +107,15 @@ class WorldModel(nn.Module):
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
)
|
||||
if config.pred_discount:
|
||||
self.heads["discount"] = networks.DenseHead(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.discount_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist="binary",
|
||||
)
|
||||
self.heads["cont"] = networks.DenseHead(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.cont_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist="binary",
|
||||
)
|
||||
for name in config.grad_heads:
|
||||
assert name in self.heads, name
|
||||
self._model_opt = tools.Optimizer(
|
||||
@@ -129,7 +128,7 @@ class WorldModel(nn.Module):
|
||||
opt=config.opt,
|
||||
use_amp=self._use_amp,
|
||||
)
|
||||
self._scales = dict(reward=config.reward_scale, discount=config.discount_scale)
|
||||
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
|
||||
|
||||
def _train(self, data):
|
||||
# action (batch_size, batch_length, act_dim)
|
||||
@@ -143,10 +142,10 @@ class WorldModel(nn.Module):
|
||||
embed = self.encoder(data)
|
||||
post, prior = self.dynamics.observe(embed, data["action"])
|
||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||
kl_lscale = tools.schedule(self._config.kl_lscale, self._step)
|
||||
kl_rscale = tools.schedule(self._config.kl_rscale, self._step)
|
||||
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss(
|
||||
post, prior, self._config.kl_forward, kl_free, kl_lscale, kl_rscale
|
||||
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
||||
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||
post, prior, kl_free, dyn_scale, rep_scale
|
||||
)
|
||||
losses = {}
|
||||
likes = {}
|
||||
@@ -163,10 +162,10 @@ class WorldModel(nn.Module):
|
||||
|
||||
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
||||
metrics["kl_free"] = kl_free
|
||||
metrics["kl_lscale"] = kl_lscale
|
||||
metrics["kl_rscale"] = kl_rscale
|
||||
metrics["loss_lhs"] = to_np(loss_lhs)
|
||||
metrics["loss_rhs"] = to_np(loss_rhs)
|
||||
metrics["dyn_scale"] = dyn_scale
|
||||
metrics["rep_scale"] = rep_scale
|
||||
metrics["dyn_loss"] = to_np(dyn_loss)
|
||||
metrics["rep_loss"] = to_np(rep_loss)
|
||||
metrics["kl"] = to_np(torch.mean(kl_value))
|
||||
with torch.cuda.amp.autocast(self._use_amp):
|
||||
metrics["prior_ent"] = to_np(
|
||||
@@ -193,6 +192,11 @@ class WorldModel(nn.Module):
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
||||
if "is_terminal" in obs:
|
||||
# this label is necessary to train cont_head
|
||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('"is_terminal" was not found in observation.')
|
||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
||||
return obs
|
||||
|
||||
@@ -347,7 +351,14 @@ class ImagBehavior(nn.Module):
|
||||
metrics.update(tools.tensorstats(value.mode(), "value"))
|
||||
metrics.update(tools.tensorstats(target, "target"))
|
||||
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
||||
metrics.update(tools.tensorstats(imag_action, "imag_action"))
|
||||
if self._config.actor_dist in ["onehot"]:
|
||||
metrics.update(
|
||||
tools.tensorstats(
|
||||
torch.argmax(imag_action, dim=-1).float(), "imag_action"
|
||||
)
|
||||
)
|
||||
else:
|
||||
metrics.update(tools.tensorstats(imag_action, "imag_action"))
|
||||
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
|
||||
with tools.RequiresGrad(self):
|
||||
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
|
||||
@@ -390,9 +401,9 @@ class ImagBehavior(nn.Module):
|
||||
def _compute_target(
|
||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||
):
|
||||
if "discount" in self._world_model.heads:
|
||||
if "cont" in self._world_model.heads:
|
||||
inp = self._world_model.dynamics.get_feat(imag_state)
|
||||
discount = self._world_model.heads["discount"](inp).mean
|
||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||
else:
|
||||
discount = self._config.discount * torch.ones_like(reward)
|
||||
if self._config.future_entropy and self._config.actor_entropy() > 0:
|
||||
|
||||
Reference in New Issue
Block a user