solves episodic variant of cartpole-balance-sparse

This commit is contained in:
Nicklas Hansen
2024-01-07 19:28:41 -08:00
parent 26c72119cd
commit fabf01a5ec
7 changed files with 41 additions and 12 deletions

View File

@@ -78,7 +78,7 @@ class Buffer():
obs = td['obs'] obs = td['obs']
action = td['action'][1:] action = td['action'][1:]
reward = td['reward'][1:].unsqueeze(-1) reward = td['reward'][1:].unsqueeze(-1)
terminated = td['terminated'][1:].unsqueeze(-1) terminated = td['terminated'][-1].unsqueeze(-1)
task = td['task'][0] if 'task' in td.keys() else None task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, terminated, task) return self._to_device(obs, action, reward, terminated, task)

View File

@@ -2,8 +2,9 @@ defaults:
- override hydra/launcher: submitit_local - override hydra/launcher: submitit_local
# environment # environment
task: dog-run task: cartpole-balance-sparse
obs: state obs: state
episodic: true
# evaluation # evaluation
checkpoint: ??? checkpoint: ???

View File

@@ -10,6 +10,8 @@ from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs from dm_env import StepType, specs
import gym import gym
from envs.wrappers.episodic import EpisodicWrapper
class ExtendedTimeStep(NamedTuple): class ExtendedTimeStep(NamedTuple):
step_type: Any step_type: Any
@@ -197,4 +199,6 @@ def make_env(cfg):
env = action_scale.Wrapper(env, minimum=-1., maximum=1.) env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
env = ExtendedTimeStepWrapper(env) env = ExtendedTimeStepWrapper(env)
env = TimeStepToGymWrapper(env, domain, task) env = TimeStepToGymWrapper(env, domain, task)
if cfg.episodic:
env = EpisodicWrapper(cfg, env)
return env return env

View File

@@ -47,12 +47,9 @@ class ManiSkillWrapper(gym.Wrapper):
def step(self, action): def step(self, action):
reward = 0 reward = 0
for _ in range(2): for _ in range(2):
obs, r, done, info = self.env.step(action) obs, r, _, info = self.env.step(action)
reward += r reward += r
info['terminated'] = done return obs, reward, False, info
if done:
break
return obs, reward, done, info
@property @property
def unwrapped(self): def unwrapped(self):

View File

@@ -0,0 +1,24 @@
from collections import deque
import gym
import numpy as np
import torch
class EpisodicWrapper(gym.Wrapper):
"""
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
"""
def __init__(self, cfg, env):
super().__init__(env)
assert cfg.task == 'cartpole-balance-sparse'
self.cfg = cfg
self.env = env
def step(self, action):
obs, reward, done, info = self.env.step(action)
if self.cfg.episodic and reward == 0:
done = True
info['terminated'] = True
return obs, reward, done, info

View File

@@ -254,15 +254,15 @@ class TDMPC2:
_zs = zs[:-1] _zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all') qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task) reward_preds = self.model.reward(_zs, action, task)
terminated_preds = self.model.terminated(_zs, task) terminated_pred = self.model.terminated(zs[-1], task)
# Compute losses # Compute losses
reward_loss, terminated_loss, value_loss = 0, 0, 0 reward_loss, value_loss = 0, 0
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t
for q in range(self.cfg.num_q): for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated)
consistency_loss *= (1/self.cfg.horizon) consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon) reward_loss *= (1/self.cfg.horizon)
terminated_loss *= (1/self.cfg.horizon) terminated_loss *= (1/self.cfg.horizon)

View File

@@ -46,6 +46,9 @@ def train(cfg: dict):
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \
f'This branch is experimental and only supports cartpole-balance-sparse at this time.'
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
trainer = trainer_cls( trainer = trainer_cls(
cfg=cfg, cfg=cfg,