support newest version of myosuite
This commit is contained in:
@@ -39,7 +39,7 @@ dependencies:
|
|||||||
- gpustat==1.1.1
|
- gpustat==1.1.1
|
||||||
####################
|
####################
|
||||||
# Gym:
|
# Gym:
|
||||||
# (unmaintained but required for maniskill2/meta-world/myosuite)
|
# (unmaintained but required for maniskill2/meta-world)
|
||||||
- gym==0.21.0
|
- gym==0.21.0
|
||||||
####################
|
####################
|
||||||
# ManiSkill2:
|
# ManiSkill2:
|
||||||
@@ -51,6 +51,5 @@ dependencies:
|
|||||||
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||||
####################
|
####################
|
||||||
# MyoSuite:
|
# MyoSuite:
|
||||||
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
|
|
||||||
# - myosuite
|
# - myosuite
|
||||||
####################
|
####################
|
||||||
|
|||||||
@@ -24,9 +24,11 @@ class MyoSuiteWrapper(gym.Wrapper):
|
|||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.camera_id = 'hand_side_inter'
|
self.camera_id = 'hand_side_inter'
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.env.reset()[0]
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, _, info = self.env.step(action.copy())
|
obs, reward, _, _, info = self.env.step(action.copy())
|
||||||
obs = obs.astype(np.float32)
|
|
||||||
info['success'] = info['solved']
|
info['success'] = info['solved']
|
||||||
return obs, reward, False, info
|
return obs, reward, False, info
|
||||||
|
|
||||||
@@ -48,7 +50,8 @@ def make_env(cfg):
|
|||||||
raise ValueError('Unknown task:', cfg.task)
|
raise ValueError('Unknown task:', cfg.task)
|
||||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||||
import myosuite
|
import myosuite
|
||||||
env = gym.make(MYOSUITE_TASKS[cfg.task])
|
from myosuite.utils import gym as gym_utils
|
||||||
|
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
|
||||||
env = MyoSuiteWrapper(env, cfg)
|
env = MyoSuiteWrapper(env, cfg)
|
||||||
env = TimeLimit(env, max_episode_steps=100)
|
env = TimeLimit(env, max_episode_steps=100)
|
||||||
env.max_episode_steps = env._max_episode_steps
|
env.max_episode_steps = env._max_episode_steps
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from tdmpc2 import TDMPC2
|
|||||||
from trainer.offline_trainer import OfflineTrainer
|
from trainer.offline_trainer import OfflineTrainer
|
||||||
from trainer.online_trainer import OnlineTrainer
|
from trainer.online_trainer import OnlineTrainer
|
||||||
from common.logger import Logger
|
from common.logger import Logger
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.set_float32_matmul_precision('high')
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user