support newest version of myosuite

This commit is contained in:
Nicklas Hansen
2024-11-04 15:15:40 -08:00
parent b7725e74a5
commit 1a77207646
3 changed files with 8 additions and 5 deletions

View File

@@ -39,7 +39,7 @@ dependencies:
- gpustat==1.1.1
####################
# Gym:
# (unmaintained but required for maniskill2/meta-world/myosuite)
# (unmaintained but required for maniskill2/meta-world)
- gym==0.21.0
####################
# ManiSkill2:
@@ -51,6 +51,5 @@ dependencies:
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
####################
# MyoSuite:
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
# - myosuite
####################

View File

@@ -24,9 +24,11 @@ class MyoSuiteWrapper(gym.Wrapper):
self.cfg = cfg
self.camera_id = 'hand_side_inter'
def reset(self):
return self.env.reset()[0]
def step(self, action):
obs, reward, _, info = self.env.step(action.copy())
obs = obs.astype(np.float32)
obs, reward, _, _, info = self.env.step(action.copy())
info['success'] = info['solved']
return obs, reward, False, info
@@ -48,7 +50,8 @@ def make_env(cfg):
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task])
from myosuite.utils import gym as gym_utils
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps

View File

@@ -18,6 +18,7 @@ from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer
from trainer.online_trainer import OnlineTrainer
from common.logger import Logger
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')