diff --git a/docker/environment.yaml b/docker/environment.yaml index 9f0e6f1..18a9914 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -2,7 +2,6 @@ name: tdmpc2 channels: - pytorch-nightly - nvidia - - anaconda - conda-forge - defaults dependencies: @@ -10,58 +9,44 @@ dependencies: - pytorch - torchvision - cudatoolkit=11.7 - - fluidsynth - - portaudio - glew - glib - - pillow - - pip + - pip==21 - pip: - absl-py - - click - - cloudpickle - - gpustat - glfw - kornia - termcolor - gym==0.21.0 - - pandas - moviepy - ffmpeg - imageio - imageio-ffmpeg - - lxml - - pyparsing - omegaconf - hydra-core - hydra-submitit-launcher - submitit - patchelf - protobuf - - scipy - tqdm - - xmltodict - transforms3d - joblib - - scikit-image - - einops - opencv-python - opencv-contrib-python - filelock - sapien==2.2.1 - mani-skill2==0.4.1 - - tabulate - - h5py - trimesh - open3d - - rtree - - seaborn + - setuptools==65.5.0 + - "cython<3" - mujoco==2.3.1 - mujoco-py==2.1.2.14 - dm-control - - plotly + - pillow - pyquaternion - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb # - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed - tensordict-nightly - torchrl-nightly + - wandb diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index baebf73..cb63997 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -17,9 +17,6 @@ class Ensemble(nn.Module): self.params = nn.ParameterList([nn.Parameter(p) for p in params]) self._repr = str(modules) - def modules(self): - return self.vmap.__wrapped__.stateless_model - def forward(self, *args, **kwargs): return self.vmap([p for p in self.params], (), *args, **kwargs) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 30fb1d4..8c9c5fd 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -59,17 +59,13 @@ class WorldModel(nn.Module): """ Enables/disables gradient tracking of Q-networks. Avoids unnecessary computation during policy optimization. - This method also enables/disables gradients for task embeddings, - and sets the dropout probability to 0 if `mode` is False. + This method also enables/disables gradients for task embeddings. """ for p in self._Qs.parameters(): p.requires_grad_(mode) if self.cfg.multitask: for p in self._task_emb.parameters(): p.requires_grad_(mode) - for m in self._Qs.modules(): - if isinstance(m, nn.Dropout): - m.p = self.cfg.dropout if mode else 0 def soft_update_target_Q(self): """