Organize unused code
This commit is contained in:
@@ -152,7 +152,6 @@ class ReplayBuffer(BaseBuffer):
|
||||
device: Union[th.device, str] = "cpu",
|
||||
n_envs: int = 1,
|
||||
optimize_memory_usage: bool = False,
|
||||
handle_timeout_termination: bool = True,
|
||||
):
|
||||
super(ReplayBuffer, self).__init__(buffer_size, observation_dim, device, n_envs=n_envs)
|
||||
|
||||
@@ -165,15 +164,10 @@ class ReplayBuffer(BaseBuffer):
|
||||
|
||||
self.optimize_memory_usage = optimize_memory_usage
|
||||
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs) + observation_dim, dtype=np.float32)
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs, observation_dim), dtype=np.float32)
|
||||
self.goals = np.zeros((self.buffer_size, self.n_envs, 3), dtype=np.float32)
|
||||
self.depths = np.zeros((self.buffer_size, self.n_envs, 1, image_WxH[1], image_WxH[0]), dtype=np.float32)
|
||||
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.float32)
|
||||
|
||||
# Handle timeouts termination properly if needed
|
||||
# see https://github.com/DLR-RM/stable-baselines3/issues/284
|
||||
self.handle_timeout_termination = handle_timeout_termination
|
||||
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.int16)
|
||||
|
||||
if psutil is not None:
|
||||
total_memory_usage = self.observations.nbytes + self.goals.nbytes + self.depths.nbytes + self.map_ids.nbytes
|
||||
@@ -187,16 +181,11 @@ class ReplayBuffer(BaseBuffer):
|
||||
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
def add(self,
|
||||
obs: np.ndarray,
|
||||
goal: np.ndarray,
|
||||
depth: np.ndarray,
|
||||
map_id: int,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
|
||||
# TODO: 删了obs的格式调整,检查下还能不能正常放
|
||||
map_id: int) -> None:
|
||||
|
||||
# Copy to avoid modification by reference
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
@@ -204,9 +193,6 @@ class ReplayBuffer(BaseBuffer):
|
||||
self.depths[self.pos] = np.array(depth).copy()
|
||||
self.map_ids[self.pos] = np.array(map_id).copy()
|
||||
|
||||
if self.handle_timeout_termination:
|
||||
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
||||
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
Reference in New Issue
Block a user