bug fix for gym==0.19.0

This commit is contained in:
NM512
2023-05-18 21:30:08 +09:00
parent d3156ecb06
commit b8ef214efa
3 changed files with 5 additions and 2 deletions

View File

@@ -343,6 +343,8 @@ class MultiEncoder(nn.Module):
symlog_inputs,
):
super(MultiEncoder, self).__init__()
excluded = ("is_first", "is_last", "is_terminal", "reward")
shapes = {k: v for k, v in shapes.items() if k not in excluded}
self.cnn_shapes = {
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
}
@@ -402,6 +404,8 @@ class MultiDecoder(nn.Module):
vector_dist,
):
super(MultiDecoder, self).__init__()
excluded = ("is_first", "is_last", "is_terminal", "reward")
shapes = {k: v for k, v in shapes.items() if k not in excluded}
self.cnn_shapes = {
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
}