modification of expl.

This commit is contained in:
NM512
2023-05-21 08:17:47 +09:00
parent b8ef214efa
commit 02c3d45fcf
3 changed files with 28 additions and 16 deletions

View File

@@ -36,7 +36,7 @@ class WorldModel(nn.Module):
self._config = config
shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
self.encoder = networks.MultiEncoder(shapes, **config.encoder)
embed_size = self.encoder.outdim
self.embed_size = self.encoder.outdim
self.dynamics = networks.RSSM(
config.dyn_stoch,
config.dyn_deter,
@@ -56,7 +56,7 @@ class WorldModel(nn.Module):
config.unimix_ratio,
config.initial,
config.num_actions,
embed_size,
self.embed_size,
config.device,
)
self.heads = nn.ModuleDict()
@@ -228,7 +228,7 @@ class ImagBehavior(nn.Module):
else:
feat_size = config.dyn_stoch + config.dyn_deter
self.actor = networks.ActionHead(
feat_size, # pytorch version
feat_size,
config.num_actions,
config.actor_layers,
config.units,
@@ -244,7 +244,7 @@ class ImagBehavior(nn.Module):
)
if config.value_head == "symlog_disc":
self.value = networks.MLP(
feat_size, # pytorch version
feat_size,
(255,),
config.value_layers,
config.units,
@@ -256,7 +256,7 @@ class ImagBehavior(nn.Module):
)
else:
self.value = networks.MLP(
feat_size, # pytorch version
feat_size,
[],
config.value_layers,
config.units,
@@ -356,7 +356,7 @@ class ImagBehavior(nn.Module):
)
else:
metrics.update(tools.tensorstats(imag_action, "imag_action"))
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
metrics["actor_entropy"] = to_np(torch.mean(actor_ent))
with tools.RequiresGrad(self):
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
metrics.update(self._value_opt(value_loss, self.value.parameters()))
@@ -462,7 +462,6 @@ class ImagBehavior(nn.Module):
if not self._config.future_entropy and (self._config.actor_entropy() > 0):
actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None]
actor_target += actor_entropy
metrics["actor_entropy"] = to_np(torch.mean(actor_entropy))
if not self._config.future_entropy and (self._config.actor_state_entropy() > 0):
state_entropy = self._config.actor_state_entropy() * state_ent[:-1]
actor_target += state_entropy