modification of expl.
This commit is contained in:
13
models.py
13
models.py
@@ -36,7 +36,7 @@ class WorldModel(nn.Module):
|
||||
self._config = config
|
||||
shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
|
||||
self.encoder = networks.MultiEncoder(shapes, **config.encoder)
|
||||
embed_size = self.encoder.outdim
|
||||
self.embed_size = self.encoder.outdim
|
||||
self.dynamics = networks.RSSM(
|
||||
config.dyn_stoch,
|
||||
config.dyn_deter,
|
||||
@@ -56,7 +56,7 @@ class WorldModel(nn.Module):
|
||||
config.unimix_ratio,
|
||||
config.initial,
|
||||
config.num_actions,
|
||||
embed_size,
|
||||
self.embed_size,
|
||||
config.device,
|
||||
)
|
||||
self.heads = nn.ModuleDict()
|
||||
@@ -228,7 +228,7 @@ class ImagBehavior(nn.Module):
|
||||
else:
|
||||
feat_size = config.dyn_stoch + config.dyn_deter
|
||||
self.actor = networks.ActionHead(
|
||||
feat_size, # pytorch version
|
||||
feat_size,
|
||||
config.num_actions,
|
||||
config.actor_layers,
|
||||
config.units,
|
||||
@@ -244,7 +244,7 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
if config.value_head == "symlog_disc":
|
||||
self.value = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
feat_size,
|
||||
(255,),
|
||||
config.value_layers,
|
||||
config.units,
|
||||
@@ -256,7 +256,7 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.value = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
feat_size,
|
||||
[],
|
||||
config.value_layers,
|
||||
config.units,
|
||||
@@ -356,7 +356,7 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
else:
|
||||
metrics.update(tools.tensorstats(imag_action, "imag_action"))
|
||||
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
|
||||
metrics["actor_entropy"] = to_np(torch.mean(actor_ent))
|
||||
with tools.RequiresGrad(self):
|
||||
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
|
||||
metrics.update(self._value_opt(value_loss, self.value.parameters()))
|
||||
@@ -462,7 +462,6 @@ class ImagBehavior(nn.Module):
|
||||
if not self._config.future_entropy and (self._config.actor_entropy() > 0):
|
||||
actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None]
|
||||
actor_target += actor_entropy
|
||||
metrics["actor_entropy"] = to_np(torch.mean(actor_entropy))
|
||||
if not self._config.future_entropy and (self._config.actor_state_entropy() > 0):
|
||||
state_entropy = self._config.actor_state_entropy() * state_ent[:-1]
|
||||
actor_target += state_entropy
|
||||
|
||||
Reference in New Issue
Block a user