merged action head into MLP and modified configs
This commit is contained in:
145
models.py
145
models.py
@@ -67,39 +67,29 @@ class WorldModel(nn.Module):
|
||||
self.heads["decoder"] = networks.MultiDecoder(
|
||||
feat_size, shapes, **config.decoder
|
||||
)
|
||||
if config.reward_head == "symlog_disc":
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
(255,),
|
||||
config.reward_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.reward_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size,
|
||||
(255,) if config.reward_head["dist"] == "symlog_disc" else (),
|
||||
config.reward_head["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head["dist"],
|
||||
outscale=config.reward_head["outscale"],
|
||||
device=config.device,
|
||||
name="Reward",
|
||||
)
|
||||
self.heads["cont"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.cont_layers,
|
||||
feat_size,
|
||||
(),
|
||||
config.cont_head["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist="binary",
|
||||
outscale=config.cont_head["outscale"],
|
||||
device=config.device,
|
||||
name="Cont",
|
||||
)
|
||||
for name in config.grad_heads:
|
||||
assert name in self.heads, name
|
||||
@@ -113,7 +103,14 @@ class WorldModel(nn.Module):
|
||||
opt=config.opt,
|
||||
use_amp=self._use_amp,
|
||||
)
|
||||
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
|
||||
print(
|
||||
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||
)
|
||||
self._scales = dict(
|
||||
reward=config.reward_head["scale"],
|
||||
cont=config.cont_head["scale"],
|
||||
image=1.0,
|
||||
)
|
||||
|
||||
def _train(self, data):
|
||||
# action (batch_size, batch_length, act_dim)
|
||||
@@ -134,6 +131,7 @@ class WorldModel(nn.Module):
|
||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||
post, prior, kl_free, dyn_scale, rep_scale
|
||||
)
|
||||
assert kl_loss.shape == embed.shape[:2], kl_loss.shape
|
||||
preds = {}
|
||||
for name, head in self.heads.items():
|
||||
grad_head = name in self._config.grad_heads
|
||||
@@ -226,65 +224,60 @@ class ImagBehavior(nn.Module):
|
||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||
else:
|
||||
feat_size = config.dyn_stoch + config.dyn_deter
|
||||
self.actor = networks.ActionHead(
|
||||
self.actor = networks.MLP(
|
||||
feat_size,
|
||||
config.num_actions,
|
||||
config.actor_layers,
|
||||
(config.num_actions,),
|
||||
config.actor["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.actor_dist,
|
||||
config.actor_init_std,
|
||||
config.actor_min_std,
|
||||
config.actor_max_std,
|
||||
config.actor_temp,
|
||||
config.actor["dist"],
|
||||
"learned",
|
||||
config.actor["min_std"],
|
||||
config.actor["max_std"],
|
||||
config.actor["temp"],
|
||||
unimix_ratio=config.actor["unimix_ratio"],
|
||||
outscale=1.0,
|
||||
unimix_ratio=config.action_unimix_ratio,
|
||||
name="Actor",
|
||||
)
|
||||
if config.value_head == "symlog_disc":
|
||||
self.value = networks.MLP(
|
||||
feat_size,
|
||||
(255,),
|
||||
config.value_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.value = networks.MLP(
|
||||
feat_size,
|
||||
[],
|
||||
config.value_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
if config.slow_value_target:
|
||||
self.value = networks.MLP(
|
||||
feat_size,
|
||||
(255,) if config.critic["dist"] == "symlog_disc" else (),
|
||||
config.critic["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.critic["dist"],
|
||||
outscale=config.critic["outscale"],
|
||||
device=config.device,
|
||||
name="Value",
|
||||
)
|
||||
if config.critic["slow_target"]:
|
||||
self._slow_value = copy.deepcopy(self.value)
|
||||
self._updates = 0
|
||||
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||
self._actor_opt = tools.Optimizer(
|
||||
"actor",
|
||||
self.actor.parameters(),
|
||||
config.actor_lr,
|
||||
config.ac_opt_eps,
|
||||
config.actor_grad_clip,
|
||||
config.actor["lr"],
|
||||
config.actor["eps"],
|
||||
config.actor["grad_clip"],
|
||||
**kw,
|
||||
)
|
||||
print(
|
||||
f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
|
||||
)
|
||||
self._value_opt = tools.Optimizer(
|
||||
"value",
|
||||
self.value.parameters(),
|
||||
config.value_lr,
|
||||
config.ac_opt_eps,
|
||||
config.value_grad_clip,
|
||||
config.critic["lr"],
|
||||
config.critic["eps"],
|
||||
config.critic["grad_clip"],
|
||||
**kw,
|
||||
)
|
||||
print(
|
||||
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
||||
)
|
||||
if self._config.reward_EMA:
|
||||
self.reward_ema = RewardEMA(device=self._config.device)
|
||||
|
||||
@@ -335,19 +328,15 @@ class ImagBehavior(nn.Module):
|
||||
# (time, batch, 1), (time, batch, 1) -> (time, batch)
|
||||
value_loss = -value.log_prob(target.detach())
|
||||
slow_target = self._slow_value(value_input[:-1].detach())
|
||||
if self._config.slow_value_target:
|
||||
value_loss = value_loss - value.log_prob(
|
||||
slow_target.mode().detach()
|
||||
)
|
||||
if self._config.value_decay:
|
||||
value_loss += self._config.value_decay * value.mode()
|
||||
if self._config.critic["slow_target"]:
|
||||
value_loss -= value.log_prob(slow_target.mode().detach())
|
||||
# (time, batch, 1), (time, batch, 1) -> (1,)
|
||||
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
|
||||
|
||||
metrics.update(tools.tensorstats(value.mode(), "value"))
|
||||
metrics.update(tools.tensorstats(target, "target"))
|
||||
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
||||
if self._config.actor_dist in ["onehot"]:
|
||||
if self._config.actor["dist"] in ["onehot"]:
|
||||
metrics.update(
|
||||
tools.tensorstats(
|
||||
torch.argmax(imag_action, dim=-1).float(), "imag_action"
|
||||
@@ -466,9 +455,9 @@ class ImagBehavior(nn.Module):
|
||||
return actor_loss, metrics
|
||||
|
||||
def _update_slow_target(self):
|
||||
if self._config.slow_value_target:
|
||||
if self._updates % self._config.slow_target_update == 0:
|
||||
mix = self._config.slow_target_fraction
|
||||
if self._config.critic["slow_target"]:
|
||||
if self._updates % self._config.critic["slow_target_update"] == 0:
|
||||
mix = self._config.critic["slow_target_fraction"]
|
||||
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
|
||||
d.data = mix * s.data + (1 - mix) * d.data
|
||||
self._updates += 1
|
||||
|
||||
Reference in New Issue
Block a user