modification of expl.

This commit is contained in:
NM512
2023-05-21 08:17:47 +09:00
parent b8ef214efa
commit 02c3d45fcf
3 changed files with 28 additions and 16 deletions

View File

@@ -8,22 +8,35 @@ import tools
class Random(nn.Module):
def __init__(self, config):
def __init__(self, config, act_space):
super(Random, self).__init__()
self._config = config
self._act_space = act_space
def actor(self, feat):
shape = feat.shape[:-1] + [self._config.num_actions]
if self._config.actor_dist == "onehot":
return tools.OneHotDist(torch.zeros(shape))
return tools.OneHotDist(
torch.zeros(self._config.num_actions)
.repeat(self._config.envs, 1)
.to(self._config.device)
)
else:
ones = torch.ones(shape)
return tools.ContDist(torchd.uniform.Uniform(-ones, ones))
return torchd.independent.Independent(
torchd.uniform.Uniform(
torch.Tensor(self._act_space.low)
.repeat(self._config.envs, 1)
.to(self._config.device),
torch.Tensor(self._act_space.high)
.repeat(self._config.envs, 1)
.to(self._config.device),
),
1,
)
def train(self, start, context):
def train(self, start, context, data):
return None, {}
# class Plan2Explore(tools.Module):
class Plan2Explore(nn.Module):
def __init__(self, config, world_model, reward=None):
super(Plan2Explore, self).__init__()
@@ -39,7 +52,7 @@ class Plan2Explore(nn.Module):
feat_size = config.dyn_stoch + config.dyn_deter
stoch = config.dyn_stoch
size = {
"embed": 32 * config.cnn_depth,
"embed": world_model.embed_size,
"stoch": stoch,
"deter": config.dyn_deter,
"feat": config.dyn_stoch + config.dyn_deter,