Merge branch 'episodic-rl' of github.com:nicklashansen/tdmpc2 into episodic-rl

This commit is contained in:
Nicklas Hansen
2024-01-08 10:55:10 -08:00
7 changed files with 40 additions and 10 deletions

View File

@@ -78,7 +78,7 @@ class Buffer():
obs = td['obs'] obs = td['obs']
action = td['action'][1:] action = td['action'][1:]
reward = td['reward'][1:].unsqueeze(-1) reward = td['reward'][1:].unsqueeze(-1)
terminated = td['terminated'][1:].unsqueeze(-1) terminated = td['terminated'][-1].unsqueeze(-1)
task = td['task'][0] if 'task' in td.keys() else None task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, terminated, task) return self._to_device(obs, action, reward, terminated, task)

View File

@@ -2,8 +2,9 @@ defaults:
- override hydra/launcher: submitit_local - override hydra/launcher: submitit_local
# environment # environment
task: dog-run task: cartpole-balance-sparse
obs: state obs: state
episodic: true
# evaluation # evaluation
checkpoint: ??? checkpoint: ???

View File

@@ -10,6 +10,8 @@ from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs from dm_env import StepType, specs
import gym import gym
from envs.wrappers.episodic import EpisodicWrapper
class ExtendedTimeStep(NamedTuple): class ExtendedTimeStep(NamedTuple):
step_type: Any step_type: Any
@@ -197,4 +199,6 @@ def make_env(cfg):
env = action_scale.Wrapper(env, minimum=-1., maximum=1.) env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
env = ExtendedTimeStepWrapper(env) env = ExtendedTimeStepWrapper(env)
env = TimeStepToGymWrapper(env, domain, task) env = TimeStepToGymWrapper(env, domain, task)
if cfg.episodic:
env = EpisodicWrapper(cfg, env)
return env return env

View File

@@ -0,0 +1,24 @@
from collections import deque
import gym
import numpy as np
import torch
class EpisodicWrapper(gym.Wrapper):
"""
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
"""
def __init__(self, cfg, env):
super().__init__(env)
assert cfg.task == 'cartpole-balance-sparse'
self.cfg = cfg
self.env = env
def step(self, action):
obs, reward, done, info = self.env.step(action)
if self.cfg.episodic and reward == 0:
done = True
info['terminated'] = True
return obs, reward, done, info

View File

@@ -230,7 +230,7 @@ class TDMPC2:
dict: Dictionary of training statistics. dict: Dictionary of training statistics.
""" """
obs, action, reward, terminated, task = buffer.sample() obs, action, reward, terminated, task = buffer.sample()
# Compute targets # Compute targets
with torch.no_grad(): with torch.no_grad():
next_z = self.model.encode(obs[1:], task) next_z = self.model.encode(obs[1:], task)
@@ -254,15 +254,16 @@ class TDMPC2:
_zs = zs[:-1] _zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all') qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task) reward_preds = self.model.reward(_zs, action, task)
terminated_preds = self.model.terminated(_zs, task) terminated_pred = self.model.terminated(zs[-1], task)
# Compute losses # Compute losses
reward_loss, terminated_loss, value_loss = 0, 0, 0 reward_loss, terminated_loss, value_loss = 0, 0, 0
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t terminated_loss += F.binary_cross_entropy(terminated_pred[t], terminated[t]) * self.cfg.rho**t
for q in range(self.cfg.num_q): for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated)
consistency_loss *= (1/self.cfg.horizon) consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon) reward_loss *= (1/self.cfg.horizon)
terminated_loss *= (1/self.cfg.horizon) terminated_loss *= (1/self.cfg.horizon)

View File

@@ -46,6 +46,9 @@ def train(cfg: dict):
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \
f'This branch is experimental and only supports cartpole-balance-sparse at this time.'
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
trainer = trainer_cls( trainer = trainer_cls(
cfg=cfg, cfg=cfg,

View File

@@ -47,10 +47,8 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes), episode_success=np.nanmean(ep_successes),
) )
def to_td(self, obs=None, action=None, reward=None, terminated=None): def to_td(self, obs, action=None, reward=None, terminated=None):
"""Creates a TensorDict for a new episode.""" """Creates a TensorDict for a new episode."""
if obs is None:
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
if isinstance(obs, dict): if isinstance(obs, dict):
obs = TensorDict(obs, batch_size=(), device='cpu') obs = TensorDict(obs, batch_size=(), device='cpu')
else: else:
@@ -93,7 +91,6 @@ class OnlineTrainer(Trainer):
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._tds.append(self.to_td()) # Separate episodes with NaNs
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset() obs = self.env.reset()