support newest version of myosuite
This commit is contained in:
@@ -39,7 +39,7 @@ dependencies:
|
||||
- gpustat==1.1.1
|
||||
####################
|
||||
# Gym:
|
||||
# (unmaintained but required for maniskill2/meta-world/myosuite)
|
||||
# (unmaintained but required for maniskill2/meta-world)
|
||||
- gym==0.21.0
|
||||
####################
|
||||
# ManiSkill2:
|
||||
@@ -51,6 +51,5 @@ dependencies:
|
||||
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||
####################
|
||||
# MyoSuite:
|
||||
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
|
||||
# - myosuite
|
||||
####################
|
||||
|
||||
@@ -24,9 +24,11 @@ class MyoSuiteWrapper(gym.Wrapper):
|
||||
self.cfg = cfg
|
||||
self.camera_id = 'hand_side_inter'
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()[0]
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, _, info = self.env.step(action.copy())
|
||||
obs = obs.astype(np.float32)
|
||||
obs, reward, _, _, info = self.env.step(action.copy())
|
||||
info['success'] = info['solved']
|
||||
return obs, reward, False, info
|
||||
|
||||
@@ -48,7 +50,8 @@ def make_env(cfg):
|
||||
raise ValueError('Unknown task:', cfg.task)
|
||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||
import myosuite
|
||||
env = gym.make(MYOSUITE_TASKS[cfg.task])
|
||||
from myosuite.utils import gym as gym_utils
|
||||
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
|
||||
env = MyoSuiteWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
env.max_episode_steps = env._max_episode_steps
|
||||
|
||||
@@ -18,6 +18,7 @@ from tdmpc2 import TDMPC2
|
||||
from trainer.offline_trainer import OfflineTrainer
|
||||
from trainer.online_trainer import OnlineTrainer
|
||||
from common.logger import Logger
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user