merged action head into MLP and modified configs

This commit is contained in:
NM512
2024-01-05 10:26:48 +09:00
parent e0f2017e28
commit e0487f8206
5 changed files with 133 additions and 231 deletions

145
models.py
View File

@@ -67,39 +67,29 @@ class WorldModel(nn.Module):
self.heads["decoder"] = networks.MultiDecoder(
feat_size, shapes, **config.decoder
)
if config.reward_head == "symlog_disc":
self.heads["reward"] = networks.MLP(
feat_size, # pytorch version
(255,),
config.reward_layers,
config.units,
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
else:
self.heads["reward"] = networks.MLP(
feat_size, # pytorch version
[],
config.reward_layers,
config.units,
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
self.heads["reward"] = networks.MLP(
feat_size,
(255,) if config.reward_head["dist"] == "symlog_disc" else (),
config.reward_head["layers"],
config.units,
config.act,
config.norm,
dist=config.reward_head["dist"],
outscale=config.reward_head["outscale"],
device=config.device,
name="Reward",
)
self.heads["cont"] = networks.MLP(
feat_size, # pytorch version
[],
config.cont_layers,
feat_size,
(),
config.cont_head["layers"],
config.units,
config.act,
config.norm,
dist="binary",
outscale=config.cont_head["outscale"],
device=config.device,
name="Cont",
)
for name in config.grad_heads:
assert name in self.heads, name
@@ -113,7 +103,14 @@ class WorldModel(nn.Module):
opt=config.opt,
use_amp=self._use_amp,
)
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
print(
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
)
self._scales = dict(
reward=config.reward_head["scale"],
cont=config.cont_head["scale"],
image=1.0,
)
def _train(self, data):
# action (batch_size, batch_length, act_dim)
@@ -134,6 +131,7 @@ class WorldModel(nn.Module):
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
post, prior, kl_free, dyn_scale, rep_scale
)
assert kl_loss.shape == embed.shape[:2], kl_loss.shape
preds = {}
for name, head in self.heads.items():
grad_head = name in self._config.grad_heads
@@ -226,65 +224,60 @@ class ImagBehavior(nn.Module):
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else:
feat_size = config.dyn_stoch + config.dyn_deter
self.actor = networks.ActionHead(
self.actor = networks.MLP(
feat_size,
config.num_actions,
config.actor_layers,
(config.num_actions,),
config.actor["layers"],
config.units,
config.act,
config.norm,
config.actor_dist,
config.actor_init_std,
config.actor_min_std,
config.actor_max_std,
config.actor_temp,
config.actor["dist"],
"learned",
config.actor["min_std"],
config.actor["max_std"],
config.actor["temp"],
unimix_ratio=config.actor["unimix_ratio"],
outscale=1.0,
unimix_ratio=config.action_unimix_ratio,
name="Actor",
)
if config.value_head == "symlog_disc":
self.value = networks.MLP(
feat_size,
(255,),
config.value_layers,
config.units,
config.act,
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
else:
self.value = networks.MLP(
feat_size,
[],
config.value_layers,
config.units,
config.act,
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
if config.slow_value_target:
self.value = networks.MLP(
feat_size,
(255,) if config.critic["dist"] == "symlog_disc" else (),
config.critic["layers"],
config.units,
config.act,
config.norm,
config.critic["dist"],
outscale=config.critic["outscale"],
device=config.device,
name="Value",
)
if config.critic["slow_target"]:
self._slow_value = copy.deepcopy(self.value)
self._updates = 0
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._actor_opt = tools.Optimizer(
"actor",
self.actor.parameters(),
config.actor_lr,
config.ac_opt_eps,
config.actor_grad_clip,
config.actor["lr"],
config.actor["eps"],
config.actor["grad_clip"],
**kw,
)
print(
f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
)
self._value_opt = tools.Optimizer(
"value",
self.value.parameters(),
config.value_lr,
config.ac_opt_eps,
config.value_grad_clip,
config.critic["lr"],
config.critic["eps"],
config.critic["grad_clip"],
**kw,
)
print(
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
)
if self._config.reward_EMA:
self.reward_ema = RewardEMA(device=self._config.device)
@@ -335,19 +328,15 @@ class ImagBehavior(nn.Module):
# (time, batch, 1), (time, batch, 1) -> (time, batch)
value_loss = -value.log_prob(target.detach())
slow_target = self._slow_value(value_input[:-1].detach())
if self._config.slow_value_target:
value_loss = value_loss - value.log_prob(
slow_target.mode().detach()
)
if self._config.value_decay:
value_loss += self._config.value_decay * value.mode()
if self._config.critic["slow_target"]:
value_loss -= value.log_prob(slow_target.mode().detach())
# (time, batch, 1), (time, batch, 1) -> (1,)
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
metrics.update(tools.tensorstats(value.mode(), "value"))
metrics.update(tools.tensorstats(target, "target"))
metrics.update(tools.tensorstats(reward, "imag_reward"))
if self._config.actor_dist in ["onehot"]:
if self._config.actor["dist"] in ["onehot"]:
metrics.update(
tools.tensorstats(
torch.argmax(imag_action, dim=-1).float(), "imag_action"
@@ -466,9 +455,9 @@ class ImagBehavior(nn.Module):
return actor_loss, metrics
def _update_slow_target(self):
if self._config.slow_value_target:
if self._updates % self._config.slow_target_update == 0:
mix = self._config.slow_target_fraction
if self._config.critic["slow_target"]:
if self._updates % self._config.critic["slow_target_update"] == 0:
mix = self._config.critic["slow_target_fraction"]
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
d.data = mix * s.data + (1 - mix) * d.data
self._updates += 1