modified based on author's implementation

This commit is contained in:
NM512
2023-03-18 08:38:23 +09:00
parent a678a509b9
commit 6273444394
6 changed files with 371 additions and 229 deletions

158
models.py
View File

@@ -10,30 +10,22 @@ import tools
to_np = lambda x: x.detach().cpu().numpy()
def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
def symexp(x):
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
class RewardEMA(object):
"""running mean and std"""
def __init__(self, device, alpha=1e-2):
self.device = device
self.scale = torch.zeros((1,)).to(device)
self.values = torch.zeros((2,)).to(device)
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x):
flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range)
scale = x_quantile[1] - x_quantile[0]
new_scale = self.alpha * scale + (1 - self.alpha) * self.scale
self.scale = new_scale
return x / torch.clip(self.scale, min=1.0)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
offset = self.values[0]
return offset.detach(), scale.detach()
class WorldModel(nn.Module):
@@ -93,7 +85,7 @@ class WorldModel(nn.Module):
shape,
config.decoder_kernels,
)
if config.reward_head == "twohot":
if config.reward_head == "twohot_symlog":
self.heads["reward"] = networks.DenseHead(
feat_size, # pytorch version
(255,),
@@ -102,6 +94,7 @@ class WorldModel(nn.Module):
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
)
else:
self.heads["reward"] = networks.DenseHead(
@@ -112,9 +105,8 @@ class WorldModel(nn.Module):
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
)
# added this
self.heads["reward"].apply(tools.weight_init)
if config.pred_discount:
self.heads["discount"] = networks.DenseHead(
feat_size, # pytorch version
@@ -163,8 +155,6 @@ class WorldModel(nn.Module):
feat = self.dynamics.get_feat(post)
feat = feat if grad_head else feat.detach()
pred = head(feat)
# if name == 'image':
# losses[name] = torch.nn.functional.mse_loss(pred.mode(), data[name], 'sum')
like = pred.log_prob(data[name])
likes[name] = like
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
@@ -196,24 +186,9 @@ class WorldModel(nn.Module):
def preprocess(self, obs):
obs = obs.copy()
if self._config.obs_trans == "normalize":
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
elif self._config.obs_trans == "identity":
obs["image"] = torch.Tensor(obs["image"])
elif self._config.obs_trans == "symlog":
obs["image"] = symlog(torch.Tensor(obs["image"]))
else:
raise NotImplemented(f"{self._config.reward_trans} is not implemented")
if self._config.reward_trans == "tanh":
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.tanh(torch.Tensor(obs["reward"])).unsqueeze(-1)
elif self._config.reward_trans == "identity":
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
elif self._config.reward_trans == "symlog":
obs["reward"] = symlog(torch.Tensor(obs["reward"])).unsqueeze(-1)
else:
raise NotImplemented(f"{self._config.reward_trans} is not implemented")
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
@@ -234,13 +209,9 @@ class WorldModel(nn.Module):
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1)
if self._config.obs_trans == "normalize":
truth = data["image"][:6] + 0.5
model += 0.5
elif self._config.obs_trans == "symlog":
truth = symexp(data["image"][:6]) / 255.0
model = symexp(model) / 255.0
error = (model - truth + 1) / 2
truth = data["image"][:6] + 0.5
model = model + 0.5
error = (model - truth + 1.0) / 2.0
return torch.cat([truth, model, error], 2)
@@ -267,11 +238,11 @@ class ImagBehavior(nn.Module):
config.actor_dist,
config.actor_init_std,
config.actor_min_std,
config.actor_dist,
config.actor_max_std,
config.actor_temp,
config.actor_outscale,
outscale=1.0,
) # action_dist -> action_disc?
if config.value_head == "twohot":
if config.value_head == "twohot_symlog":
self.value = networks.DenseHead(
feat_size, # pytorch version
(255,),
@@ -280,6 +251,7 @@ class ImagBehavior(nn.Module):
config.act,
config.norm,
config.value_head,
outscale=0.0,
)
else:
self.value = networks.DenseHead(
@@ -290,9 +262,9 @@ class ImagBehavior(nn.Module):
config.act,
config.norm,
config.value_head,
outscale=0.0,
)
self.value.apply(tools.weight_init)
if config.slow_value_target or config.slow_actor_target:
if config.slow_value_target:
self._slow_value = copy.deepcopy(self.value)
self._updates = 0
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
@@ -335,21 +307,12 @@ class ImagBehavior(nn.Module):
start, self.actor, self._config.imag_horizon, repeats
)
reward = objective(imag_feat, imag_state, imag_action)
if self._config.reward_trans == "symlog":
# rescale predicted reward by head['reward']
reward = symexp(reward)
actor_ent = self.actor(imag_feat).entropy()
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
# this target is not scaled
# slow is flag to indicate whether slow_target is used for lambda-return
target, weights = self._compute_target(
imag_feat,
imag_state,
imag_action,
reward,
actor_ent,
state_ent,
self._config.slow_actor_target,
target, weights, base = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
)
actor_loss, mets = self._compute_actor_loss(
imag_feat,
@@ -359,42 +322,31 @@ class ImagBehavior(nn.Module):
actor_ent,
state_ent,
weights,
base,
)
metrics.update(mets)
if self._config.slow_value_target != self._config.slow_actor_target:
target, weights = self._compute_target(
imag_feat,
imag_state,
imag_action,
reward,
actor_ent,
state_ent,
self._config.slow_value_target,
)
value_input = imag_feat
with tools.RequiresGrad(self.value):
with torch.cuda.amp.autocast(self._use_amp):
value = self.value(value_input[:-1].detach())
target = torch.stack(target, dim=1)
# only critic target is processed using symlog(not actor)
if self._config.critic_trans == "symlog":
metrics["unscaled_target_mean"] = to_np(torch.mean(target))
target = symlog(target)
# (time, batch, 1), (time, batch, 1) -> (time, batch)
value_loss = -value.log_prob(target.detach())
slow_target = self._slow_value(value_input[:-1].detach())
if self._config.slow_value_target:
value_loss = value_loss - value.log_prob(
slow_target.mode().detach()
)
if self._config.value_decay:
value_loss += self._config.value_decay * value.mode()
# (time, batch, 1), (time, batch, 1) -> (1,)
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
metrics["value_mean"] = to_np(torch.mean(value.mode()))
metrics["value_max"] = to_np(torch.max(value.mode()))
metrics["value_min"] = to_np(torch.min(value.mode()))
metrics["value_std"] = to_np(torch.std(value.mode()))
metrics["target_mean"] = to_np(torch.mean(target))
metrics["reward_mean"] = to_np(torch.mean(reward))
metrics["reward_std"] = to_np(torch.std(reward))
metrics.update(tools.tensorstats(value.mode(), "value"))
metrics.update(tools.tensorstats(target, "target"))
metrics.update(tools.tensorstats(reward, "imag_reward"))
metrics.update(tools.tensorstats(imag_action, "imag_action"))
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
with tools.RequiresGrad(self):
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
@@ -402,6 +354,11 @@ class ImagBehavior(nn.Module):
return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None):
# horizon: 15
# start = dict(stoch, deter, logit)
# start["stoch"] (16, 63, 32, 32)
# start["deter"] (16, 63, 512)
# start["logit"] (16, 63, 32, 32)
dynamics = self._world_model.dynamics
if repeats:
raise NotImplemented("repeats is not implemented in this version")
@@ -418,6 +375,8 @@ class ImagBehavior(nn.Module):
feat = 0 * dynamics.get_feat(start)
action = policy(feat).mode()
# Is this action deterministic or stochastic?
# action = policy(feat).sample()
succ, feats, actions = tools.static_scan(
step, [torch.arange(horizon)], (start, feat, action)
)
@@ -428,7 +387,7 @@ class ImagBehavior(nn.Module):
return feats, states, actions
def _compute_target(
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, slow
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
):
if "discount" in self._world_model.heads:
inp = self._world_model.dynamics.get_feat(imag_state)
@@ -439,13 +398,10 @@ class ImagBehavior(nn.Module):
reward += self._config.actor_entropy() * actor_ent
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
reward += self._config.actor_state_entropy() * state_ent
if slow:
value = self._slow_value(imag_feat).mode()
else:
value = self.value(imag_feat).mode()
if self._config.critic_trans == "symlog":
# After adding this line there is issue
value = symexp(value)
value = self.value(imag_feat).mode()
# value(15, 960, ch)
# action(15, 960, ch)
# discount(15, 960, ch)
target = tools.lambda_return(
reward[:-1],
value[:-1],
@@ -457,10 +413,18 @@ class ImagBehavior(nn.Module):
weights = torch.cumprod(
torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
).detach()
return target, weights
return target, weights, value[:-1]
def _compute_actor_loss(
self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights
self,
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
):
metrics = {}
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
@@ -469,11 +433,17 @@ class ImagBehavior(nn.Module):
# Q-val for actor is not transformed using symlog
target = torch.stack(target, dim=1)
if self._config.reward_EMA:
target = self.reward_ema(target)
metrics["EMA_scale"] = to_np(self.reward_ema.scale)
offset, scale = self.reward_ema(target)
normed_target = (target - offset) / scale
normed_base = (base - offset) / scale
adv = normed_target - normed_base
metrics.update(tools.tensorstats(normed_target, "normed_target"))
values = self.reward_ema.values
metrics["EMA_005"] = to_np(values[0])
metrics["EMA_095"] = to_np(values[1])
if self._config.imag_gradient == "dynamics":
actor_target = target
actor_target = adv
elif self._config.imag_gradient == "reinforce":
actor_target = (
policy.log_prob(imag_action)[:-1][:, :, None]
@@ -501,7 +471,7 @@ class ImagBehavior(nn.Module):
return actor_loss, metrics
def _update_slow_target(self):
if self._config.slow_value_target or self._config.slow_actor_target:
if self._config.slow_value_target:
if self._updates % self._config.slow_target_update == 0:
mix = self._config.slow_target_fraction
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):