update dependencies

This commit is contained in:
Nicklas Hansen
2023-11-25 19:20:52 -08:00
parent 58a95e431b
commit f3139291e2
3 changed files with 6 additions and 28 deletions

View File

@@ -2,7 +2,6 @@ name: tdmpc2
channels: channels:
- pytorch-nightly - pytorch-nightly
- nvidia - nvidia
- anaconda
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
@@ -10,58 +9,44 @@ dependencies:
- pytorch - pytorch
- torchvision - torchvision
- cudatoolkit=11.7 - cudatoolkit=11.7
- fluidsynth
- portaudio
- glew - glew
- glib - glib
- pillow - pip==21
- pip
- pip: - pip:
- absl-py - absl-py
- click
- cloudpickle
- gpustat
- glfw - glfw
- kornia - kornia
- termcolor - termcolor
- gym==0.21.0 - gym==0.21.0
- pandas
- moviepy - moviepy
- ffmpeg - ffmpeg
- imageio - imageio
- imageio-ffmpeg - imageio-ffmpeg
- lxml
- pyparsing
- omegaconf - omegaconf
- hydra-core - hydra-core
- hydra-submitit-launcher - hydra-submitit-launcher
- submitit - submitit
- patchelf - patchelf
- protobuf - protobuf
- scipy
- tqdm - tqdm
- xmltodict
- transforms3d - transforms3d
- joblib - joblib
- scikit-image
- einops
- opencv-python - opencv-python
- opencv-contrib-python - opencv-contrib-python
- filelock - filelock
- sapien==2.2.1 - sapien==2.2.1
- mani-skill2==0.4.1 - mani-skill2==0.4.1
- tabulate
- h5py
- trimesh - trimesh
- open3d - open3d
- rtree - setuptools==65.5.0
- seaborn - "cython<3"
- mujoco==2.3.1 - mujoco==2.3.1
- mujoco-py==2.1.2.14 - mujoco-py==2.1.2.14
- dm-control - dm-control
- plotly - pillow
- pyquaternion - pyquaternion
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed # - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed
- tensordict-nightly - tensordict-nightly
- torchrl-nightly - torchrl-nightly
- wandb

View File

@@ -17,9 +17,6 @@ class Ensemble(nn.Module):
self.params = nn.ParameterList([nn.Parameter(p) for p in params]) self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules) self._repr = str(modules)
def modules(self):
return self.vmap.__wrapped__.stateless_model
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs) return self.vmap([p for p in self.params], (), *args, **kwargs)

View File

@@ -59,17 +59,13 @@ class WorldModel(nn.Module):
""" """
Enables/disables gradient tracking of Q-networks. Enables/disables gradient tracking of Q-networks.
Avoids unnecessary computation during policy optimization. Avoids unnecessary computation during policy optimization.
This method also enables/disables gradients for task embeddings, This method also enables/disables gradients for task embeddings.
and sets the dropout probability to 0 if `mode` is False.
""" """
for p in self._Qs.parameters(): for p in self._Qs.parameters():
p.requires_grad_(mode) p.requires_grad_(mode)
if self.cfg.multitask: if self.cfg.multitask:
for p in self._task_emb.parameters(): for p in self._task_emb.parameters():
p.requires_grad_(mode) p.requires_grad_(mode)
for m in self._Qs.modules():
if isinstance(m, nn.Dropout):
m.p = self.cfg.dropout if mode else 0
def soft_update_target_Q(self): def soft_update_target_Q(self):
""" """