Organize unused code

This commit is contained in:
Lu Junjie
2024-10-20 23:40:36 +08:00
parent 68c3baa7fe
commit 6a7aba86dc
3 changed files with 39 additions and 134 deletions

View File

@@ -152,7 +152,6 @@ class ReplayBuffer(BaseBuffer):
device: Union[th.device, str] = "cpu",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super(ReplayBuffer, self).__init__(buffer_size, observation_dim, device, n_envs=n_envs)
@@ -165,15 +164,10 @@ class ReplayBuffer(BaseBuffer):
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs) + observation_dim, dtype=np.float32)
self.observations = np.zeros((self.buffer_size, self.n_envs, observation_dim), dtype=np.float32)
self.goals = np.zeros((self.buffer_size, self.n_envs, 3), dtype=np.float32)
self.depths = np.zeros((self.buffer_size, self.n_envs, 1, image_WxH[1], image_WxH[0]), dtype=np.float32)
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.int16)
if psutil is not None:
total_memory_usage = self.observations.nbytes + self.goals.nbytes + self.depths.nbytes + self.map_ids.nbytes
@@ -187,16 +181,11 @@ class ReplayBuffer(BaseBuffer):
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add(
self,
def add(self,
obs: np.ndarray,
goal: np.ndarray,
depth: np.ndarray,
map_id: int,
infos: List[Dict[str, Any]],
) -> None:
# TODO: 删了obs的格式调整检查下还能不能正常放
map_id: int) -> None:
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
@@ -204,9 +193,6 @@ class ReplayBuffer(BaseBuffer):
self.depths[self.pos] = np.array(depth).copy()
self.map_ids[self.pos] = np.array(map_id).copy()
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True