solves episodic variant of cartpole-balance-sparse
This commit is contained in:
@@ -78,7 +78,7 @@ class Buffer():
|
|||||||
obs = td['obs']
|
obs = td['obs']
|
||||||
action = td['action'][1:]
|
action = td['action'][1:]
|
||||||
reward = td['reward'][1:].unsqueeze(-1)
|
reward = td['reward'][1:].unsqueeze(-1)
|
||||||
terminated = td['terminated'][1:].unsqueeze(-1)
|
terminated = td['terminated'][-1].unsqueeze(-1)
|
||||||
task = td['task'][0] if 'task' in td.keys() else None
|
task = td['task'][0] if 'task' in td.keys() else None
|
||||||
return self._to_device(obs, action, reward, terminated, task)
|
return self._to_device(obs, action, reward, terminated, task)
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ defaults:
|
|||||||
- override hydra/launcher: submitit_local
|
- override hydra/launcher: submitit_local
|
||||||
|
|
||||||
# environment
|
# environment
|
||||||
task: dog-run
|
task: cartpole-balance-sparse
|
||||||
obs: state
|
obs: state
|
||||||
|
episodic: true
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
checkpoint: ???
|
checkpoint: ???
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from dm_control.suite.wrappers import action_scale
|
|||||||
from dm_env import StepType, specs
|
from dm_env import StepType, specs
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
from envs.wrappers.episodic import EpisodicWrapper
|
||||||
|
|
||||||
|
|
||||||
class ExtendedTimeStep(NamedTuple):
|
class ExtendedTimeStep(NamedTuple):
|
||||||
step_type: Any
|
step_type: Any
|
||||||
@@ -197,4 +199,6 @@ def make_env(cfg):
|
|||||||
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
||||||
env = ExtendedTimeStepWrapper(env)
|
env = ExtendedTimeStepWrapper(env)
|
||||||
env = TimeStepToGymWrapper(env, domain, task)
|
env = TimeStepToGymWrapper(env, domain, task)
|
||||||
|
if cfg.episodic:
|
||||||
|
env = EpisodicWrapper(cfg, env)
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -47,12 +47,9 @@ class ManiSkillWrapper(gym.Wrapper):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
reward = 0
|
reward = 0
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
obs, r, done, info = self.env.step(action)
|
obs, r, _, info = self.env.step(action)
|
||||||
reward += r
|
reward += r
|
||||||
info['terminated'] = done
|
return obs, reward, False, info
|
||||||
if done:
|
|
||||||
break
|
|
||||||
return obs, reward, done, info
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
|
|||||||
24
tdmpc2/envs/wrappers/episodic.py
Normal file
24
tdmpc2/envs/wrappers/episodic.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg, env):
|
||||||
|
super().__init__(env)
|
||||||
|
assert cfg.task == 'cartpole-balance-sparse'
|
||||||
|
self.cfg = cfg
|
||||||
|
self.env = env
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
if self.cfg.episodic and reward == 0:
|
||||||
|
done = True
|
||||||
|
info['terminated'] = True
|
||||||
|
return obs, reward, done, info
|
||||||
@@ -254,15 +254,15 @@ class TDMPC2:
|
|||||||
_zs = zs[:-1]
|
_zs = zs[:-1]
|
||||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||||
reward_preds = self.model.reward(_zs, action, task)
|
reward_preds = self.model.reward(_zs, action, task)
|
||||||
terminated_preds = self.model.terminated(_zs, task)
|
terminated_pred = self.model.terminated(zs[-1], task)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
reward_loss, terminated_loss, value_loss = 0, 0, 0
|
reward_loss, value_loss = 0, 0
|
||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
||||||
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t
|
|
||||||
for q in range(self.cfg.num_q):
|
for q in range(self.cfg.num_q):
|
||||||
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
||||||
|
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated)
|
||||||
consistency_loss *= (1/self.cfg.horizon)
|
consistency_loss *= (1/self.cfg.horizon)
|
||||||
reward_loss *= (1/self.cfg.horizon)
|
reward_loss *= (1/self.cfg.horizon)
|
||||||
terminated_loss *= (1/self.cfg.horizon)
|
terminated_loss *= (1/self.cfg.horizon)
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ def train(cfg: dict):
|
|||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||||
|
|
||||||
|
assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \
|
||||||
|
f'This branch is experimental and only supports cartpole-balance-sparse at this time.'
|
||||||
|
|
||||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
|
|||||||
Reference in New Issue
Block a user