learnable initial state options for RSSM

This commit is contained in:
NM512
2023-04-29 07:54:03 +09:00
parent 1328ff1088
commit 0eb66997fb
4 changed files with 60 additions and 12 deletions

View File

@@ -66,6 +66,7 @@ class WorldModel(nn.Module):
config.dyn_min_std,
config.dyn_cell,
config.unimix_ratio,
config.initial,
config.num_actions,
embed_size,
config.device,
@@ -95,6 +96,7 @@ class WorldModel(nn.Module):
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
else:
self.heads["reward"] = networks.DenseHead(
@@ -106,6 +108,7 @@ class WorldModel(nn.Module):
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
self.heads["cont"] = networks.DenseHead(
feat_size, # pytorch version
@@ -115,6 +118,7 @@ class WorldModel(nn.Module):
config.act,
config.norm,
dist="binary",
device=config.device,
)
for name in config.grad_heads:
assert name in self.heads, name
@@ -140,7 +144,9 @@ class WorldModel(nn.Module):
with tools.RequiresGrad(self):
with torch.cuda.amp.autocast(self._use_amp):
embed = self.encoder(data)
post, prior = self.dynamics.observe(embed, data["action"])
post, prior = self.dynamics.observe(
embed, data["action"], data["is_first"]
)
kl_free = tools.schedule(self._config.kl_free, self._step)
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
rep_scale = tools.schedule(self._config.rep_scale, self._step)
@@ -204,7 +210,9 @@ class WorldModel(nn.Module):
data = self.preprocess(data)
embed = self.encoder(data)
states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5])
states, _ = self.dynamics.observe(
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
)
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
@@ -257,6 +265,7 @@ class ImagBehavior(nn.Module):
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
else:
self.value = networks.DenseHead(
@@ -268,6 +277,7 @@ class ImagBehavior(nn.Module):
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
if config.slow_value_target:
self._slow_value = copy.deepcopy(self.value)