modification of expl.
This commit is contained in:
@@ -8,22 +8,35 @@ import tools
|
||||
|
||||
|
||||
class Random(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, act_space):
|
||||
super(Random, self).__init__()
|
||||
self._config = config
|
||||
self._act_space = act_space
|
||||
|
||||
def actor(self, feat):
|
||||
shape = feat.shape[:-1] + [self._config.num_actions]
|
||||
if self._config.actor_dist == "onehot":
|
||||
return tools.OneHotDist(torch.zeros(shape))
|
||||
return tools.OneHotDist(
|
||||
torch.zeros(self._config.num_actions)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device)
|
||||
)
|
||||
else:
|
||||
ones = torch.ones(shape)
|
||||
return tools.ContDist(torchd.uniform.Uniform(-ones, ones))
|
||||
return torchd.independent.Independent(
|
||||
torchd.uniform.Uniform(
|
||||
torch.Tensor(self._act_space.low)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
torch.Tensor(self._act_space.high)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
def train(self, start, context):
|
||||
def train(self, start, context, data):
|
||||
return None, {}
|
||||
|
||||
|
||||
# class Plan2Explore(tools.Module):
|
||||
class Plan2Explore(nn.Module):
|
||||
def __init__(self, config, world_model, reward=None):
|
||||
super(Plan2Explore, self).__init__()
|
||||
@@ -39,7 +52,7 @@ class Plan2Explore(nn.Module):
|
||||
feat_size = config.dyn_stoch + config.dyn_deter
|
||||
stoch = config.dyn_stoch
|
||||
size = {
|
||||
"embed": 32 * config.cnn_depth,
|
||||
"embed": world_model.embed_size,
|
||||
"stoch": stoch,
|
||||
"deter": config.dyn_deter,
|
||||
"feat": config.dyn_stoch + config.dyn_deter,
|
||||
|
||||
Reference in New Issue
Block a user