Files
tdmpc2/tdmpc2/common/world_model.py
Nicklas Hansen b67b21c5c6 first commit
2023-10-25 18:26:00 -07:00

175 lines
5.6 KiB
Python

from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from common import layers, math, init
class WorldModel(nn.Module):
"""
TD-MPC2 implicit world model architecture.
Can be used for both single-task and multi-task experiments.
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.multitask:
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
for i in range(len(cfg.tasks)):
self._action_masks[i, :cfg.action_dims[i]] = 1.
self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min)
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
@property
def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def to(self, *args, **kwargs):
"""
Overriding `to` method to also move additional tensors to device.
"""
super().to(*args, **kwargs)
if self.cfg.multitask:
self._action_masks = self._action_masks.to(*args, **kwargs)
self.log_std_min = self.log_std_min.to(*args, **kwargs)
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
return self
def train(self, mode=True):
"""
Overriding `train` method to keep target Q-networks in eval mode.
"""
super().train(mode)
self._target_Qs.train(False)
return self
def track_q_grad(self, mode=True):
"""
Enables/disables gradient tracking of Q-networks.
Avoids unnecessary computation during policy optimization.
This method also enables/disables gradients for task embeddings,
and sets the dropout probability to 0 if `mode` is False.
"""
for p in self._Qs.parameters():
p.requires_grad_(mode)
if self.cfg.multitask:
for p in self._task_emb.parameters():
p.requires_grad_(mode)
for m in self._Qs.modules():
if isinstance(m, nn.Dropout):
m.p = self.cfg.dropout if mode else 0
def soft_update_target_Q(self):
"""
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
p_target.data.lerp_(p.data, self.cfg.tau)
def task_emb(self, x, task):
"""
Continuous task embedding for multi-task experiments.
Retrieves the task embedding for a given task ID `task`
and concatenates it to the input `x`.
"""
if isinstance(task, int):
task = torch.tensor([task], device=x.device)
emb = self._task_emb(task.long())
if x.ndim == 3:
emb = emb.unsqueeze(0).repeat(x.shape[0], 1, 1)
elif emb.shape[0] == 1:
emb = emb.repeat(x.shape[0], 1)
return torch.cat([x, emb], dim=-1)
def encode(self, obs, task):
"""
Encodes an observation into its latent representation.
This implementation assumes a single state-based observation.
"""
if self.cfg.multitask:
obs = self.task_emb(obs, task)
return self._encoder['state'](obs)
def next(self, z, a, task):
"""
Predicts the next latent state given the current latent state and action.
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1)
return self._dynamics(z)
def reward(self, z, a, task):
"""
Predicts instantaneous (single-step) reward.
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1)
return self._reward(z)
def pi(self, z, task):
"""
Samples an action from the policy prior.
The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network.
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
# Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1)
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
eps = torch.randn_like(mu)
if self.cfg.multitask: # Mask out unused action dimensions
mu = mu * self._action_masks[task]
log_std = log_std * self._action_masks[task]
eps = eps * self._action_masks[task]
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
else: # No masking
action_dims = None
log_pi = math.gaussian_logprob(eps, log_std, size=action_dims)
pi = mu + eps * log_std.exp()
mu, pi, log_pi = math.squash(mu, pi, log_pi)
return mu, pi, log_pi, log_std
def Q(self, z, a, task, return_type='min', target=False):
"""
Predict state-action value.
`return_type` can be one of [`min`, `avg`, `all`]:
- `min`: return the minimum of two randomly subsampled Q-values.
- `avg`: return the average of two randomly subsampled Q-values.
- `all`: return all Q-values.
`target` specifies whether to use the target Q-networks or not.
"""
assert return_type in {'min', 'avg', 'all'}
if self.cfg.multitask:
z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1)
out = (self._target_Qs if target else self._Qs)(z)
if return_type == 'all':
return out
Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)]
Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg)
return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2