added state input capability
This commit is contained in:
304
networks.py
304
networks.py
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -20,8 +21,8 @@ class RSSM(nn.Module):
|
||||
rec_depth=1,
|
||||
shared=False,
|
||||
discrete=False,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
mean_act="none",
|
||||
std_act="softplus",
|
||||
temp_post=True,
|
||||
@@ -43,8 +44,8 @@ class RSSM(nn.Module):
|
||||
self._rec_depth = rec_depth
|
||||
self._shared = shared
|
||||
self._discrete = discrete
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._mean_act = mean_act
|
||||
self._std_act = std_act
|
||||
self._temp_post = temp_post
|
||||
@@ -62,8 +63,8 @@ class RSSM(nn.Module):
|
||||
inp_dim += self._embed
|
||||
for i in range(self._layers_input):
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
inp_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(self._act())
|
||||
inp_layers.append(norm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._inp_layers = nn.Sequential(*inp_layers)
|
||||
@@ -82,8 +83,8 @@ class RSSM(nn.Module):
|
||||
inp_dim = self._deter
|
||||
for i in range(self._layers_output):
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
img_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(self._act())
|
||||
img_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._img_out_layers = nn.Sequential(*img_out_layers)
|
||||
@@ -96,8 +97,8 @@ class RSSM(nn.Module):
|
||||
inp_dim = self._embed
|
||||
for i in range(self._layers_output):
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
obs_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(self._act())
|
||||
obs_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
||||
@@ -327,28 +328,156 @@ class RSSM(nn.Module):
|
||||
return loss, value, dyn_loss, rep_loss
|
||||
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
class MultiEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
grayscale=False,
|
||||
depth=32,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
kernels=(3, 3, 3, 3),
|
||||
shapes,
|
||||
mlp_keys,
|
||||
cnn_keys,
|
||||
act,
|
||||
norm,
|
||||
cnn_depth,
|
||||
cnn_kernels,
|
||||
mlp_layers,
|
||||
mlp_units,
|
||||
symlog_inputs,
|
||||
):
|
||||
super(MultiEncoder, self).__init__()
|
||||
self.cnn_shapes = {
|
||||
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||
}
|
||||
self.mlp_shapes = {
|
||||
k: v
|
||||
for k, v in shapes.items()
|
||||
if len(v) in (1, 2) and re.match(mlp_keys, k)
|
||||
}
|
||||
print("Encoder CNN shapes:", self.cnn_shapes)
|
||||
print("Encoder MLP shapes:", self.mlp_shapes)
|
||||
|
||||
self.outdim = 0
|
||||
if self.cnn_shapes:
|
||||
input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
|
||||
self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels)
|
||||
self.outdim += self._cnn.outdim
|
||||
if self.mlp_shapes:
|
||||
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
|
||||
self._mlp = MLP(
|
||||
input_size,
|
||||
None,
|
||||
mlp_layers,
|
||||
mlp_units,
|
||||
act,
|
||||
norm,
|
||||
symlog_inputs=symlog_inputs,
|
||||
)
|
||||
self.outdim += mlp_units
|
||||
|
||||
def forward(self, obs):
|
||||
outputs = []
|
||||
if self.cnn_shapes:
|
||||
inputs = torch.cat([obs[k] for k in self.cnn_shapes], -1)
|
||||
outputs.append(self._cnn(inputs))
|
||||
if self.mlp_shapes:
|
||||
inputs = torch.cat([obs[k] for k in self.mlp_shapes], -1)
|
||||
outputs.append(self._mlp(inputs))
|
||||
outputs = torch.cat(outputs, -1)
|
||||
return outputs
|
||||
|
||||
|
||||
class MultiDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feat_size,
|
||||
shapes,
|
||||
mlp_keys,
|
||||
cnn_keys,
|
||||
act,
|
||||
norm,
|
||||
cnn_depth,
|
||||
cnn_kernels,
|
||||
mlp_layers,
|
||||
mlp_units,
|
||||
cnn_sigmoid,
|
||||
image_dist,
|
||||
vector_dist,
|
||||
):
|
||||
super(MultiDecoder, self).__init__()
|
||||
self.cnn_shapes = {
|
||||
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||
}
|
||||
self.mlp_shapes = {
|
||||
k: v
|
||||
for k, v in shapes.items()
|
||||
if len(v) in (1, 2) and re.match(mlp_keys, k)
|
||||
}
|
||||
print("Decoder CNN shapes:", self.cnn_shapes)
|
||||
print("Decoder MLP shapes:", self.mlp_shapes)
|
||||
|
||||
if self.cnn_shapes:
|
||||
some_shape = list(self.cnn_shapes.values())[0]
|
||||
shape = (sum(x[-1] for x in self.cnn_shapes.values()),) + some_shape[:-1]
|
||||
self._cnn = ConvDecoder(
|
||||
feat_size,
|
||||
shape,
|
||||
cnn_depth,
|
||||
act,
|
||||
norm,
|
||||
cnn_kernels,
|
||||
cnn_sigmoid=cnn_sigmoid,
|
||||
)
|
||||
if self.mlp_shapes:
|
||||
self._mlp = MLP(
|
||||
feat_size,
|
||||
self.mlp_shapes,
|
||||
mlp_layers,
|
||||
mlp_units,
|
||||
act,
|
||||
norm,
|
||||
vector_dist,
|
||||
)
|
||||
self._image_dist = image_dist
|
||||
|
||||
def forward(self, features):
|
||||
dists = {}
|
||||
if self.cnn_shapes:
|
||||
feat = features
|
||||
outputs = self._cnn(feat)
|
||||
split_sizes = [v[-1] for v in self.cnn_shapes.values()]
|
||||
outputs = torch.split(outputs, split_sizes, -1)
|
||||
dists.update(
|
||||
{
|
||||
key: self._make_image_dist(output)
|
||||
for key, output in zip(self.cnn_shapes.keys(), outputs)
|
||||
}
|
||||
)
|
||||
if self.mlp_shapes:
|
||||
dists.update(self._mlp(features))
|
||||
return dists
|
||||
|
||||
def _make_image_dist(self, mean):
|
||||
if self._image_dist == "normal":
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(torchd.normal.Normal(mean, 1), 3)
|
||||
)
|
||||
if self._image_dist == "mse":
|
||||
return tools.MSEDist(mean)
|
||||
raise NotImplementedError(self._image_dist)
|
||||
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3)
|
||||
):
|
||||
super(ConvEncoder, self).__init__()
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._depth = depth
|
||||
self._kernels = kernels
|
||||
h, w = 64, 64
|
||||
layers = []
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
if i == 0:
|
||||
if grayscale:
|
||||
inp_dim = 1
|
||||
else:
|
||||
inp_dim = 3
|
||||
inp_dim = input_ch
|
||||
else:
|
||||
inp_dim = 2 ** (i - 1) * self._depth
|
||||
depth = 2**i * self._depth
|
||||
@@ -365,37 +494,42 @@ class ConvEncoder(nn.Module):
|
||||
layers.append(act())
|
||||
h, w = h // 2, w // 2
|
||||
|
||||
self.outdim = depth * h * w
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
def __call__(self, obs):
|
||||
x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:]))
|
||||
def forward(self, obs):
|
||||
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
|
||||
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
|
||||
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
# prod: product of all elements
|
||||
# (batch * time, ...) -> (batch * time, -1)
|
||||
x = x.reshape([x.shape[0], np.prod(x.shape[1:])])
|
||||
shape = list(obs["image"].shape[:-3]) + [x.shape[-1]]
|
||||
return x.reshape(shape)
|
||||
# (batch * time, -1) -> (batch, time, -1)
|
||||
return x.reshape(list(obs.shape[:-3]) + [x.shape[-1]])
|
||||
|
||||
|
||||
class ConvDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inp_depth,
|
||||
shape=(3, 64, 64),
|
||||
depth=32,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
shape=(3, 64, 64),
|
||||
kernels=(3, 3, 3, 3),
|
||||
outscale=1.0,
|
||||
cnn_sigmoid=False,
|
||||
):
|
||||
super(ConvDecoder, self).__init__()
|
||||
self._inp_depth = inp_depth
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._depth = depth
|
||||
self._shape = shape
|
||||
self._kernels = kernels
|
||||
self._cnn_sigmoid = cnn_sigmoid
|
||||
self._embed_size = (
|
||||
(64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)
|
||||
)
|
||||
@@ -407,7 +541,6 @@ class ConvDecoder(nn.Module):
|
||||
h, w = 4, 4
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
depth = self._embed_size // 16 // (2 ** (i + 1))
|
||||
act = self._act
|
||||
bias = False
|
||||
initializer = tools.weight_init
|
||||
if i == len(self._kernels) - 1:
|
||||
@@ -447,88 +580,125 @@ class ConvDecoder(nn.Module):
|
||||
outpad = pad * 2 - val
|
||||
return pad, outpad
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
def forward(self, features, dtype=None):
|
||||
x = self._linear_layer(features)
|
||||
# (batch, time, -1) -> (batch * time, h, w, ch)
|
||||
x = x.reshape([-1, 4, 4, self._embed_size // 16])
|
||||
# (batch, time, -1) -> (batch * time, ch, h, w)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
|
||||
mean = x.reshape(features.shape[:-1] + self._shape)
|
||||
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
|
||||
mean = mean.permute(0, 1, 3, 4, 2)
|
||||
return tools.SymlogDist(mean)
|
||||
if self._cnn_sigmoid:
|
||||
mean = F.sigmoid(mean) - 0.5
|
||||
return mean
|
||||
|
||||
|
||||
class DenseHead(nn.Module):
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inp_dim,
|
||||
shape,
|
||||
layers,
|
||||
units,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
outscale=1.0,
|
||||
symlog_inputs=False,
|
||||
device="cuda",
|
||||
):
|
||||
super(DenseHead, self).__init__()
|
||||
super(MLP, self).__init__()
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
if len(self._shape) == 0:
|
||||
if self._shape is not None and len(self._shape) == 0:
|
||||
self._shape = (1,)
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._symlog_inputs = symlog_inputs
|
||||
self._device = device
|
||||
|
||||
layers = []
|
||||
for index in range(self._layers):
|
||||
layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
||||
layers.append(norm(self._units, eps=1e-03))
|
||||
layers.append(nn.Linear(inp_dim, units, bias=False))
|
||||
layers.append(norm(units, eps=1e-03))
|
||||
layers.append(act())
|
||||
if index == 0:
|
||||
inp_dim = self._units
|
||||
inp_dim = units
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
if isinstance(self._shape, dict):
|
||||
self.mean_layer = nn.ModuleDict()
|
||||
for name, shape in self._shape.items():
|
||||
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
if self._std == "learned":
|
||||
self.std_layer = nn.ModuleDict()
|
||||
for name, shape in self._shape.items():
|
||||
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||
elif self._shape is not None:
|
||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
if self._std == "learned":
|
||||
self.std_layer = nn.Linear(units, np.prod(self._shape))
|
||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
if self._std == "learned":
|
||||
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
|
||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
def forward(self, features, dtype=None):
|
||||
x = features
|
||||
if self._symlog_inputs:
|
||||
x = tools.symlog(x)
|
||||
out = self.layers(x)
|
||||
mean = self.mean_layer(out)
|
||||
if self._std == "learned":
|
||||
std = self.std_layer(out)
|
||||
if self._shape is None:
|
||||
return out
|
||||
if isinstance(self._shape, dict):
|
||||
dists = {}
|
||||
for name, shape in self._shape.items():
|
||||
mean = self.mean_layer[name](out)
|
||||
if self._std == "learned":
|
||||
std = self.std_layer[name](out)
|
||||
else:
|
||||
std = self._std
|
||||
dists.update({name: self.dist(self._dist, mean, std, shape)})
|
||||
return dists
|
||||
else:
|
||||
std = self._std
|
||||
if self._dist == "normal":
|
||||
mean = self.mean_layer(out)
|
||||
if self._std == "learned":
|
||||
std = self.std_layer(out)
|
||||
else:
|
||||
std = self._std
|
||||
return self.dist(self._dist, mean, std, self._shape)
|
||||
|
||||
def dist(self, dist, mean, std, shape):
|
||||
if dist == "normal":
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
torchd.normal.Normal(mean, std), len(self._shape)
|
||||
torchd.normal.Normal(mean, std), len(shape)
|
||||
)
|
||||
)
|
||||
if self._dist == "huber":
|
||||
if dist == "huber":
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape)
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||
)
|
||||
)
|
||||
if self._dist == "binary":
|
||||
if dist == "binary":
|
||||
return tools.Bernoulli(
|
||||
torchd.independent.Independent(
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
||||
)
|
||||
)
|
||||
if self._dist == "twohot_symlog":
|
||||
return tools.TwoHotDistSymlog(logits=mean, device=self._device)
|
||||
raise NotImplementedError(self._dist)
|
||||
if dist == "symlog_disc":
|
||||
return tools.DiscDist(logits=mean, device=self._device)
|
||||
if dist == "symlog_mse":
|
||||
return tools.SymlogDist(mean)
|
||||
raise NotImplementedError(dist)
|
||||
|
||||
|
||||
class ActionHead(nn.Module):
|
||||
@@ -553,8 +723,8 @@ class ActionHead(nn.Module):
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._dist = dist
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._init_std = init_std
|
||||
@@ -579,7 +749,7 @@ class ActionHead(nn.Module):
|
||||
self._dist_layer = nn.Linear(self._units, self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
def forward(self, features, dtype=None):
|
||||
x = features
|
||||
x = self._pre_layers(x)
|
||||
if self._dist == "tanh_normal":
|
||||
|
||||
Reference in New Issue
Block a user