learnable initial state options for RSSM
This commit is contained in:
14
models.py
14
models.py
@@ -66,6 +66,7 @@ class WorldModel(nn.Module):
|
||||
config.dyn_min_std,
|
||||
config.dyn_cell,
|
||||
config.unimix_ratio,
|
||||
config.initial,
|
||||
config.num_actions,
|
||||
embed_size,
|
||||
config.device,
|
||||
@@ -95,6 +96,7 @@ class WorldModel(nn.Module):
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.heads["reward"] = networks.DenseHead(
|
||||
@@ -106,6 +108,7 @@ class WorldModel(nn.Module):
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
self.heads["cont"] = networks.DenseHead(
|
||||
feat_size, # pytorch version
|
||||
@@ -115,6 +118,7 @@ class WorldModel(nn.Module):
|
||||
config.act,
|
||||
config.norm,
|
||||
dist="binary",
|
||||
device=config.device,
|
||||
)
|
||||
for name in config.grad_heads:
|
||||
assert name in self.heads, name
|
||||
@@ -140,7 +144,9 @@ class WorldModel(nn.Module):
|
||||
with tools.RequiresGrad(self):
|
||||
with torch.cuda.amp.autocast(self._use_amp):
|
||||
embed = self.encoder(data)
|
||||
post, prior = self.dynamics.observe(embed, data["action"])
|
||||
post, prior = self.dynamics.observe(
|
||||
embed, data["action"], data["is_first"]
|
||||
)
|
||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
||||
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
||||
@@ -204,7 +210,9 @@ class WorldModel(nn.Module):
|
||||
data = self.preprocess(data)
|
||||
embed = self.encoder(data)
|
||||
|
||||
states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5])
|
||||
states, _ = self.dynamics.observe(
|
||||
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
|
||||
)
|
||||
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
init = {k: v[:, -1] for k, v in states.items()}
|
||||
@@ -257,6 +265,7 @@ class ImagBehavior(nn.Module):
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.value = networks.DenseHead(
|
||||
@@ -268,6 +277,7 @@ class ImagBehavior(nn.Module):
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
if config.slow_value_target:
|
||||
self._slow_value = copy.deepcopy(self.value)
|
||||
|
||||
Reference in New Issue
Block a user