added benchmark task Crafter
This commit is contained in:
@@ -347,7 +347,11 @@ class MultiEncoder(nn.Module):
|
||||
):
|
||||
super(MultiEncoder, self).__init__()
|
||||
excluded = ("is_first", "is_last", "is_terminal", "reward")
|
||||
shapes = {k: v for k, v in shapes.items() if k not in excluded}
|
||||
shapes = {
|
||||
k: v
|
||||
for k, v in shapes.items()
|
||||
if k not in excluded and not k.startswith("log_")
|
||||
}
|
||||
self.cnn_shapes = {
|
||||
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user