Organize unused code

This commit is contained in:
Lu Junjie
2024-10-20 23:40:36 +08:00
parent 68c3baa7fe
commit 6a7aba86dc
3 changed files with 39 additions and 134 deletions

View File

@@ -25,7 +25,6 @@ class FlightEnvVec(VecEnv):
self.world_box = np.zeros([6], dtype=np.float32) self.world_box = np.zeros([6], dtype=np.float32)
self.wrapper.getWorldBox(self.world_box) # xyz_min, xyz_max self.wrapper.getWorldBox(self.world_box) # xyz_min, xyz_max
self.reward_names = self.wrapper.getRewardNames() self.reward_names = self.wrapper.getRewardNames()
self.pretrained = False
# observations # observations
self._traj_cost = np.zeros([self.num_envs, 1], dtype=np.float32) # cost of current pred self._traj_cost = np.zeros([self.num_envs, 1], dtype=np.float32) # cost of current pred

View File

@@ -152,7 +152,6 @@ class ReplayBuffer(BaseBuffer):
device: Union[th.device, str] = "cpu", device: Union[th.device, str] = "cpu",
n_envs: int = 1, n_envs: int = 1,
optimize_memory_usage: bool = False, optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
): ):
super(ReplayBuffer, self).__init__(buffer_size, observation_dim, device, n_envs=n_envs) super(ReplayBuffer, self).__init__(buffer_size, observation_dim, device, n_envs=n_envs)
@@ -165,15 +164,10 @@ class ReplayBuffer(BaseBuffer):
self.optimize_memory_usage = optimize_memory_usage self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs) + observation_dim, dtype=np.float32) self.observations = np.zeros((self.buffer_size, self.n_envs, observation_dim), dtype=np.float32)
self.goals = np.zeros((self.buffer_size, self.n_envs, 3), dtype=np.float32) self.goals = np.zeros((self.buffer_size, self.n_envs, 3), dtype=np.float32)
self.depths = np.zeros((self.buffer_size, self.n_envs, 1, image_WxH[1], image_WxH[0]), dtype=np.float32) self.depths = np.zeros((self.buffer_size, self.n_envs, 1, image_WxH[1], image_WxH[0]), dtype=np.float32)
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.float32) self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.int16)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None: if psutil is not None:
total_memory_usage = self.observations.nbytes + self.goals.nbytes + self.depths.nbytes + self.map_ids.nbytes total_memory_usage = self.observations.nbytes + self.goals.nbytes + self.depths.nbytes + self.map_ids.nbytes
@@ -187,16 +181,11 @@ class ReplayBuffer(BaseBuffer):
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
) )
def add( def add(self,
self,
obs: np.ndarray, obs: np.ndarray,
goal: np.ndarray, goal: np.ndarray,
depth: np.ndarray, depth: np.ndarray,
map_id: int, map_id: int) -> None:
infos: List[Dict[str, Any]],
) -> None:
# TODO: 删了obs的格式调整检查下还能不能正常放
# Copy to avoid modification by reference # Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy() self.observations[self.pos] = np.array(obs).copy()
@@ -204,9 +193,6 @@ class ReplayBuffer(BaseBuffer):
self.depths[self.pos] = np.array(depth).copy() self.depths[self.pos] = np.array(depth).copy()
self.map_ids[self.pos] = np.array(map_id).copy() self.map_ids[self.pos] = np.array(map_id).copy()
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1 self.pos += 1
if self.pos == self.buffer_size: if self.pos == self.buffer_size:
self.full = True self.full = True

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch as th import torch as th
from torch.nn import functional as F from torch.nn import functional as F
from stable_baselines3.common.type_aliases import RolloutReturn, TrainFreq, TrainFrequencyUnit from stable_baselines3.common.type_aliases import RolloutReturn, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import should_collect_more_steps, get_schedule_fn, configure_logger from stable_baselines3.common.utils import should_collect_more_steps, get_schedule_fn, configure_logger, update_learning_rate
from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.utils import get_device from stable_baselines3.common.utils import get_device
@@ -172,82 +172,51 @@ class YopoAlgorithm:
path = policy_path + "/epoch{}_iter{}.pth".format(epoch_, step) path = policy_path + "/epoch{}_iter{}.pth".format(epoch_, step)
th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path) th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path)
# 模仿学习: 已弃用(暂未删除以备后续使用)
# 0、reset_state、get_depth、reset_goal
# 1、执行若干步env_num * 200
# 2、训练若干步batch_size = env_num, 训200次=1eposide
# 3、reset_state、get_depth、reset_goal
def imitation_learning( def imitation_learning(
self, self,
total_timesteps, total_timesteps,
callback=None, log_interval,
log_interval=4,
eval_env=None,
eval_freq=-1,
n_eval_episodes=5,
tb_log_name="YOPO",
eval_log_path=None,
reset_num_timesteps=True, reset_num_timesteps=True,
): ):
# 0. 初始化第一次观测 # 0. setup learn and init the first observation
total_timesteps, callback = self._setup_learn( self._setup_learn(total_timesteps, reset_num_timesteps)
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
)
self.pretrained = self.env.pretrained
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps: while self.num_timesteps < total_timesteps:
# 1. 数据收集 # 1. Rollout and Collect Data into Buffer
rollout = self.collect_rollouts( rollout = self.collect_rollouts(
self.env, self.env,
train_freq=self.train_freq, train_freq=self.train_freq,
action_noise=self.action_noise, replay_buffer=self.replay_buffer
callback=callback,
replay_buffer=self.replay_buffer,
log_interval=log_interval,
) )
if rollout.continue_training is False: if rollout.continue_training is False:
break break
# 2. 训练模型 # 2. Train the Policy
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
# If no `gradient_steps` is specified, # If no `gradient_steps` is specified, do as many gradients steps as steps performed during the rollout
# do as many gradients steps as steps performed during the rollout
gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
# Special case when the user passes `gradient_steps=0` if gradient_steps > 0: # Special case when the user passes `gradient_steps=0`
if gradient_steps > 0:
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
self.reset_state() self.reset_state()
iteration = int(self.num_timesteps / (self.train_freq.frequency * self.env.num_envs)) iteration = int(self.num_timesteps / (self.train_freq.frequency * self.env.num_envs))
# 3. 重置环境 # 3. reset the environment
if self.change_env_freq > 0 and iteration % self.change_env_freq == 0: if self.change_env_freq > 0 and iteration % self.change_env_freq == 0:
self.env.spawnTreesAndSavePointcloud() self.env.spawnTreesAndSavePointcloud()
self._map_id = self._map_id + 1 self._map_id = self._map_id + 1
self.reset_state() self.reset_state()
# 4. 终端打印log # 4. print the log and save weight
if log_interval is not None and iteration % log_interval[0] == 0: if log_interval is not None and iteration % log_interval[0] == 0:
self._dump_logs() self._dump_logs()
if log_interval is not None and iteration % log_interval[1] == 0: if log_interval is not None and iteration % log_interval[1] == 0:
policy_path = self.logger.get_dir() + "/Policy" policy_path = self.logger.get_dir() + "/Policy"
os.makedirs(policy_path, exist_ok=True) os.makedirs(policy_path, exist_ok=True)
path = policy_path + "/epoch0_iter{}.pth".format(iteration) path = policy_path + "/epoch0_iter{}.pth".format(iteration)
th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path) th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path)
callback.on_training_end()
def test_policy(self, num_rollouts: int = 10): def test_policy(self, num_rollouts: int = 10):
max_ep_length = 400 max_ep_length = 400
self.policy.set_training_mode(False) self.policy.set_training_mode(False)
@@ -294,19 +263,16 @@ class YopoAlgorithm:
def train(self, gradient_steps: int, batch_size: int) -> None: def train(self, gradient_steps: int, batch_size: int) -> None:
""" """
Sample the replay buffer and do the updates Imitation learning: sample data from the replay buffer and train the Policy
(gradient descent and update target networks)
""" """
# Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True) update_learning_rate(self.policy.optimizer, self.lr_schedule(self._current_progress_remaining))
# Update learning rate according to schedule (TODO in supervised learning)
self._update_learning_rate(self.policy.optimizer)
cost_losses = [] cost_losses = []
score_losses = [] # dy, dz, r, p, vx, vy, vz score_losses = [] # dy, dz, r, p, vx, vy, vz
for _ in range(gradient_steps): for _ in range(gradient_steps):
# Sample replay buffer # Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) replay_data = self.replay_buffer.sample(batch_size)
depth = th.from_numpy(replay_data.depths).to(self.device) depth = th.from_numpy(replay_data.depths).to(self.device)
pos = replay_data.observations[:, 0:3] pos = replay_data.observations[:, 0:3]
vel_acc_b = replay_data.observations[:, 3:9] vel_acc_b = replay_data.observations[:, 3:9]
@@ -354,69 +320,43 @@ class YopoAlgorithm:
def collect_rollouts( def collect_rollouts(
self, self,
env, env,
callback,
train_freq, train_freq,
replay_buffer, replay_buffer,
action_noise=None,
log_interval=None,
) -> RolloutReturn: ) -> RolloutReturn:
self.policy.set_training_mode(False) self.policy.set_training_mode(False)
num_collected_steps, num_collected_episodes = 0, 0 num_collected_steps, num_collected_episodes = 0, 0
assert isinstance(env, VecEnv), "You must pass a VecEnv" assert isinstance(env, VecEnv), "You must pass a VecEnv"
assert train_freq.frequency > 0, "Should at least collect one step or episode." assert train_freq.frequency > 0, "Should at least collect one step or episode."
if env.num_envs > 1: if env.num_envs > 1:
assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training." assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
callback.on_rollout_start()
continue_training = True continue_training = True
"""
1、pred endstate
2、get obs: self._last_obs = env.step(endstate)
3、get depth: self._last_depth = env.getDepthImage()
4、record to buffer and back to 1
"""
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
# 1. pred endstate used latest policy or pre-trained policy # 1. pred endstate used latest policy
sampled_endstate = self._sample_action(action_noise, env.num_envs) sampled_endstate = self._sample_action()
# 2. perform action # 2. perform action and get new observation
new_obs, rewards, dones = env.step(sampled_endstate) new_obs, rewards, dones = env.step(sampled_endstate)
self.num_timesteps += env.num_envs self.num_timesteps += env.num_envs
num_collected_steps += 1 num_collected_steps += 1
# Give access to local variables
callback.update_locals(locals())
# Only stop training if return value is False, not when it is None.
if callback.on_step() is False:
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes,
continue_training=False)
# 3. store the last obs, depth, and goal # 3. store the last obs, depth, and goal
# self._update_info_buffer(infos, dones)
self._store_transition(replay_buffer) self._store_transition(replay_buffer)
self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps) self._current_progress_remaining = 1.0 - float(self.num_timesteps) / float(self._total_timesteps)
# 4. update the obs, depth, goal, and reset the goal for the done-env # 4. update the obs, depth, and reset the goal for the done-env
self._last_obs = new_obs self._last_obs = new_obs
self._last_depth = env.getDepthImage() self._last_depth = env.getDepthImage()
for idx, done in enumerate(dones): for idx, done in enumerate(dones):
if done: if done:
# Update stats
num_collected_episodes += 1 num_collected_episodes += 1
self._episode_num += 1
# reset goal for the 'done' env # reset goal for the 'done' env
self._last_goal[idx] = self.get_random_goal(self._last_obs[idx]) self._last_goal[idx] = self.get_random_goal(self._last_obs[idx])
callback.on_rollout_end()
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
def prapare_input_observation(self, obs): def prapare_input_observation(self, obs):
@@ -468,38 +408,23 @@ class YopoAlgorithm:
costs_[i][indices] = 0.0 costs_[i][indices] = 0.0
return costs_ return costs_
def _setup_learn( def _setup_learn(self, total_timesteps, reset_num_timesteps=True):
self, # reset the time info
total_timesteps, self.start_time = time.time()
eval_env=None, if reset_num_timesteps:
callback=None, self.num_timesteps = 0 # steps of sampling
eval_freq=10000, self._n_updates = 0 # steps of policy updating
n_eval_episodes=5, self._total_timesteps = total_timesteps
log_path=None, self._num_timesteps_at_start = self.num_timesteps
reset_num_timesteps=True,
tb_log_name="run",
):
# ----------------- Init the First Observation ----------------- # ----------------- Init the First Observation -----------------
# super()._setup_learn() 中: self._last_obs = self.env.reset() self._last_obs = self.env.reset()
total_timesteps_, callback_ = super()._setup_learn(
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
log_path,
reset_num_timesteps,
tb_log_name,
)
self._last_depth = self.env.getDepthImage() self._last_depth = self.env.getDepthImage()
self._last_goal = np.zeros([self.env.num_envs, 3], dtype=np.float32) self._last_goal = np.zeros([self.env.num_envs, 3], dtype=np.float32)
for i in range(0, self.env.num_envs): for i in range(0, self.env.num_envs):
self._last_goal[i] = self.get_random_goal(self._last_obs[i]) self._last_goal[i] = self.get_random_goal(self._last_obs[i])
self._map_id = np.zeros((self.env.num_envs, 1), dtype=np.float32) self._map_id = np.zeros((self.env.num_envs, 1), dtype=np.float32)
return total_timesteps_, callback_
def _sample_action(self) -> np.ndarray: def _sample_action(self) -> np.ndarray:
""" """
use pretrained model or current model to sample the actions (endstate) use pretrained model or current model to sample the actions (endstate)
@@ -509,7 +434,7 @@ class YopoAlgorithm:
obs = self._last_obs.copy() obs = self._last_obs.copy()
goal_w = self._last_goal.copy() goal_w = self._last_goal.copy()
depth = th.from_numpy(self._last_depth).to(self.device) depth = th.from_numpy(self._last_depth).to(self.device)
# wxyz 四元数的逆[w, -x, -y, -z] # [w, x, y, z] inv() of quat: [w, -x, -y, -z]
quat_bw = -obs[:, 9:13] quat_bw = -obs[:, 9:13]
quat_bw[:, 0] = -quat_bw[:, 0] quat_bw[:, 0] = -quat_bw[:, 0]
vel_acc_norm_b = self.normalize_obs(obs[:, 3:9]) vel_acc_norm_b = self.normalize_obs(obs[:, 3:9])
@@ -546,12 +471,7 @@ class YopoAlgorithm:
depth = deepcopy(self._last_depth) depth = deepcopy(self._last_depth)
map_id = deepcopy(self._map_id) map_id = deepcopy(self._map_id)
replay_buffer.add( replay_buffer.add(obs, goal, depth, map_id)
obs,
goal,
depth,
map_id
)
def get_random_goal(self, uav_state=None): def get_random_goal(self, uav_state=None):
world = self.env.world_box world = self.env.world_box
@@ -564,8 +484,8 @@ class YopoAlgorithm:
random_goal = random_numbers * world_scale + world_center random_goal = random_numbers * world_scale + world_center
# 2. Use goal in front of the UAV (for better imitation learning) # 2. Use goal in front of the UAV (for better imitation learning)
else: else:
q_wb = uav_state[9:] q_wb = uav_state[9:].copy()
p_wb = uav_state[0:3] p_wb = uav_state[0:3].copy()
goal = np.random.randn(3) + np.array([2, 0, 0]) goal = np.random.randn(3) + np.array([2, 0, 0])
goal_dir = goal / np.linalg.norm(goal) goal_dir = goal / np.linalg.norm(goal)
random_goal_b = 50 * goal_dir random_goal_b = 50 * goal_dir