erased unused options
This commit is contained in:
79
models.py
79
models.py
@@ -1,8 +1,6 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from PIL import ImageColor, Image, ImageDraw, ImageFont
|
||||
|
||||
import networks
|
||||
import tools
|
||||
@@ -10,21 +8,21 @@ import tools
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
|
||||
class RewardEMA(object):
|
||||
class RewardEMA:
|
||||
"""running mean and std"""
|
||||
|
||||
def __init__(self, device, alpha=1e-2):
|
||||
self.device = device
|
||||
self.values = torch.zeros((2,)).to(device)
|
||||
self.alpha = alpha
|
||||
self.range = torch.tensor([0.05, 0.95]).to(device)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x, ema_vals):
|
||||
flat_x = torch.flatten(x.detach())
|
||||
x_quantile = torch.quantile(input=flat_x, q=self.range)
|
||||
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
|
||||
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
|
||||
offset = self.values[0]
|
||||
# this should be in-place operation
|
||||
ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
|
||||
scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
|
||||
offset = ema_vals[0]
|
||||
return offset.detach(), scale.detach()
|
||||
|
||||
|
||||
@@ -41,18 +39,13 @@ class WorldModel(nn.Module):
|
||||
config.dyn_stoch,
|
||||
config.dyn_deter,
|
||||
config.dyn_hidden,
|
||||
config.dyn_input_layers,
|
||||
config.dyn_output_layers,
|
||||
config.dyn_rec_depth,
|
||||
config.dyn_shared,
|
||||
config.dyn_discrete,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.dyn_mean_act,
|
||||
config.dyn_std_act,
|
||||
config.dyn_temp_post,
|
||||
config.dyn_min_std,
|
||||
config.dyn_cell,
|
||||
config.unimix_ratio,
|
||||
config.initial,
|
||||
config.num_actions,
|
||||
@@ -106,10 +99,10 @@ class WorldModel(nn.Module):
|
||||
print(
|
||||
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||
)
|
||||
# other losses are scaled by 1.0.
|
||||
self._scales = dict(
|
||||
reward=config.reward_head["scale"],
|
||||
cont=config.cont_head["scale"],
|
||||
image=1.0,
|
||||
reward=config.reward_head["loss_scale"],
|
||||
cont=config.cont_head["loss_scale"],
|
||||
)
|
||||
|
||||
def _train(self, data):
|
||||
@@ -148,7 +141,8 @@ class WorldModel(nn.Module):
|
||||
assert loss.shape == embed.shape[:2], (name, loss.shape)
|
||||
losses[name] = loss
|
||||
scaled = {
|
||||
key: value * self._scales[key] for key, value in losses.items()
|
||||
key: value * self._scales.get(key, 1.0)
|
||||
for key, value in losses.items()
|
||||
}
|
||||
model_loss = sum(scaled.values()) + kl_loss
|
||||
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
||||
@@ -217,13 +211,11 @@ class WorldModel(nn.Module):
|
||||
|
||||
|
||||
class ImagBehavior(nn.Module):
|
||||
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
|
||||
def __init__(self, config, world_model):
|
||||
super(ImagBehavior, self).__init__()
|
||||
self._use_amp = True if config.precision == 16 else False
|
||||
self._config = config
|
||||
self._world_model = world_model
|
||||
self._stop_grad_actor = stop_grad_actor
|
||||
self._reward = reward
|
||||
if config.dyn_discrete:
|
||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||
else:
|
||||
@@ -284,42 +276,34 @@ class ImagBehavior(nn.Module):
|
||||
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
||||
)
|
||||
if self._config.reward_EMA:
|
||||
# register ema_vals to nn.Module for enabling torch.save and torch.load
|
||||
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
|
||||
self.reward_ema = RewardEMA(device=self._config.device)
|
||||
|
||||
def _train(
|
||||
self,
|
||||
start,
|
||||
objective=None,
|
||||
action=None,
|
||||
reward=None,
|
||||
imagine=None,
|
||||
tape=None,
|
||||
repeats=None,
|
||||
objective,
|
||||
):
|
||||
objective = objective or self._reward
|
||||
self._update_slow_target()
|
||||
metrics = {}
|
||||
|
||||
with tools.RequiresGrad(self.actor):
|
||||
with torch.cuda.amp.autocast(self._use_amp):
|
||||
imag_feat, imag_state, imag_action = self._imagine(
|
||||
start, self.actor, self._config.imag_horizon, repeats
|
||||
start, self.actor, self._config.imag_horizon
|
||||
)
|
||||
reward = objective(imag_feat, imag_state, imag_action)
|
||||
actor_ent = self.actor(imag_feat).entropy()
|
||||
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
|
||||
# this target is not scaled
|
||||
# slow is flag to indicate whether slow_target is used for lambda-return
|
||||
# this target is not scaled by ema or sym_log.
|
||||
target, weights, base = self._compute_target(
|
||||
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||
imag_feat, imag_state, reward
|
||||
)
|
||||
actor_loss, mets = self._compute_actor_loss(
|
||||
imag_feat,
|
||||
imag_state,
|
||||
imag_action,
|
||||
target,
|
||||
actor_ent,
|
||||
state_ent,
|
||||
weights,
|
||||
base,
|
||||
)
|
||||
@@ -357,33 +341,27 @@ class ImagBehavior(nn.Module):
|
||||
metrics.update(self._value_opt(value_loss, self.value.parameters()))
|
||||
return imag_feat, imag_state, imag_action, weights, metrics
|
||||
|
||||
def _imagine(self, start, policy, horizon, repeats=None):
|
||||
def _imagine(self, start, policy, horizon):
|
||||
dynamics = self._world_model.dynamics
|
||||
if repeats:
|
||||
raise NotImplemented("repeats is not implemented in this version")
|
||||
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
|
||||
start = {k: flatten(v) for k, v in start.items()}
|
||||
|
||||
def step(prev, _):
|
||||
state, _, _ = prev
|
||||
feat = dynamics.get_feat(state)
|
||||
inp = feat.detach() if self._stop_grad_actor else feat
|
||||
inp = feat.detach()
|
||||
action = policy(inp).sample()
|
||||
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
|
||||
succ = dynamics.img_step(state, action)
|
||||
return succ, feat, action
|
||||
|
||||
succ, feats, actions = tools.static_scan(
|
||||
step, [torch.arange(horizon)], (start, None, None)
|
||||
)
|
||||
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
|
||||
if repeats:
|
||||
raise NotImplemented("repeats is not implemented in this version")
|
||||
|
||||
return feats, states, actions
|
||||
|
||||
def _compute_target(
|
||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||
):
|
||||
def _compute_target(self, imag_feat, imag_state, reward):
|
||||
if "cont" in self._world_model.heads:
|
||||
inp = self._world_model.dynamics.get_feat(imag_state)
|
||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||
@@ -406,29 +384,24 @@ class ImagBehavior(nn.Module):
|
||||
def _compute_actor_loss(
|
||||
self,
|
||||
imag_feat,
|
||||
imag_state,
|
||||
imag_action,
|
||||
target,
|
||||
actor_ent,
|
||||
state_ent,
|
||||
weights,
|
||||
base,
|
||||
):
|
||||
metrics = {}
|
||||
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
|
||||
inp = imag_feat.detach()
|
||||
policy = self.actor(inp)
|
||||
actor_ent = policy.entropy()
|
||||
# Q-val for actor is not transformed using symlog
|
||||
target = torch.stack(target, dim=1)
|
||||
if self._config.reward_EMA:
|
||||
offset, scale = self.reward_ema(target)
|
||||
offset, scale = self.reward_ema(target, self.ema_vals)
|
||||
normed_target = (target - offset) / scale
|
||||
normed_base = (base - offset) / scale
|
||||
adv = normed_target - normed_base
|
||||
metrics.update(tools.tensorstats(normed_target, "normed_target"))
|
||||
values = self.reward_ema.values
|
||||
metrics["EMA_005"] = to_np(values[0])
|
||||
metrics["EMA_095"] = to_np(values[1])
|
||||
metrics["EMA_005"] = to_np(self.ema_vals[0])
|
||||
metrics["EMA_095"] = to_np(self.ema_vals[1])
|
||||
|
||||
if self._config.imag_gradient == "dynamics":
|
||||
actor_target = adv
|
||||
|
||||
Reference in New Issue
Block a user