added state input capability
This commit is contained in:
63
models.py
63
models.py
@@ -29,26 +29,14 @@ class RewardEMA(object):
|
||||
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
def __init__(self, step, config):
|
||||
def __init__(self, obs_space, act_space, step, config):
|
||||
super(WorldModel, self).__init__()
|
||||
self._step = step
|
||||
self._use_amp = True if config.precision == 16 else False
|
||||
self._config = config
|
||||
self.encoder = networks.ConvEncoder(
|
||||
config.grayscale,
|
||||
config.cnn_depth,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.encoder_kernels,
|
||||
)
|
||||
if config.size[0] == 64 and config.size[1] == 64:
|
||||
embed_size = (
|
||||
(64 // 2 ** (len(config.encoder_kernels))) ** 2
|
||||
* config.cnn_depth
|
||||
* 2 ** (len(config.encoder_kernels) - 1)
|
||||
)
|
||||
else:
|
||||
raise NotImplemented(f"{config.size} is not applicable now")
|
||||
shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
|
||||
self.encoder = networks.MultiEncoder(shapes, **config.encoder)
|
||||
embed_size = self.encoder.outdim
|
||||
self.dynamics = networks.RSSM(
|
||||
config.dyn_stoch,
|
||||
config.dyn_deter,
|
||||
@@ -72,22 +60,15 @@ class WorldModel(nn.Module):
|
||||
config.device,
|
||||
)
|
||||
self.heads = nn.ModuleDict()
|
||||
channels = 1 if config.grayscale else 3
|
||||
shape = (channels,) + config.size
|
||||
if config.dyn_discrete:
|
||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||
else:
|
||||
feat_size = config.dyn_stoch + config.dyn_deter
|
||||
self.heads["image"] = networks.ConvDecoder(
|
||||
feat_size, # pytorch version
|
||||
config.cnn_depth,
|
||||
config.act,
|
||||
config.norm,
|
||||
shape,
|
||||
config.decoder_kernels,
|
||||
self.heads["decoder"] = networks.MultiDecoder(
|
||||
feat_size, shapes, **config.decoder
|
||||
)
|
||||
if config.reward_head == "twohot_symlog":
|
||||
self.heads["reward"] = networks.DenseHead(
|
||||
if config.reward_head == "symlog_disc":
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
(255,),
|
||||
config.reward_layers,
|
||||
@@ -99,7 +80,7 @@ class WorldModel(nn.Module):
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.heads["reward"] = networks.DenseHead(
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.reward_layers,
|
||||
@@ -110,7 +91,7 @@ class WorldModel(nn.Module):
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
self.heads["cont"] = networks.DenseHead(
|
||||
self.heads["cont"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.cont_layers,
|
||||
@@ -153,15 +134,19 @@ class WorldModel(nn.Module):
|
||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||
post, prior, kl_free, dyn_scale, rep_scale
|
||||
)
|
||||
losses = {}
|
||||
likes = {}
|
||||
preds = {}
|
||||
for name, head in self.heads.items():
|
||||
grad_head = name in self._config.grad_heads
|
||||
feat = self.dynamics.get_feat(post)
|
||||
feat = feat if grad_head else feat.detach()
|
||||
pred = head(feat)
|
||||
if type(pred) is dict:
|
||||
preds.update(pred)
|
||||
else:
|
||||
preds[name] = pred
|
||||
losses = {}
|
||||
for name, pred in preds.items():
|
||||
like = pred.log_prob(data[name])
|
||||
likes[name] = like
|
||||
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
|
||||
model_loss = sum(losses.values()) + kl_loss
|
||||
metrics = self._model_opt(model_loss, self.parameters())
|
||||
@@ -213,11 +198,13 @@ class WorldModel(nn.Module):
|
||||
states, _ = self.dynamics.observe(
|
||||
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
|
||||
)
|
||||
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
|
||||
:6
|
||||
]
|
||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
init = {k: v[:, -1] for k, v in states.items()}
|
||||
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
|
||||
openl = self.heads["image"](self.dynamics.get_feat(prior)).mode()
|
||||
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
|
||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||
# observed image is given until 5 steps
|
||||
model = torch.cat([recon[:, :5], openl], 1)
|
||||
@@ -254,9 +241,9 @@ class ImagBehavior(nn.Module):
|
||||
config.actor_temp,
|
||||
outscale=1.0,
|
||||
unimix_ratio=config.action_unimix_ratio,
|
||||
) # action_dist -> action_disc?
|
||||
if config.value_head == "twohot_symlog":
|
||||
self.value = networks.DenseHead(
|
||||
)
|
||||
if config.value_head == "symlog_disc":
|
||||
self.value = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
(255,),
|
||||
config.value_layers,
|
||||
@@ -268,7 +255,7 @@ class ImagBehavior(nn.Module):
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.value = networks.DenseHead(
|
||||
self.value = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.value_layers,
|
||||
|
||||
Reference in New Issue
Block a user