erased unused options

This commit is contained in:
NM512
2024-01-05 23:23:09 +09:00
parent a27711ab96
commit 7f66ed5333
6 changed files with 84 additions and 211 deletions

View File

@@ -1,8 +1,6 @@
import copy
import torch
from torch import nn
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont
import networks
import tools
@@ -10,21 +8,21 @@ import tools
to_np = lambda x: x.detach().cpu().numpy()
class RewardEMA(object):
class RewardEMA:
"""running mean and std"""
def __init__(self, device, alpha=1e-2):
self.device = device
self.values = torch.zeros((2,)).to(device)
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x):
def __call__(self, x, ema_vals):
flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
offset = self.values[0]
# this should be in-place operation
ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
offset = ema_vals[0]
return offset.detach(), scale.detach()
@@ -41,18 +39,13 @@ class WorldModel(nn.Module):
config.dyn_stoch,
config.dyn_deter,
config.dyn_hidden,
config.dyn_input_layers,
config.dyn_output_layers,
config.dyn_rec_depth,
config.dyn_shared,
config.dyn_discrete,
config.act,
config.norm,
config.dyn_mean_act,
config.dyn_std_act,
config.dyn_temp_post,
config.dyn_min_std,
config.dyn_cell,
config.unimix_ratio,
config.initial,
config.num_actions,
@@ -106,10 +99,10 @@ class WorldModel(nn.Module):
print(
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
)
# other losses are scaled by 1.0.
self._scales = dict(
reward=config.reward_head["scale"],
cont=config.cont_head["scale"],
image=1.0,
reward=config.reward_head["loss_scale"],
cont=config.cont_head["loss_scale"],
)
def _train(self, data):
@@ -148,7 +141,8 @@ class WorldModel(nn.Module):
assert loss.shape == embed.shape[:2], (name, loss.shape)
losses[name] = loss
scaled = {
key: value * self._scales[key] for key, value in losses.items()
key: value * self._scales.get(key, 1.0)
for key, value in losses.items()
}
model_loss = sum(scaled.values()) + kl_loss
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
@@ -217,13 +211,11 @@ class WorldModel(nn.Module):
class ImagBehavior(nn.Module):
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
def __init__(self, config, world_model):
super(ImagBehavior, self).__init__()
self._use_amp = True if config.precision == 16 else False
self._config = config
self._world_model = world_model
self._stop_grad_actor = stop_grad_actor
self._reward = reward
if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else:
@@ -284,42 +276,34 @@ class ImagBehavior(nn.Module):
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
)
if self._config.reward_EMA:
# register ema_vals to nn.Module for enabling torch.save and torch.load
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
self.reward_ema = RewardEMA(device=self._config.device)
def _train(
self,
start,
objective=None,
action=None,
reward=None,
imagine=None,
tape=None,
repeats=None,
objective,
):
objective = objective or self._reward
self._update_slow_target()
metrics = {}
with tools.RequiresGrad(self.actor):
with torch.cuda.amp.autocast(self._use_amp):
imag_feat, imag_state, imag_action = self._imagine(
start, self.actor, self._config.imag_horizon, repeats
start, self.actor, self._config.imag_horizon
)
reward = objective(imag_feat, imag_state, imag_action)
actor_ent = self.actor(imag_feat).entropy()
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
# this target is not scaled
# slow is flag to indicate whether slow_target is used for lambda-return
# this target is not scaled by ema or sym_log.
target, weights, base = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
imag_feat, imag_state, reward
)
actor_loss, mets = self._compute_actor_loss(
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
)
@@ -357,33 +341,27 @@ class ImagBehavior(nn.Module):
metrics.update(self._value_opt(value_loss, self.value.parameters()))
return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None):
def _imagine(self, start, policy, horizon):
dynamics = self._world_model.dynamics
if repeats:
raise NotImplemented("repeats is not implemented in this version")
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = dynamics.get_feat(state)
inp = feat.detach() if self._stop_grad_actor else feat
inp = feat.detach()
action = policy(inp).sample()
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
succ = dynamics.img_step(state, action)
return succ, feat, action
succ, feats, actions = tools.static_scan(
step, [torch.arange(horizon)], (start, None, None)
)
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
if repeats:
raise NotImplemented("repeats is not implemented in this version")
return feats, states, actions
def _compute_target(
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
):
def _compute_target(self, imag_feat, imag_state, reward):
if "cont" in self._world_model.heads:
inp = self._world_model.dynamics.get_feat(imag_state)
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
@@ -406,29 +384,24 @@ class ImagBehavior(nn.Module):
def _compute_actor_loss(
self,
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
):
metrics = {}
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
inp = imag_feat.detach()
policy = self.actor(inp)
actor_ent = policy.entropy()
# Q-val for actor is not transformed using symlog
target = torch.stack(target, dim=1)
if self._config.reward_EMA:
offset, scale = self.reward_ema(target)
offset, scale = self.reward_ema(target, self.ema_vals)
normed_target = (target - offset) / scale
normed_base = (base - offset) / scale
adv = normed_target - normed_base
metrics.update(tools.tensorstats(normed_target, "normed_target"))
values = self.reward_ema.values
metrics["EMA_005"] = to_np(values[0])
metrics["EMA_095"] = to_np(values[1])
metrics["EMA_005"] = to_np(self.ema_vals[0])
metrics["EMA_095"] = to_np(self.ema_vals[1])
if self._config.imag_gradient == "dynamics":
actor_target = adv