update dependencies
This commit is contained in:
@@ -2,7 +2,6 @@ name: tdmpc2
|
|||||||
channels:
|
channels:
|
||||||
- pytorch-nightly
|
- pytorch-nightly
|
||||||
- nvidia
|
- nvidia
|
||||||
- anaconda
|
|
||||||
- conda-forge
|
- conda-forge
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
@@ -10,58 +9,44 @@ dependencies:
|
|||||||
- pytorch
|
- pytorch
|
||||||
- torchvision
|
- torchvision
|
||||||
- cudatoolkit=11.7
|
- cudatoolkit=11.7
|
||||||
- fluidsynth
|
|
||||||
- portaudio
|
|
||||||
- glew
|
- glew
|
||||||
- glib
|
- glib
|
||||||
- pillow
|
- pip==21
|
||||||
- pip
|
|
||||||
- pip:
|
- pip:
|
||||||
- absl-py
|
- absl-py
|
||||||
- click
|
|
||||||
- cloudpickle
|
|
||||||
- gpustat
|
|
||||||
- glfw
|
- glfw
|
||||||
- kornia
|
- kornia
|
||||||
- termcolor
|
- termcolor
|
||||||
- gym==0.21.0
|
- gym==0.21.0
|
||||||
- pandas
|
|
||||||
- moviepy
|
- moviepy
|
||||||
- ffmpeg
|
- ffmpeg
|
||||||
- imageio
|
- imageio
|
||||||
- imageio-ffmpeg
|
- imageio-ffmpeg
|
||||||
- lxml
|
|
||||||
- pyparsing
|
|
||||||
- omegaconf
|
- omegaconf
|
||||||
- hydra-core
|
- hydra-core
|
||||||
- hydra-submitit-launcher
|
- hydra-submitit-launcher
|
||||||
- submitit
|
- submitit
|
||||||
- patchelf
|
- patchelf
|
||||||
- protobuf
|
- protobuf
|
||||||
- scipy
|
|
||||||
- tqdm
|
- tqdm
|
||||||
- xmltodict
|
|
||||||
- transforms3d
|
- transforms3d
|
||||||
- joblib
|
- joblib
|
||||||
- scikit-image
|
|
||||||
- einops
|
|
||||||
- opencv-python
|
- opencv-python
|
||||||
- opencv-contrib-python
|
- opencv-contrib-python
|
||||||
- filelock
|
- filelock
|
||||||
- sapien==2.2.1
|
- sapien==2.2.1
|
||||||
- mani-skill2==0.4.1
|
- mani-skill2==0.4.1
|
||||||
- tabulate
|
|
||||||
- h5py
|
|
||||||
- trimesh
|
- trimesh
|
||||||
- open3d
|
- open3d
|
||||||
- rtree
|
- setuptools==65.5.0
|
||||||
- seaborn
|
- "cython<3"
|
||||||
- mujoco==2.3.1
|
- mujoco==2.3.1
|
||||||
- mujoco-py==2.1.2.14
|
- mujoco-py==2.1.2.14
|
||||||
- dm-control
|
- dm-control
|
||||||
- plotly
|
- pillow
|
||||||
- pyquaternion
|
- pyquaternion
|
||||||
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||||
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed
|
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed
|
||||||
- tensordict-nightly
|
- tensordict-nightly
|
||||||
- torchrl-nightly
|
- torchrl-nightly
|
||||||
|
- wandb
|
||||||
|
|||||||
@@ -17,9 +17,6 @@ class Ensemble(nn.Module):
|
|||||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||||
self._repr = str(modules)
|
self._repr = str(modules)
|
||||||
|
|
||||||
def modules(self):
|
|
||||||
return self.vmap.__wrapped__.stateless_model
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -59,17 +59,13 @@ class WorldModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Enables/disables gradient tracking of Q-networks.
|
Enables/disables gradient tracking of Q-networks.
|
||||||
Avoids unnecessary computation during policy optimization.
|
Avoids unnecessary computation during policy optimization.
|
||||||
This method also enables/disables gradients for task embeddings,
|
This method also enables/disables gradients for task embeddings.
|
||||||
and sets the dropout probability to 0 if `mode` is False.
|
|
||||||
"""
|
"""
|
||||||
for p in self._Qs.parameters():
|
for p in self._Qs.parameters():
|
||||||
p.requires_grad_(mode)
|
p.requires_grad_(mode)
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
for p in self._task_emb.parameters():
|
for p in self._task_emb.parameters():
|
||||||
p.requires_grad_(mode)
|
p.requires_grad_(mode)
|
||||||
for m in self._Qs.modules():
|
|
||||||
if isinstance(m, nn.Dropout):
|
|
||||||
m.p = self.cfg.dropout if mode else 0
|
|
||||||
|
|
||||||
def soft_update_target_Q(self):
|
def soft_update_target_Q(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user