9 Commits

Author SHA1 Message Date
Nicklas Hansen
5e276e6aaa Merge remote-tracking branch 'origin/main' into discrete 2024-11-28 12:02:22 -08:00
Nicklas Hansen
4f07f1ced4 clean up discrete planning 2024-11-24 16:47:02 -08:00
Nicklas Hansen
4dcd933b8f maximum entropy discrete policy 2024-11-22 22:51:47 -08:00
Nicklas Hansen
d463268bd2 update readme 2024-11-12 13:46:04 -08:00
Nicklas Hansen
88ad0620ca add discrete planning 2024-11-12 00:13:08 -08:00
Nicklas Hansen
8280b82d5c argmax policy works 2024-11-11 22:36:40 -08:00
Nicklas Hansen
a9b5ad0ff8 cleanup 2024-11-11 19:09:09 -08:00
Nicklas Hansen
dc6720d322 fix 2024-11-11 18:20:09 -08:00
Nicklas Hansen
dee034070e init 2024-11-11 18:13:24 -08:00
9 changed files with 203 additions and 56 deletions

View File

@@ -12,6 +12,10 @@ Official implementation of
---- ----
**Discrete branch:** this branch is under active development and contains experimental support for discrete action spaces. We expect a stable release to be available in a few months. Please use the `main` branch for the time being.
----
**Announcement: training just got ~4.5x faster!** **Announcement: training just got ~4.5x faster!**
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility! Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!

View File

@@ -42,6 +42,16 @@ def squash(mu, pi, log_pi):
return mu, pi, log_pi return mu, pi, log_pi
def int_to_one_hot(x, num_classes):
"""
Converts an integer tensor to a one-hot tensor.
Supports batched inputs.
"""
one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return one_hot
def symlog(x): def symlog(x):
""" """
Symmetric logarithmic function. Symmetric logarithmic function.

View File

@@ -2,9 +2,12 @@ from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from tensordict.nn import TensorDictParams
from common import layers, math, init from common import layers, math, init
from tensordict.nn import TensorDictParams
class WorldModel(nn.Module): class WorldModel(nn.Module):
""" """
@@ -23,7 +26,7 @@ class WorldModel(nn.Module):
self._encoder = layers.enc(cfg) self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action_space == 'continuous' else cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
@@ -121,15 +124,12 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
return self._reward(z) return self._reward(z)
def pi(self, z, task): def _continuous_pi(self, z, task):
""" """
Samples an action from the policy prior. Samples an action from the policy prior.
The policy prior is a Gaussian distribution with The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network. mean and (log) std predicted by a neural network.
""" """
if self.cfg.multitask:
z = self.task_emb(z, task)
# Gaussian policy prior # Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1) mu, log_std = self._pi(z).chunk(2, dim=-1)
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif) log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
@@ -149,6 +149,41 @@ class WorldModel(nn.Module):
return mu, pi, log_pi, log_std return mu, pi, log_pi, log_std
def _discrete_pi(self, z, task):
"""
Samples an action from the policy prior.
The policy prior is a categorical distribution
with logits predicted by a neural network.
"""
assert task is None, "Discrete policy does not support multitask."
# Categorical policy prior
logits = self._pi(z)
policy_dist = Categorical(logits=logits)
action = policy_dist.sample()
action_probs = policy_dist.probs
log_prob = F.log_softmax(logits, dim=-1)
one_hot_action = math.int_to_one_hot(action, self.cfg.action_dim)
return action, one_hot_action, log_prob, action_probs
def pi(self, z, task):
"""
Samples an action from the policy prior.
Policy can be either continuous (Gaussian) or discrete (categorical).
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
if self.cfg.action_space == 'discrete':
return self._discrete_pi(z, task)
elif self.cfg.action_space == 'continuous':
return self._continuous_pi(z, task)
else:
raise NotImplementedError(f"Action space {self.cfg.action} not supported.")
def Q(self, z, a, task, return_type='min', target=False, detach=False): def Q(self, z, a, task, return_type='min', target=False, detach=False):
""" """
Predict state-action value. Predict state-action value.

View File

@@ -2,7 +2,7 @@ defaults:
- override hydra/launcher: submitit_local - override hydra/launcher: submitit_local
# environment # environment
task: dog-run task: discrete-cartpole-swingup
obs: state obs: state
# evaluation # evaluation
@@ -79,6 +79,7 @@ task_title: ???
multitask: ??? multitask: ???
tasks: ??? tasks: ???
obs_shape: ??? obs_shape: ???
action_space: ???
action_dim: ??? action_dim: ???
episode_length: ??? episode_length: ???
obs_shapes: ??? obs_shapes: ???

View File

@@ -3,6 +3,7 @@ import warnings
import gym import gym
from envs.wrappers.discrete import DiscreteWrapper
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
@@ -59,24 +60,34 @@ def make_env(cfg):
gym.logger.set_level(40) gym.logger.set_level(40)
if cfg.multitask: if cfg.multitask:
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
if cfg.task.startswith('discrete-'):
discrete = True
cfg.task = cfg.task.replace('discrete-', '')
else:
discrete = False
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
break
except ValueError: except ValueError:
pass pass
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env) env = TensorWrapper(env)
if discrete:
env = DiscreteWrapper(env)
if cfg.get('obs', 'state') == 'rgb': if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env) env = PixelWrapper(cfg, env)
try: # Dict try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box except: # Box
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0] assert not isinstance(env.action_space, (gym.spaces.Dict, gym.spaces.MultiDiscrete)), \
'Dict and MultiDiscrete action spaces are not supported.'
cfg.action_space = 'discrete' if isinstance(env.action_space, gym.spaces.Discrete) else 'continuous'
cfg.action_dim = env.action_space.n if cfg.action_space == 'discrete' else env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length) cfg.seed_steps = max(1000, 5*cfg.episode_length)
return env return env

View File

@@ -1,4 +1,4 @@
from collections import deque, defaultdict from collections import defaultdict
from typing import Any, NamedTuple from typing import Any, NamedTuple
import dm_env import dm_env
import numpy as np import numpy as np

View File

@@ -0,0 +1,33 @@
import gym
import torch
from common import math
class DiscreteWrapper(gym.Wrapper):
"""
Wrapper for converting continuous action spaces to discrete via binning.
"""
def __init__(self, env, bins_per_dim=5):
super().__init__(env)
self.bins_per_dim = bins_per_dim
self.continuous_dims = self.env.action_space.shape[0]
# Equally spaced bins along each dimension
self.action_space = gym.spaces.Discrete(bins_per_dim ** self.continuous_dims)
def rand_act(self):
action = torch.tensor(self.action_space.sample(), dtype=torch.int64)
return math.int_to_one_hot(action, self.action_space.n)
def _discrete_to_continuous(self, action):
# Convert a discrete action to a continuous action
action = torch.argmax(action)
action = action.item()
action = [action // self.bins_per_dim ** i % self.bins_per_dim for i in range(self.continuous_dims)]
action = torch.tensor(action, dtype=torch.float32)
return (action - 1) / 1
def step(self, action):
action = self._discrete_to_continuous(action)
return self.env.step(action)

View File

@@ -103,11 +103,12 @@ class TDMPC2(torch.nn.Module):
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
else: else:
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
a = self.model.pi(z, task)[int(not eval_mode)][0] select_idx = int(not eval_mode or self.cfg.action_space == 'discrete')
return a.cpu() action = self.model.pi(z, task)[select_idx][0]
return action.cpu()
@torch.no_grad() @torch.no_grad()
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
@@ -119,7 +120,37 @@ class TDMPC2(torch.nn.Module):
G = G + discount * reward G = G + discount * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update discount = discount * discount_update
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') pi = self.model.pi(z, task)[1]
if self.cfg.action_space == 'discrete':
pi = pi.squeeze(1) # TODO: this is a bit hacky
return G + discount * self.model.Q(z, pi, task, return_type='avg')
@torch.no_grad()
def _sample_policy(self, z, task):
"""Sample trajectories from the policy prior."""
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
for t in range(self.cfg.horizon-1):
action = self.model.pi(z, task)[1]
if self.cfg.action_space == 'discrete':
action = action.squeeze(1)
pi_actions[t] = action
z = self.model.next(z, pi_actions[t], task)
action = self.model.pi(z, task)[1]
if self.cfg.action_space == 'discrete':
action = action.squeeze(1)
pi_actions[-1] = action
return pi_actions
@torch.no_grad()
def _sample_actions(self, n, mean=None, std=None):
"""Sample actions from a Gaussian or Categorical distribution."""
if self.cfg.action_space == 'discrete':
actions = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, n), device=self.device)
actions = math.int_to_one_hot(actions, self.cfg.action_dim)
else:
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
actions = (mean.unsqueeze(1) + std.unsqueeze(1) * r).clamp(-1, 1)
return actions
@torch.no_grad() @torch.no_grad()
def _plan(self, obs, t0=False, eval_mode=False, task=None): def _plan(self, obs, t0=False, eval_mode=False, task=None):
@@ -135,61 +166,61 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
# Sample policy trajectories # Encode observation
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1]
_z = self.model.next(_z, pi_actions[t], task)
pi_actions[-1] = self.model.pi(_z, task)[1]
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.repeat(self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) # Initialize parameters
if not t0: if self.cfg.action_space == 'continuous':
mean[:-1] = self._prev_mean[1:] mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
if not t0:
mean[:-1] = self._prev_mean[1:]
else:
mean, std = None, None
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
# Sample policy trajectories
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions actions[:, :self.cfg.num_pi_trajs] = self._sample_policy(z[:self.cfg.num_pi_trajs], task)
# Iterate MPPI # Iterate MPPI
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
# Sample actions # Sample random actions
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) actions[:, self.cfg.num_pi_trajs:] = self._sample_actions(self.cfg.num_samples-self.cfg.num_pi_trajs, mean, std)
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
actions_sample = actions_sample.clamp(-1, 1)
actions[:, self.cfg.num_pi_trajs:] = actions_sample
if self.cfg.multitask: if self.cfg.multitask:
actions = actions * self.model._action_masks[task] actions = actions * self.model._action_masks[task]
# Compute elite actions # Select elites and compute scores
value = self._estimate_value(z, actions, task).nan_to_num(0) value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0).values max_value = elite_value.max(0).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value)) score = torch.exp(self.cfg.temperature*(elite_value - max_value))
score = score / score.sum(0) score = score / score.sum(0)
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask:
mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task]
# Update parameters
if self.cfg.action_space == 'continuous':
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask:
mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task]
else:
break
# Select action # Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) action = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)[0]
a, std = actions[0], std[0] if self.cfg.action_space == 'continuous':
if not eval_mode: if not eval_mode:
a = a + std * torch.randn(self.cfg.action_dim, device=std.device) action = action + std[0] * torch.randn(self.cfg.action_dim, device=std.device)
self._prev_mean.copy_(mean) self._prev_mean.copy_(mean)
return a.clamp(-1, 1) action = action.clamp(-1, 1)
return action
def update_pi(self, zs, task): def update_pi(self, zs, task):
""" """
@@ -202,14 +233,28 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
float: Loss of the policy update. float: Loss of the policy update.
""" """
_, pis, log_pis, _ = self.model.pi(zs, task) _, actions, log_probs, action_probs = self.model.pi(zs, task)
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
self.scale.update(qs[0]) if self.cfg.action_space == 'discrete':
actions = torch.eye(self.cfg.action_dim, device=zs.device).unsqueeze(0)
zs = zs.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
actions = actions.unsqueeze(0).repeat(zs.shape[0], zs.shape[1], 1, 1)
qs = self.model.Q(zs, actions, task, return_type='avg', detach=True)
if self.cfg.action_space == 'discrete':
qs = qs.squeeze(-1)
self.scale.update(torch.sum(action_probs*qs,dim=(1,2),keepdim=True)[0])
else:
self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
# Loss is a weighted sum of Q-values # Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() if self.cfg.action_space == 'discrete':
pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean()
else:
pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward() pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step() self.pi_optim.step()
@@ -231,6 +276,8 @@ class TDMPC2(torch.nn.Module):
torch.Tensor: TD-target. torch.Tensor: TD-target.
""" """
pi = self.model.pi(next_z, task)[1] pi = self.model.pi(next_z, task)[1]
if self.cfg.action_space == 'discrete':
pi = pi.squeeze(2) # TODO: this is a bit hacky
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)

View File

@@ -54,7 +54,8 @@ class OnlineTrainer(Trainer):
else: else:
obs = obs.unsqueeze(0).cpu() obs = obs.unsqueeze(0).cpu()
if action is None: if action is None:
action = torch.full_like(self.env.rand_act(), float('nan')) action_val = -1 if self.cfg.action_space == 'discrete' else float('nan')
action = torch.full_like(self.env.rand_act(), action_val)
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
td = TensorDict( td = TensorDict(
@@ -97,6 +98,11 @@ class OnlineTrainer(Trainer):
action = self.agent.act(obs, t0=len(self._tds)==1) action = self.agent.act(obs, t0=len(self._tds)==1)
else: else:
action = self.env.rand_act() action = self.env.rand_act()
if self.cfg.action_space == 'discrete':
# exploration schedule
# minimum 0.01, maximum 0.05, anneal over 20k steps
if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000):
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward)) self._tds.append(self.to_td(obs, action, reward))