From a9b5ad0ff804c1e4d334ff3303327af29d2934eb Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 11 Nov 2024 19:09:09 -0800 Subject: [PATCH] cleanup --- tdmpc2/envs/wrappers/discrete.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tdmpc2/envs/wrappers/discrete.py b/tdmpc2/envs/wrappers/discrete.py index caf5e26..da590b4 100644 --- a/tdmpc2/envs/wrappers/discrete.py +++ b/tdmpc2/envs/wrappers/discrete.py @@ -1,5 +1,4 @@ import gym -import numpy as np import torch from common import math @@ -10,12 +9,13 @@ class DiscreteWrapper(gym.Wrapper): Wrapper for converting continuous action spaces to discrete via binning. """ - def __init__(self, env): + def __init__(self, env, bins_per_dim=5): super().__init__(env) + self.bins_per_dim = bins_per_dim self.continuous_dims = self.env.action_space.shape[0] # Bins at [-1, 0, 1] for each dimension # Discrete actions include all possible combinations of these bins - self.action_space = gym.spaces.Discrete(3 ** self.continuous_dims) + self.action_space = gym.spaces.Discrete(bins_per_dim ** self.continuous_dims) def rand_act(self): action = torch.tensor(self.action_space.sample(), dtype=torch.int64) @@ -26,7 +26,7 @@ class DiscreteWrapper(gym.Wrapper): # action is a one-hot encoded tensor action = torch.argmax(action) action = action.item() - action = [action // 3 ** i % 3 for i in range(self.continuous_dims)] + action = [action // self.bins_per_dim ** i % self.bins_per_dim for i in range(self.continuous_dims)] action = torch.tensor(action, dtype=torch.float32) return (action - 1) / 1