From 1a7720764616e4e55621a970e87e86fee7ef60ed Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 4 Nov 2024 15:15:40 -0800 Subject: [PATCH] support newest version of myosuite --- docker/environment.yaml | 3 +-- tdmpc2/envs/myosuite.py | 9 ++++++--- tdmpc2/train.py | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docker/environment.yaml b/docker/environment.yaml index 9425459..9da54e0 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -39,7 +39,7 @@ dependencies: - gpustat==1.1.1 #################### # Gym: - # (unmaintained but required for maniskill2/meta-world/myosuite) + # (unmaintained but required for maniskill2/meta-world) - gym==0.21.0 #################### # ManiSkill2: @@ -51,6 +51,5 @@ dependencies: # - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb #################### # MyoSuite: - # (requires gym==0.13 which conflicts with meta-world / mani-skill2) # - myosuite #################### diff --git a/tdmpc2/envs/myosuite.py b/tdmpc2/envs/myosuite.py index fa6876e..d15f11f 100644 --- a/tdmpc2/envs/myosuite.py +++ b/tdmpc2/envs/myosuite.py @@ -24,9 +24,11 @@ class MyoSuiteWrapper(gym.Wrapper): self.cfg = cfg self.camera_id = 'hand_side_inter' + def reset(self): + return self.env.reset()[0] + def step(self, action): - obs, reward, _, info = self.env.step(action.copy()) - obs = obs.astype(np.float32) + obs, reward, _, _, info = self.env.step(action.copy()) info['success'] = info['solved'] return obs, reward, False, info @@ -48,7 +50,8 @@ def make_env(cfg): raise ValueError('Unknown task:', cfg.task) assert cfg.obs == 'state', 'This task only supports state observations.' import myosuite - env = gym.make(MYOSUITE_TASKS[cfg.task]) + from myosuite.utils import gym as gym_utils + env = gym_utils.make(MYOSUITE_TASKS[cfg.task]) env = MyoSuiteWrapper(env, cfg) env = TimeLimit(env, max_episode_steps=100) env.max_episode_steps = env._max_episode_steps diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 3dc37a6..1846145 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -18,6 +18,7 @@ from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer from trainer.online_trainer import OnlineTrainer from common.logger import Logger + torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision('high')