Files
tdmpc2/tdmpc2/tdmpc2.py
2024-01-07 11:52:53 -08:00

291 lines
10 KiB
Python
Executable File

import numpy as np
import torch
import torch.nn.functional as F
from common import math
from common.scale import RunningScale
from common.world_model import WorldModel
class TDMPC2:
"""
TD-MPC2 agent. Implements training + inference.
Can be used for both single-task and multi-task experiments,
and supports both state and pixel observations.
"""
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device(cfg.rank)
self.model = WorldModel(cfg)
self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()},
{'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
], lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5)
self.model.eval()
self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
def _get_discount(self, episode_length):
"""
Returns discount factor for a given episode length.
Simple heuristic that scales discount linearly with episode length.
Default values should work well for most tasks, but can be changed as needed.
Args:
episode_length (int): Length of the episode. Assumes episodes are of fixed length.
Returns:
float: Discount factor for the task.
"""
frac = episode_length/self.cfg.discount_denom
return min(max((frac-1)/(frac), self.cfg.discount_min), self.cfg.discount_max)
def save(self, fp):
"""
Save state dict of the agent to filepath.
Args:
fp (str): Filepath to save state dict to.
"""
torch.save({"model": self.model.state_dict()}, fp)
def load(self, fp):
"""
Load a saved state dict from filepath (or dictionary) into current agent.
Args:
fp (str or dict): Filepath or state dict to load.
"""
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
self.model.load_state_dict(state_dict["model"])
@torch.no_grad()
def act(self, obs, t0=False, eval_mode=False, task=None):
"""
Select an action by planning in the latent space of the world model.
Args:
obs (torch.Tensor): Observation from the environment.
t0 (bool): Whether this is the first observation in the episode.
eval_mode (bool): Whether to use the mean of the action distribution.
task (int): Task index (only used for multi-task experiments).
Returns:
torch.Tensor: Action to take in the environment.
"""
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
if task is not None:
task = torch.tensor([task], device=self.device)
z = self.model.encode(obs, task)
if self.cfg.mpc:
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
else:
a = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu()
@torch.no_grad()
def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G += discount * reward
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
@torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None):
"""
Plan a sequence of actions using the learned world model.
Args:
z (torch.Tensor): Latent state from which to plan.
t0 (bool): Whether this is the first observation in the episode.
eval_mode (bool): Whether to use the mean of the action distribution.
task (Torch.Tensor): Task index (only used for multi-task experiments).
Returns:
torch.Tensor: Action to take in the environment.
"""
# Sample policy trajectories
if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1]
_z = self.model.next(_z, pi_actions[t], task)
pi_actions[-1] = self.model.pi(_z, task)[1]
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device)
if not t0:
mean[:-1] = self._prev_mean[1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI
for _ in range(self.cfg.iterations):
# Sample actions
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \
.clamp(-1, 1)
if self.cfg.multitask:
actions = actions * self.model._action_masks[task]
# Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num_(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
score /= score.sum(0)
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \
.clamp_(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask:
mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task]
# Select action
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean
a, std = actions[0], std[0]
if not eval_mode:
a += std * torch.randn(self.cfg.action_dim, device=std.device)
return a.clamp_(-1, 1)
def update_pi(self, zs, task):
"""
Update policy using a sequence of latent states.
Args:
zs (torch.Tensor): Sequence of latent states.
task (torch.Tensor): Task index (only used for multi-task experiments).
Returns:
float: Loss of the policy update.
"""
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
_, pis, log_pis, _ = self.model.pi(zs, task)
qs = self.model.Q(zs, pis, task, return_type='avg')
self.scale.update(qs[0])
qs = self.scale(qs)
# Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step()
self.model.track_q_grad(True)
return pi_loss.item()
@torch.no_grad()
def _td_target(self, next_z, reward, task):
"""
Compute the TD-target from a reward and the observation at the following time step.
Args:
next_z (torch.Tensor): Latent state at the following time step.
reward (torch.Tensor): Reward at the current time step.
task (torch.Tensor): Task index (only used for multi-task experiments).
Returns:
torch.Tensor: TD-target.
"""
pi = self.model.pi(next_z, task)[1]
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
def update(self, buffer):
"""
Main update function. Corresponds to one iteration of model learning.
Args:
buffer (common.buffer.Buffer): Replay buffer.
Returns:
dict: Dictionary of training statistics.
"""
obs, action, reward, task = buffer.sample()
# Compute targets
with torch.no_grad():
next_z = self.model.encode(obs[1:], task)
td_targets = self._td_target(next_z, reward, task)
# Prepare for update
self.optim.zero_grad(set_to_none=True)
self.model.train()
# Latent rollout
zs = torch.empty(self.cfg.horizon+1, self.cfg.batch_size, self.cfg.latent_dim, device=self.device)
z = self.model.encode(obs[0], task)
zs[0] = z
consistency_loss = 0
for t in range(self.cfg.horizon):
z = self.model.next(z, action[t], task)
consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t
zs[t+1] = z
# Predictions
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
# Compute losses
reward_loss, value_loss = 0, 0
for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon)
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
total_loss = (
self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss +
self.cfg.value_coef * value_loss
)
# Update model
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
self.optim.step()
# Update policy
pi_loss = self.update_pi(zs.detach(), task)
# Update target Q-functions
self.model.soft_update_target_Q()
# Return training statistics
self.model.eval()
return {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"value_loss": float(value_loss.mean().item()),
"pi_loss": pi_loss,
"total_loss": float(total_loss.mean().item()),
"grad_norm": float(grad_norm),
"pi_scale": float(self.scale.value),
}