Merge branch 'episodic-rl' of github.com:nicklashansen/tdmpc2 into episodic-rl
This commit is contained in:
@@ -78,7 +78,7 @@ class Buffer():
|
||||
obs = td['obs']
|
||||
action = td['action'][1:]
|
||||
reward = td['reward'][1:].unsqueeze(-1)
|
||||
terminated = td['terminated'][1:].unsqueeze(-1)
|
||||
terminated = td['terminated'][-1].unsqueeze(-1)
|
||||
task = td['task'][0] if 'task' in td.keys() else None
|
||||
return self._to_device(obs, action, reward, terminated, task)
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@ defaults:
|
||||
- override hydra/launcher: submitit_local
|
||||
|
||||
# environment
|
||||
task: dog-run
|
||||
task: cartpole-balance-sparse
|
||||
obs: state
|
||||
episodic: true
|
||||
|
||||
# evaluation
|
||||
checkpoint: ???
|
||||
|
||||
@@ -10,6 +10,8 @@ from dm_control.suite.wrappers import action_scale
|
||||
from dm_env import StepType, specs
|
||||
import gym
|
||||
|
||||
from envs.wrappers.episodic import EpisodicWrapper
|
||||
|
||||
|
||||
class ExtendedTimeStep(NamedTuple):
|
||||
step_type: Any
|
||||
@@ -197,4 +199,6 @@ def make_env(cfg):
|
||||
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
||||
env = ExtendedTimeStepWrapper(env)
|
||||
env = TimeStepToGymWrapper(env, domain, task)
|
||||
if cfg.episodic:
|
||||
env = EpisodicWrapper(cfg, env)
|
||||
return env
|
||||
|
||||
24
tdmpc2/envs/wrappers/episodic.py
Normal file
24
tdmpc2/envs/wrappers/episodic.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from collections import deque
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class EpisodicWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env):
|
||||
super().__init__(env)
|
||||
assert cfg.task == 'cartpole-balance-sparse'
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
if self.cfg.episodic and reward == 0:
|
||||
done = True
|
||||
info['terminated'] = True
|
||||
return obs, reward, done, info
|
||||
@@ -230,7 +230,7 @@ class TDMPC2:
|
||||
dict: Dictionary of training statistics.
|
||||
"""
|
||||
obs, action, reward, terminated, task = buffer.sample()
|
||||
|
||||
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
next_z = self.model.encode(obs[1:], task)
|
||||
@@ -254,15 +254,16 @@ class TDMPC2:
|
||||
_zs = zs[:-1]
|
||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||
reward_preds = self.model.reward(_zs, action, task)
|
||||
terminated_preds = self.model.terminated(_zs, task)
|
||||
|
||||
terminated_pred = self.model.terminated(zs[-1], task)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, terminated_loss, value_loss = 0, 0, 0
|
||||
for t in range(self.cfg.horizon):
|
||||
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
||||
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t
|
||||
terminated_loss += F.binary_cross_entropy(terminated_pred[t], terminated[t]) * self.cfg.rho**t
|
||||
for q in range(self.cfg.num_q):
|
||||
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
||||
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated)
|
||||
consistency_loss *= (1/self.cfg.horizon)
|
||||
reward_loss *= (1/self.cfg.horizon)
|
||||
terminated_loss *= (1/self.cfg.horizon)
|
||||
|
||||
@@ -46,6 +46,9 @@ def train(cfg: dict):
|
||||
set_seed(cfg.seed)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
|
||||
assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \
|
||||
f'This branch is experimental and only supports cartpole-balance-sparse at this time.'
|
||||
|
||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||
trainer = trainer_cls(
|
||||
cfg=cfg,
|
||||
|
||||
@@ -47,10 +47,8 @@ class OnlineTrainer(Trainer):
|
||||
episode_success=np.nanmean(ep_successes),
|
||||
)
|
||||
|
||||
def to_td(self, obs=None, action=None, reward=None, terminated=None):
|
||||
def to_td(self, obs, action=None, reward=None, terminated=None):
|
||||
"""Creates a TensorDict for a new episode."""
|
||||
if obs is None:
|
||||
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
|
||||
if isinstance(obs, dict):
|
||||
obs = TensorDict(obs, batch_size=(), device='cpu')
|
||||
else:
|
||||
@@ -93,7 +91,6 @@ class OnlineTrainer(Trainer):
|
||||
)
|
||||
train_metrics.update(self.common_metrics())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
self._tds.append(self.to_td()) # Separate episodes with NaNs
|
||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||
|
||||
obs = self.env.reset()
|
||||
|
||||
Reference in New Issue
Block a user