From fabf01a5ec92e7cc5793ee9cbbced881a9c13798 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 7 Jan 2024 19:28:41 -0800 Subject: [PATCH] solves episodic variant of cartpole-balance-sparse --- tdmpc2/common/buffer.py | 2 +- tdmpc2/config.yaml | 3 ++- tdmpc2/envs/dmcontrol.py | 4 ++++ tdmpc2/envs/maniskill.py | 7 ++----- tdmpc2/envs/wrappers/episodic.py | 24 ++++++++++++++++++++++++ tdmpc2/tdmpc2.py | 10 +++++----- tdmpc2/train.py | 3 +++ 7 files changed, 41 insertions(+), 12 deletions(-) create mode 100644 tdmpc2/envs/wrappers/episodic.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index a060a3d..139348b 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -78,7 +78,7 @@ class Buffer(): obs = td['obs'] action = td['action'][1:] reward = td['reward'][1:].unsqueeze(-1) - terminated = td['terminated'][1:].unsqueeze(-1) + terminated = td['terminated'][-1].unsqueeze(-1) task = td['task'][0] if 'task' in td.keys() else None return self._to_device(obs, action, reward, terminated, task) diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index cd3ac11..2288dfb 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -2,8 +2,9 @@ defaults: - override hydra/launcher: submitit_local # environment -task: dog-run +task: cartpole-balance-sparse obs: state +episodic: true # evaluation checkpoint: ??? diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 97be75a..bda699c 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -10,6 +10,8 @@ from dm_control.suite.wrappers import action_scale from dm_env import StepType, specs import gym +from envs.wrappers.episodic import EpisodicWrapper + class ExtendedTimeStep(NamedTuple): step_type: Any @@ -197,4 +199,6 @@ def make_env(cfg): env = action_scale.Wrapper(env, minimum=-1., maximum=1.) env = ExtendedTimeStepWrapper(env) env = TimeStepToGymWrapper(env, domain, task) + if cfg.episodic: + env = EpisodicWrapper(cfg, env) return env diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py index 4a37b6d..7b0b6ed 100644 --- a/tdmpc2/envs/maniskill.py +++ b/tdmpc2/envs/maniskill.py @@ -47,12 +47,9 @@ class ManiSkillWrapper(gym.Wrapper): def step(self, action): reward = 0 for _ in range(2): - obs, r, done, info = self.env.step(action) + obs, r, _, info = self.env.step(action) reward += r - info['terminated'] = done - if done: - break - return obs, reward, done, info + return obs, reward, False, info @property def unwrapped(self): diff --git a/tdmpc2/envs/wrappers/episodic.py b/tdmpc2/envs/wrappers/episodic.py new file mode 100644 index 0000000..f27c8fc --- /dev/null +++ b/tdmpc2/envs/wrappers/episodic.py @@ -0,0 +1,24 @@ +from collections import deque + +import gym +import numpy as np +import torch + + +class EpisodicWrapper(gym.Wrapper): + """ + Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment. + """ + + def __init__(self, cfg, env): + super().__init__(env) + assert cfg.task == 'cartpole-balance-sparse' + self.cfg = cfg + self.env = env + + def step(self, action): + obs, reward, done, info = self.env.step(action) + if self.cfg.episodic and reward == 0: + done = True + info['terminated'] = True + return obs, reward, done, info diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 0f09b6c..3abdd1f 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -230,7 +230,7 @@ class TDMPC2: dict: Dictionary of training statistics. """ obs, action, reward, terminated, task = buffer.sample() - + # Compute targets with torch.no_grad(): next_z = self.model.encode(obs[1:], task) @@ -254,15 +254,15 @@ class TDMPC2: _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - terminated_preds = self.model.terminated(_zs, task) - + terminated_pred = self.model.terminated(zs[-1], task) + # Compute losses - reward_loss, terminated_loss, value_loss = 0, 0, 0 + reward_loss, value_loss = 0, 0 for t in range(self.cfg.horizon): reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t - terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t for q in range(self.cfg.num_q): value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t + terminated_loss = F.binary_cross_entropy(terminated_pred, terminated) consistency_loss *= (1/self.cfg.horizon) reward_loss *= (1/self.cfg.horizon) terminated_loss *= (1/self.cfg.horizon) diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 5953bb2..1680871 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -46,6 +46,9 @@ def train(cfg: dict): set_seed(cfg.seed) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) + assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \ + f'This branch is experimental and only supports cartpole-balance-sparse at this time.' + trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer = trainer_cls( cfg=cfg,