added state input capability

This commit is contained in:
NM512
2023-05-14 23:38:46 +09:00
parent 3ebb8ad617
commit b984e69b6e
8 changed files with 369 additions and 142 deletions

View File

@@ -29,26 +29,14 @@ class RewardEMA(object):
class WorldModel(nn.Module):
def __init__(self, step, config):
def __init__(self, obs_space, act_space, step, config):
super(WorldModel, self).__init__()
self._step = step
self._use_amp = True if config.precision == 16 else False
self._config = config
self.encoder = networks.ConvEncoder(
config.grayscale,
config.cnn_depth,
config.act,
config.norm,
config.encoder_kernels,
)
if config.size[0] == 64 and config.size[1] == 64:
embed_size = (
(64 // 2 ** (len(config.encoder_kernels))) ** 2
* config.cnn_depth
* 2 ** (len(config.encoder_kernels) - 1)
)
else:
raise NotImplemented(f"{config.size} is not applicable now")
shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
self.encoder = networks.MultiEncoder(shapes, **config.encoder)
embed_size = self.encoder.outdim
self.dynamics = networks.RSSM(
config.dyn_stoch,
config.dyn_deter,
@@ -72,22 +60,15 @@ class WorldModel(nn.Module):
config.device,
)
self.heads = nn.ModuleDict()
channels = 1 if config.grayscale else 3
shape = (channels,) + config.size
if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else:
feat_size = config.dyn_stoch + config.dyn_deter
self.heads["image"] = networks.ConvDecoder(
feat_size, # pytorch version
config.cnn_depth,
config.act,
config.norm,
shape,
config.decoder_kernels,
self.heads["decoder"] = networks.MultiDecoder(
feat_size, shapes, **config.decoder
)
if config.reward_head == "twohot_symlog":
self.heads["reward"] = networks.DenseHead(
if config.reward_head == "symlog_disc":
self.heads["reward"] = networks.MLP(
feat_size, # pytorch version
(255,),
config.reward_layers,
@@ -99,7 +80,7 @@ class WorldModel(nn.Module):
device=config.device,
)
else:
self.heads["reward"] = networks.DenseHead(
self.heads["reward"] = networks.MLP(
feat_size, # pytorch version
[],
config.reward_layers,
@@ -110,7 +91,7 @@ class WorldModel(nn.Module):
outscale=0.0,
device=config.device,
)
self.heads["cont"] = networks.DenseHead(
self.heads["cont"] = networks.MLP(
feat_size, # pytorch version
[],
config.cont_layers,
@@ -153,15 +134,19 @@ class WorldModel(nn.Module):
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
post, prior, kl_free, dyn_scale, rep_scale
)
losses = {}
likes = {}
preds = {}
for name, head in self.heads.items():
grad_head = name in self._config.grad_heads
feat = self.dynamics.get_feat(post)
feat = feat if grad_head else feat.detach()
pred = head(feat)
if type(pred) is dict:
preds.update(pred)
else:
preds[name] = pred
losses = {}
for name, pred in preds.items():
like = pred.log_prob(data[name])
likes[name] = like
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
model_loss = sum(losses.values()) + kl_loss
metrics = self._model_opt(model_loss, self.parameters())
@@ -213,11 +198,13 @@ class WorldModel(nn.Module):
states, _ = self.dynamics.observe(
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
)
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
:6
]
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
openl = self.heads["image"](self.dynamics.get_feat(prior)).mode()
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1)
@@ -254,9 +241,9 @@ class ImagBehavior(nn.Module):
config.actor_temp,
outscale=1.0,
unimix_ratio=config.action_unimix_ratio,
) # action_dist -> action_disc?
if config.value_head == "twohot_symlog":
self.value = networks.DenseHead(
)
if config.value_head == "symlog_disc":
self.value = networks.MLP(
feat_size, # pytorch version
(255,),
config.value_layers,
@@ -268,7 +255,7 @@ class ImagBehavior(nn.Module):
device=config.device,
)
else:
self.value = networks.DenseHead(
self.value = networks.MLP(
feat_size, # pytorch version
[],
config.value_layers,