Initial Commit (tested training, testing, and TRT conversion)
This commit is contained in:
0
flightpolicy/__init__.py
Normal file
0
flightpolicy/__init__.py
Normal file
0
flightpolicy/envs/__init__.py
Normal file
0
flightpolicy/envs/__init__.py
Normal file
225
flightpolicy/envs/vec_env_wrapper.py
Normal file
225
flightpolicy/envs/vec_env_wrapper.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
from ruamel.yaml import YAML
|
||||
from typing import Any, List, Type
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices
|
||||
|
||||
|
||||
class FlightEnvVec(VecEnv):
|
||||
|
||||
def __init__(self, impl):
|
||||
self.wrapper = impl
|
||||
# params
|
||||
self.action_dim = self.wrapper.getActDim()
|
||||
self.observation_dim = self.wrapper.getObsDim()
|
||||
self.reward_dim = self.wrapper.getRewDim()
|
||||
self.img_width = self.wrapper.getImgWidth()
|
||||
self.img_height = self.wrapper.getImgHeight()
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
|
||||
scale = 32 # The downsampling factor of backbone
|
||||
self.network_height = scale * cfg["vertical_num"]
|
||||
self.network_width = scale * cfg["horizon_num"]
|
||||
self.world_box = np.zeros([6], dtype=np.float32)
|
||||
self.wrapper.getWorldBox(self.world_box) # xyz_min, xyz_max
|
||||
self.reward_names = self.wrapper.getRewardNames()
|
||||
self.pretrained = False
|
||||
|
||||
# observations
|
||||
self._traj_cost = np.zeros([self.num_envs, 1], dtype=np.float32) # cost of current pred
|
||||
self._traj_grad = np.zeros([self.num_envs, 9], dtype=np.float32) # gard of current pred x_pva y_pav z_pva
|
||||
self._observation = np.zeros([self.num_envs, self.observation_dim], dtype=np.float32)
|
||||
self._rgb_img_obs = np.zeros([self.num_envs, self.img_width * self.img_height * 3], dtype=np.uint8)
|
||||
self._gray_img_obs = np.zeros([self.num_envs, self.img_width * self.img_height], dtype=np.uint8)
|
||||
self._depth_img_obs = np.zeros([self.num_envs, self.img_width * self.img_height], dtype=np.float32)
|
||||
self._reward = np.zeros([self.num_envs, self.reward_dim], dtype=np.float32)
|
||||
self._done = np.zeros((self.num_envs), dtype=np.bool)
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# observation: [p_wb, v_b, a_b, q_wb] (in Body Frame); action: dp_pred; reward: cost
|
||||
def step(self, action):
|
||||
if action.ndim <= 1:
|
||||
action = action.reshape((self.num_envs, -1))
|
||||
if action.dtype == np.dtype('int'):
|
||||
action = action.astype(np.float32)
|
||||
self.wrapper.step(
|
||||
action,
|
||||
self._observation,
|
||||
self._reward,
|
||||
self._done,
|
||||
)
|
||||
|
||||
return (
|
||||
self._observation.copy(),
|
||||
self._reward.copy(),
|
||||
self._done.copy(),
|
||||
)
|
||||
|
||||
# observation: [p_wb, v_b, a_b, q_wb] (in Body Frame)
|
||||
def reset(self, random=True):
|
||||
self._reward = np.zeros([self.num_envs, self.reward_dim], dtype=np.float32)
|
||||
self.wrapper.reset(self._observation)
|
||||
return self._observation.copy()
|
||||
|
||||
# (in World Frame) goal_w
|
||||
def setGoal(self, goal):
|
||||
if goal.ndim <= 1:
|
||||
goal = goal.reshape((self.num_envs, -1))
|
||||
self.wrapper.setGoal(goal)
|
||||
|
||||
# (in World Frame) pos_wb, vel_w, acc_w, quat_wb
|
||||
def setState(self, pos, vel, acc, quad):
|
||||
if pos.ndim <= 1:
|
||||
pos = pos.reshape((self.num_envs, -1))
|
||||
quad = quad.reshape((self.num_envs, -1)) # wxyz
|
||||
vel = vel.reshape((self.num_envs, -1))
|
||||
acc = acc.reshape((self.num_envs, -1))
|
||||
state = np.hstack((pos, vel, acc, quad))
|
||||
self.wrapper.setState(state)
|
||||
|
||||
# map_id: The ID of the map used in the current training;
|
||||
# during data collection or DAgger, map_id=-1 indicates that the latest map is used.
|
||||
def setMapID(self, map_id):
|
||||
if map_id.ndim <= 1:
|
||||
map_id = map_id.reshape((self.num_envs, -1))
|
||||
self.wrapper.setMapID(map_id)
|
||||
|
||||
def getObs(self):
|
||||
self.wrapper.getObs(self._observation)
|
||||
return self._observation.copy()
|
||||
|
||||
# pred_dp: x_pva, y_pva, z_pva (in Body Frame); _traj_grad: x_pva, y_pva, z_pva (in Body Frame)
|
||||
def getCostAndGradient(self, pred_dp_in, traj_id):
|
||||
"""
|
||||
Args:
|
||||
pred_dp_in: the prediction of dp (x_pva, y_pva, z_pva)
|
||||
traj_id: the id of the trajectory in lattice
|
||||
|
||||
Returns: the cost and gradient of the prediction dp (x_pva, y_pva, z_pva)
|
||||
|
||||
"""
|
||||
if not isinstance(pred_dp_in, np.ndarray):
|
||||
pred_dp = pred_dp_in.detach().cpu().numpy()
|
||||
else:
|
||||
pred_dp = pred_dp_in
|
||||
|
||||
if pred_dp.ndim <= 1:
|
||||
pred_dp = pred_dp.reshape((self.num_envs, -1))
|
||||
if traj_id.ndim <= 1:
|
||||
traj_id = traj_id.reshape((self.num_envs, -1))
|
||||
self.wrapper.getCostAndGradient(pred_dp, traj_id, self._traj_cost, self._traj_grad)
|
||||
return self._traj_cost.copy(), self._traj_grad.copy()
|
||||
|
||||
def getRGBImage(self, rgb=False):
|
||||
if rgb:
|
||||
self.wrapper.getRGBImage(self._rgb_img_obs, True)
|
||||
return self._rgb_img_obs.copy()
|
||||
else:
|
||||
self.wrapper.getRGBImage(self._gray_img_obs, False)
|
||||
gray_img = self._gray_img_obs
|
||||
gray_img = np.reshape(gray_img, (gray_img.shape[0], self.img_height, self.img_width))
|
||||
return gray_img.copy()
|
||||
|
||||
def getDepthImage(self, resize=True):
|
||||
self.wrapper.getDepthImage(self._depth_img_obs)
|
||||
# normalize the depth values from 0-20m to 0-1
|
||||
depth = 1000 * self._depth_img_obs
|
||||
depth = np.minimum(depth, 20)
|
||||
depth = depth / 20.0
|
||||
depth[np.isnan(depth)] = 1.0
|
||||
depth = np.reshape(depth, (depth.shape[0], self.img_height, self.img_width))
|
||||
if resize:
|
||||
depth_ = np.zeros((depth.shape[0], self.network_height, self.network_width), dtype=np.float32())
|
||||
for i in range(depth.shape[0]):
|
||||
depth_[i] = cv2.resize(depth[i], (self.network_width, self.network_height))
|
||||
depth = np.expand_dims(depth_, axis=1)
|
||||
else:
|
||||
depth = np.expand_dims(depth, axis=1)
|
||||
return depth.copy()
|
||||
|
||||
def getStereoImage(self):
|
||||
# [n_envs, HxW]
|
||||
self.wrapper.getStereoImage(self._depth_img_obs)
|
||||
depth = self._depth_img_obs
|
||||
depth = np.minimum(depth, 20) / 20
|
||||
|
||||
depth_ = np.zeros((depth.shape[0], self.network_height, self.network_width), dtype=np.float32())
|
||||
for i in range(depth.shape[0]):
|
||||
nan_mask = np.isnan(depth[i])
|
||||
interpolated_image = cv2.inpaint(np.uint8(depth * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS)
|
||||
interpolated_image = interpolated_image.astype(np.float32) / 255.0
|
||||
interpolated_image = np.reshape(interpolated_image, (self.img_height, self.img_width))
|
||||
depth_[i] = cv2.resize(interpolated_image, (self.network_width, self.network_height))
|
||||
depth_ = np.expand_dims(depth_, axis=1)
|
||||
|
||||
return depth_.copy()
|
||||
|
||||
def getQuadState(self):
|
||||
self.wrapper.getQuadState(self._quadstate)
|
||||
return self._quadstate
|
||||
|
||||
def spawnTrees(self):
|
||||
self.wrapper.spawnTrees() # avg_tree_spacing is defined in .cfg
|
||||
|
||||
def savePointcloud(self, ply_idx):
|
||||
self.wrapper.savePointcloud(ply_idx)
|
||||
|
||||
def spawnTreesAndSavePointcloud(self, ply_idx=-1, spacing=-1):
|
||||
self.wrapper.spawnTreesAndSavePointcloud(ply_idx, spacing)
|
||||
|
||||
def seed(self, seed=0):
|
||||
self.wrapper.setSeed(seed)
|
||||
|
||||
def render(self):
|
||||
return self.wrapper.render()
|
||||
|
||||
def close(self):
|
||||
self.wrapper.close()
|
||||
|
||||
def connectUnity(self):
|
||||
self.wrapper.connectUnity()
|
||||
|
||||
def disconnectUnity(self):
|
||||
self.wrapper.disconnectUnity()
|
||||
|
||||
def env_method(
|
||||
self,
|
||||
method_name: str,
|
||||
*method_args,
|
||||
indices: VecEnvIndices = None,
|
||||
**method_kwargs
|
||||
) -> List[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [
|
||||
getattr(env_i, method_name)(*method_args, **method_kwargs)
|
||||
for env_i in target_envs
|
||||
]
|
||||
|
||||
def env_is_wrapped(
|
||||
self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None
|
||||
) -> List[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
# Import here to avoid a circular import
|
||||
from stable_baselines3.common import env_util
|
||||
|
||||
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
|
||||
|
||||
@property
|
||||
def num_envs(self):
|
||||
return self.wrapper.getNumOfEnvs()
|
||||
|
||||
def step_async(self):
|
||||
raise RuntimeError("This method is not implemented")
|
||||
|
||||
def step_wait(self):
|
||||
raise RuntimeError("This method is not implemented")
|
||||
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
raise RuntimeError("This method is not implemented")
|
||||
|
||||
def set_attr(self, attr_name, value, indices=None):
|
||||
raise RuntimeError("This method is not implemented")
|
||||
19
flightpolicy/setup.py
Normal file
19
flightpolicy/setup.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
from setuptools import setup, Extension, find_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
setup(
|
||||
name='flightpolicy',
|
||||
version='0.0.1',
|
||||
author='Junjie Lu',
|
||||
author_email='lqzx1998@tju.edu.cn',
|
||||
description='A Learning-based Planner for Autonomous Navigation',
|
||||
long_description='',
|
||||
packages=['flightpolicy'],
|
||||
)
|
||||
0
flightpolicy/yopo/__init__.py
Normal file
0
flightpolicy/yopo/__init__.py
Normal file
246
flightpolicy/yopo/buffers.py
Normal file
246
flightpolicy/yopo/buffers.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
The code is from stable_baseline3.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from gym import spaces
|
||||
from typing import Any, Dict, Generator, List, Optional, Union, NamedTuple
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
import torch as th
|
||||
import numpy as np
|
||||
import warnings
|
||||
from stable_baselines3.common.type_aliases import (
|
||||
ReplayBufferSamples,
|
||||
RolloutBufferSamples,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check memory used by replay buffer when possible
|
||||
import psutil
|
||||
except ImportError:
|
||||
psutil = None
|
||||
|
||||
|
||||
class BaseBuffer(ABC):
|
||||
"""
|
||||
Base class that represent a buffer (rollout or replay)
|
||||
|
||||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_dim: Observation space
|
||||
:param action_space: Action space
|
||||
:param device: PyTorch device
|
||||
to which the values will be converted
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_dim: int,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
n_envs: int = 1,
|
||||
):
|
||||
super(BaseBuffer, self).__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.observation_dim = observation_dim
|
||||
|
||||
self.pos = 0
|
||||
self.full = False
|
||||
self.device = device
|
||||
self.n_envs = n_envs
|
||||
|
||||
@staticmethod
|
||||
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
|
||||
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
|
||||
to [n_steps * n_envs, ...] (which maintain the order)
|
||||
|
||||
:param arr:
|
||||
:return:
|
||||
"""
|
||||
shape = arr.shape
|
||||
if len(shape) < 3:
|
||||
shape = shape + (1,)
|
||||
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
:return: The current size of the buffer
|
||||
"""
|
||||
if self.full:
|
||||
return self.buffer_size
|
||||
return self.pos
|
||||
|
||||
def add(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Add elements to the buffer.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def extend(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Add a new batch of transitions to the buffer
|
||||
"""
|
||||
# Do a for loop along the batch axis
|
||||
for data in zip(*args):
|
||||
self.add(*data)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the buffer.
|
||||
"""
|
||||
self.pos = 0
|
||||
self.full = False
|
||||
|
||||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
|
||||
"""
|
||||
:param batch_size: Number of element to sample
|
||||
:param env: associated gym VecEnv
|
||||
to normalize the observations/rewards when sampling
|
||||
:return:
|
||||
"""
|
||||
upper_bound = self.buffer_size if self.full else self.pos
|
||||
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
|
||||
return self._get_samples(batch_inds, env=env)
|
||||
|
||||
@abstractmethod
|
||||
def _get_samples(
|
||||
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
|
||||
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
|
||||
"""
|
||||
:param batch_inds:
|
||||
:param env:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
|
||||
"""
|
||||
Convert a numpy array to a PyTorch tensor.
|
||||
Note: it copies the data by default
|
||||
|
||||
:param array:
|
||||
:param copy: Whether to copy or not the data
|
||||
(may be useful to avoid changing things be reference)
|
||||
:return:
|
||||
"""
|
||||
if copy:
|
||||
return th.tensor(array).to(self.device)
|
||||
return th.as_tensor(array).to(self.device)
|
||||
|
||||
|
||||
class ReplayBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
goals: th.Tensor
|
||||
depths: th.Tensor
|
||||
map_id: th.Tensor
|
||||
|
||||
|
||||
class ReplayBuffer(BaseBuffer):
|
||||
"""
|
||||
self.observations
|
||||
self.goals
|
||||
self.depths
|
||||
self.map_ids
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_dim: spaces.Space,
|
||||
image_WxH: tuple,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
n_envs: int = 1,
|
||||
optimize_memory_usage: bool = False,
|
||||
handle_timeout_termination: bool = True,
|
||||
):
|
||||
super(ReplayBuffer, self).__init__(buffer_size, observation_dim, device, n_envs=n_envs)
|
||||
|
||||
# Adjust buffer size
|
||||
self.buffer_size = max(buffer_size // n_envs, 1)
|
||||
|
||||
# Check that the replay buffer can fit into the memory
|
||||
if psutil is not None:
|
||||
mem_available = psutil.virtual_memory().available
|
||||
|
||||
self.optimize_memory_usage = optimize_memory_usage
|
||||
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs) + observation_dim, dtype=np.float32)
|
||||
self.goals = np.zeros((self.buffer_size, self.n_envs, 3), dtype=np.float32)
|
||||
self.depths = np.zeros((self.buffer_size, self.n_envs, 1, image_WxH[1], image_WxH[0]), dtype=np.float32)
|
||||
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.float32)
|
||||
|
||||
# Handle timeouts termination properly if needed
|
||||
# see https://github.com/DLR-RM/stable-baselines3/issues/284
|
||||
self.handle_timeout_termination = handle_timeout_termination
|
||||
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
|
||||
if psutil is not None:
|
||||
total_memory_usage = self.observations.nbytes + self.goals.nbytes + self.depths.nbytes + self.map_ids.nbytes
|
||||
|
||||
if total_memory_usage > mem_available:
|
||||
# Convert to GB
|
||||
total_memory_usage /= 1e9
|
||||
mem_available /= 1e9
|
||||
warnings.warn(
|
||||
"This system does not have apparently enough memory to store the complete "
|
||||
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
obs: np.ndarray,
|
||||
goal: np.ndarray,
|
||||
depth: np.ndarray,
|
||||
map_id: int,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
|
||||
# TODO: 删了obs的格式调整,检查下还能不能正常放
|
||||
|
||||
# Copy to avoid modification by reference
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
self.goals[self.pos] = np.array(goal).copy()
|
||||
self.depths[self.pos] = np.array(depth).copy()
|
||||
self.map_ids[self.pos] = np.array(map_id).copy()
|
||||
|
||||
if self.handle_timeout_termination:
|
||||
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
||||
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
self.pos = 0
|
||||
|
||||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
||||
"""
|
||||
Sample elements from the replay buffer.
|
||||
Custom sampling when using memory efficient variant,
|
||||
as we should not sample the element with index `self.pos`
|
||||
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
||||
|
||||
:param batch_size: Number of element to sample
|
||||
:param env: associated gym VecEnv
|
||||
to normalize the observations/rewards when sampling
|
||||
:return:
|
||||
"""
|
||||
if not self.optimize_memory_usage:
|
||||
return super().sample(batch_size=batch_size, env=env)
|
||||
# Do not sample the element with index `self.pos` as the transitions is invalid
|
||||
# (we use only one array to store `obs` and `next_obs`)
|
||||
if self.full:
|
||||
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
|
||||
else:
|
||||
batch_inds = np.random.randint(0, self.pos, size=batch_size)
|
||||
return self._get_samples(batch_inds, env=env)
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
||||
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
|
||||
|
||||
data = (
|
||||
self.observations[batch_inds, env_indices, :],
|
||||
self.goals[batch_inds, env_indices, :],
|
||||
self.depths[batch_inds, env_indices, :],
|
||||
self.map_ids[batch_inds, env_indices, :],
|
||||
)
|
||||
return ReplayBufferSamples(*data)
|
||||
106
flightpolicy/yopo/dataloader.py
Normal file
106
flightpolicy/yopo/dataloader.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from ruamel.yaml import YAML
|
||||
import time
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
|
||||
class YopoDataset(Dataset):
|
||||
def __init__(self):
|
||||
super(YopoDataset, self).__init__()
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
|
||||
scale = 32 # 神经网络下采样倍数
|
||||
self.height = scale * cfg["vertical_num"]
|
||||
self.width = scale * cfg["horizon_num"]
|
||||
multiple_ = 0.5 * cfg["vel_max"]
|
||||
# The x-direction follows a log-normal distribution,
|
||||
# while the yz-direction follows a normal distribution with a mean of 0.
|
||||
self.v_max = cfg["vel_max"]
|
||||
v_des = multiple_ * cfg["vx_mean_unit"]
|
||||
self.vx_lognorm_mean = np.log(self.v_max - v_des)
|
||||
self.vx_logmorm_sigma = np.log(np.sqrt(v_des))
|
||||
self.v_mean = multiple_ * np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]])
|
||||
self.v_var = multiple_ * multiple_ * np.array([cfg["vx_var_unit"], cfg["vy_var_unit"], cfg["vz_var_unit"]])
|
||||
self.a_mean = multiple_ * multiple_ * np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]])
|
||||
self.a_var = multiple_ * multiple_ * multiple_ * multiple_ * np.array([cfg["ax_var_unit"], cfg["ay_var_unit"], cfg["az_var_unit"]])
|
||||
|
||||
print("Loading dataset, it may take a while...")
|
||||
data_cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
||||
data_dir = os.environ["FLIGHTMARE_PATH"] + data_cfg["env"]["dataset_path"]
|
||||
|
||||
self.img_list = []
|
||||
self.map_idx = []
|
||||
self.positions = np.empty((0, 3))
|
||||
self.quaternions = np.empty((0, 4))
|
||||
subfolders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
|
||||
subfolders.sort(key=lambda x: os.path.basename(x).lower())
|
||||
for i in range(len(subfolders)):
|
||||
img_dir = subfolders[i]
|
||||
file_names = [filename
|
||||
for filename in os.listdir(img_dir)
|
||||
if os.path.splitext(filename)[1] == '.tif']
|
||||
file_names.sort(key=lambda x: int(x.split('.')[0].split("_")[1])) # sort by filename
|
||||
images = [cv2.imread(img_dir + "/" + filename, -1).astype(np.float32) for filename in file_names]
|
||||
self.img_list.extend(images)
|
||||
self.map_idx.extend([i] * len(images))
|
||||
|
||||
label_path = img_dir + "/label.npz"
|
||||
labels = np.load(label_path)
|
||||
self.positions = np.vstack((self.positions, labels["positions"]))
|
||||
self.quaternions = np.vstack((self.quaternions, labels["quaternions"]))
|
||||
|
||||
print("Dataset loaded!")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self.img_list[item].shape[-2] != self.height or self.img_list[item].shape[-1] != self.width:
|
||||
self.img_list[item] = cv2.resize(self.img_list[item], (self.width, self.height)) # OpenCV and NumPy is Dif
|
||||
|
||||
if len(self.img_list[item].shape) == 2:
|
||||
self.img_list[item] = np.expand_dims(self.img_list[item], axis=0)
|
||||
|
||||
vel, acc = self._get_random_state()
|
||||
|
||||
# generate random goal in front of the quadrotor.
|
||||
q_wxyz = self.quaternions[item, :] # q: wxyz
|
||||
R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]])
|
||||
euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)]
|
||||
R_wB = R.from_euler('ZYX', [0, euler_angles[1], euler_angles[2]], degrees=False)
|
||||
goal_w = np.random.randn(3) + np.array([2, 0, 0])
|
||||
goal_b = R_wB.inv().apply(goal_w)
|
||||
|
||||
goal_dist = np.linalg.norm(goal_b)
|
||||
goal_dir = goal_b / goal_dist
|
||||
random_obs = np.hstack((vel, acc, goal_dir))
|
||||
|
||||
return (self.img_list[item], self.positions[item, :], self.quaternions[item, :], random_obs,
|
||||
self.map_idx[item]) # in body frame, vel_acc no-normalization
|
||||
|
||||
def _get_random_state(self):
|
||||
vel = self.v_mean + np.sqrt(self.v_var) * np.random.randn(3)
|
||||
acc = self.a_mean + np.sqrt(self.a_var) * np.random.randn(3)
|
||||
|
||||
right_skewed_vx = -1
|
||||
while right_skewed_vx < 0:
|
||||
right_skewed_vx = np.random.lognormal(mean=self.vx_lognorm_mean, sigma=self.vx_logmorm_sigma, size=None)
|
||||
right_skewed_vx = -right_skewed_vx + self.v_max + 0.2 # +0.2 to ensure v_max can be sampled
|
||||
vel[0] = right_skewed_vx
|
||||
# distribution of vx is visualized in docs/distribution_of_sampled_velocity.png (v_max=6)
|
||||
return vel, acc
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_loader = DataLoader(YopoDataset(), batch_size=32, shuffle=True, num_workers=4)
|
||||
|
||||
start = time.time()
|
||||
for epoch in range(1):
|
||||
last = time.time()
|
||||
for i, (depth, pos, quat, obs, id) in enumerate(data_loader):
|
||||
pass
|
||||
end = time.time()
|
||||
|
||||
print("总耗时:", end - start)
|
||||
137
flightpolicy/yopo/primitive_utils.py
Normal file
137
flightpolicy/yopo/primitive_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
|
||||
class LatticeParam():
|
||||
def __init__(self, cfg):
|
||||
self.vel_max = cfg["vel_max"]
|
||||
segment_time = 2 * cfg["radio_range"] / self.vel_max
|
||||
self.horizon_num = cfg["horizon_num"]
|
||||
self.vertical_num = cfg["vertical_num"]
|
||||
self.radio_num = cfg["radio_num"]
|
||||
self.vel_num = cfg["vel_num"]
|
||||
self.horizon_fov = cfg["horizon_camera_fov"] * (self.horizon_num - 1) / self.horizon_num
|
||||
self.vertical_fov = cfg["vertical_camera_fov"] * (self.vertical_num - 1) / self.vertical_num
|
||||
self.horizon_anchor_fov = cfg["horizon_anchor_fov"]
|
||||
self.vertical_anchor_fov = cfg["vertical_anchor_fov"]
|
||||
self.radio_range = cfg["radio_range"]
|
||||
self.vel_fov = cfg["vel_fov"]
|
||||
self.vel_prefile = cfg["vel_prefile"]
|
||||
self.acc_max = self.vel_max / segment_time
|
||||
print("---------------------")
|
||||
print("| max speed = ", round(self.vel_max, 1), " |")
|
||||
print("| traj time = ", round(segment_time, 1), " |")
|
||||
print("| max radio = ", round(2 * self.radio_range, 1), " |")
|
||||
print("---------------------")
|
||||
|
||||
|
||||
# ID in images:
|
||||
# [8, 7, 6,
|
||||
# 5, 4, 3,
|
||||
# 2, 1, 0]
|
||||
class LatticePrimitive():
|
||||
def __init__(self, LatticeParam):
|
||||
self.lattice_param = LatticeParam
|
||||
|
||||
if self.lattice_param.horizon_num == 1:
|
||||
direction_diff = 0
|
||||
else:
|
||||
direction_diff = (self.lattice_param.horizon_fov / 180.0 * np.pi) / (self.lattice_param.horizon_num - 1)
|
||||
if self.lattice_param.vertical_num == 1:
|
||||
altitude_diff = 0
|
||||
else:
|
||||
altitude_diff = (self.lattice_param.vertical_fov / 180.0 * np.pi) / (self.lattice_param.vertical_num - 1)
|
||||
radio_diff = self.lattice_param.radio_range / self.lattice_param.radio_num
|
||||
if self.lattice_param.vel_num == 1:
|
||||
vel_dir_diff = 0
|
||||
else:
|
||||
vel_dir_diff = (self.lattice_param.vel_fov / 180.0 * np.pi) / (self.lattice_param.vel_num - 1)
|
||||
|
||||
lattice_pos_list = []
|
||||
lattice_vel_list = []
|
||||
lattice_angle_list = []
|
||||
self.lattice_Rbp_list = []
|
||||
|
||||
# Primitives: Bottom to Top, Right to Left
|
||||
# We retain the code of sampling primitives with different velocity directions and length,
|
||||
# hope to predict multiple outputs in each grid like YOLO, but it does not work well.
|
||||
for h in range(0, self.lattice_param.radio_num):
|
||||
for i in range(0, self.lattice_param.vertical_num):
|
||||
for j in range(0, self.lattice_param.horizon_num):
|
||||
for k in range(0, self.lattice_param.vel_num):
|
||||
search_radio = (h + 1) * radio_diff
|
||||
alpha = -direction_diff * (self.lattice_param.horizon_num - 1) / 2 + j * direction_diff
|
||||
beta = -altitude_diff * (self.lattice_param.vertical_num - 1) / 2 + i * altitude_diff
|
||||
gamma = -vel_dir_diff * (self.lattice_param.vel_num - 1) / 2 + k * vel_dir_diff
|
||||
|
||||
pos_node = [np.cos(beta) * np.cos(alpha) * search_radio,
|
||||
np.cos(beta) * np.sin(alpha) * search_radio,
|
||||
np.sin(beta) * search_radio]
|
||||
vel_node = [np.cos(alpha + gamma) * self.lattice_param.vel_prefile,
|
||||
np.sin(alpha + gamma) * self.lattice_param.vel_prefile,
|
||||
0.0]
|
||||
lattice_pos_list.append(pos_node)
|
||||
lattice_vel_list.append(vel_node)
|
||||
lattice_angle_list.append([alpha, beta])
|
||||
# inner rotation: yaw-pitch-roll
|
||||
Rotation = R.from_euler('ZYX', [alpha, -beta, 0.0], degrees=False)
|
||||
self.lattice_Rbp_list.append(Rotation.as_matrix().astype(np.float32))
|
||||
|
||||
self.lattice_pos_node = np.array(lattice_pos_list)
|
||||
self.lattice_vel_node = np.array(lattice_vel_list)
|
||||
self.lattice_angle_node = np.array(lattice_angle_list)
|
||||
|
||||
self.yaw_diff = 0.5 * self.lattice_param.horizon_anchor_fov / 180.0 * np.pi
|
||||
self.pitch_diff = 0.5 * self.lattice_param.vertical_anchor_fov / 180.0 * np.pi
|
||||
|
||||
def getStateLattice(self, id):
|
||||
return self.lattice_pos_node[id, :], self.lattice_vel_node[id, :]
|
||||
|
||||
# yaw, pitch
|
||||
def getAngleLattice(self, id):
|
||||
return self.lattice_angle_node[id, 0], self.lattice_angle_node[id, 1]
|
||||
|
||||
def getRotation(self, id):
|
||||
return self.lattice_Rbp_list[id]
|
||||
|
||||
|
||||
"""
|
||||
From body to world
|
||||
p_w = Rwb * p_b + t_w
|
||||
"""
|
||||
|
||||
def rotate(q_wb, pos_b): # quat: wxzy
|
||||
pos_w = np.zeros_like(pos_b)
|
||||
if q_wb.ndim == 1:
|
||||
Rotation_wb = R.from_quat([q_wb[1], q_wb[2], q_wb[3], q_wb[0]]) # xyzw
|
||||
pos_w[:] = np.dot(Rotation_wb.as_matrix(), pos_b[:])
|
||||
else:
|
||||
for i in range(0, q_wb.shape[0]):
|
||||
Rotation_wb = R.from_quat([q_wb[i, 1], q_wb[i, 2], q_wb[i, 3], q_wb[i, 0]]) # xyzw
|
||||
pos_w[i, :] = np.dot(Rotation_wb.as_matrix(), pos_b[i, :])
|
||||
return pos_w
|
||||
|
||||
def transform(q_wb, tw, pos_b):
|
||||
pos_w = rotate(q_wb, pos_b)
|
||||
return pos_w + tw
|
||||
|
||||
|
||||
"""
|
||||
From world to body
|
||||
p_b = Rbw * (p_w - t_w)
|
||||
"""
|
||||
|
||||
def rotate_inv(q_wb, pos_w): # quat: wxzy
|
||||
pos_b = np.zeros_like(pos_w)
|
||||
if q_wb.ndim == 1:
|
||||
Rotation_bw = R.from_quat([-q_wb[1], -q_wb[2], -q_wb[3], q_wb[0]]) # xyzw
|
||||
pos_b[:] = np.dot(Rotation_bw.as_matrix(), pos_w[:])
|
||||
else:
|
||||
for i in range(0, q_wb.shape[0]):
|
||||
Rotation_bw = R.from_quat([-q_wb[i, 1], -q_wb[i, 2], -q_wb[i, 3], q_wb[i, 0]]) # xyzw
|
||||
pos_b[i, :] = np.dot(Rotation_bw.as_matrix(), pos_w[i, :])
|
||||
return pos_b
|
||||
|
||||
def transform_inv(q_wb, tw, pos_w):
|
||||
pos_b = rotate_inv(q_wb, pos_w - tw)
|
||||
return pos_b
|
||||
392
flightpolicy/yopo/resnet.py
Normal file
392
flightpolicy/yopo/resnet.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""
|
||||
this code is from torchvision.
|
||||
"""
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from typing import Type, Any, Callable, Union, List, Optional
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
num_classes: int = 1000,
|
||||
zero_init_residual: bool = False,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||
|
||||
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
# x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
# x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _resnet(
|
||||
arch: str,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any
|
||||
) -> ResNet:
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
610
flightpolicy/yopo/yopo_algorithm.py
Normal file
610
flightpolicy/yopo/yopo_algorithm.py
Normal file
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
Training Strategy
|
||||
supervised learning, imitation learning, testing, rollout
|
||||
"""
|
||||
import time
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch.nn import functional as F
|
||||
from stable_baselines3.common.type_aliases import RolloutReturn, TrainFreq, TrainFrequencyUnit
|
||||
from stable_baselines3.common.utils import should_collect_more_steps, get_schedule_fn, configure_logger
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.utils import get_device
|
||||
|
||||
# -----------
|
||||
from flightpolicy.yopo.yopo_policy import YopoPolicy
|
||||
from flightpolicy.yopo.dataloader import YopoDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from flightpolicy.yopo.primitive_utils import transform, rotate, transform_inv, rotate_inv
|
||||
from flightpolicy.yopo.primitive_utils import LatticeParam, LatticePrimitive
|
||||
from flightpolicy.yopo.buffers import ReplayBuffer
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
class YopoAlgorithm:
|
||||
def __init__(
|
||||
self,
|
||||
env=None,
|
||||
learning_rate=0.001,
|
||||
is_imitation=False,
|
||||
buffer_size=1_000_000,
|
||||
learning_starts=100,
|
||||
batch_size=256,
|
||||
unselect=0.0,
|
||||
loss_weight=[],
|
||||
train_freq=(1, "step"),
|
||||
change_env_freq=-1,
|
||||
gradient_steps=1,
|
||||
policy_kwargs=None,
|
||||
tensorboard_log=None,
|
||||
verbose=0,
|
||||
max_grad_norm=10,
|
||||
):
|
||||
# env
|
||||
self.observation_dim = env.observation_dim
|
||||
self.action_dim = env.action_dim
|
||||
self.n_envs = env.num_envs
|
||||
self.env = env
|
||||
# training
|
||||
self.learning_rate = learning_rate
|
||||
self.batch_size = batch_size
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.unselect = unselect
|
||||
self.loss_weight = loss_weight
|
||||
self.device = get_device('auto')
|
||||
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
|
||||
# imitation learning
|
||||
self.is_imitation = is_imitation
|
||||
self.buffer_size = buffer_size
|
||||
self.train_freq = train_freq
|
||||
self.change_env_freq = change_env_freq
|
||||
self.learning_starts = learning_starts
|
||||
self.gradient_steps = gradient_steps
|
||||
self.freq_reset = False
|
||||
self.replay_buffer = None
|
||||
# logger
|
||||
self.verbose = verbose
|
||||
self.tensorboard_log = tensorboard_log
|
||||
self.logger = configure_logger(self.verbose, self.tensorboard_log, "YOPO")
|
||||
# trajectory
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
|
||||
self.lattice_space = LatticeParam(cfg)
|
||||
self.lattice_primitive = LatticePrimitive(self.lattice_space)
|
||||
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self):
|
||||
self.lr_schedule = get_schedule_fn(self.learning_rate)
|
||||
|
||||
# buffer: pos, quat, vel, acc, depth
|
||||
if self.replay_buffer is None and self.is_imitation:
|
||||
self.replay_buffer = ReplayBuffer(
|
||||
self.buffer_size,
|
||||
self.observation_dim,
|
||||
(self.env.network_width, self.env.network_height),
|
||||
device=self.device,
|
||||
n_envs=self.n_envs,
|
||||
)
|
||||
|
||||
print("Loading Network...")
|
||||
|
||||
self.policy = YopoPolicy(
|
||||
observation_dim=self.observation_dim,
|
||||
action_dim=self.action_dim,
|
||||
lattice_space=self.lattice_space,
|
||||
lattice_primitive=self.lattice_primitive,
|
||||
lr_schedule=self.lr_schedule,
|
||||
train_env=self.env,
|
||||
device=self.device,
|
||||
**self.policy_kwargs
|
||||
)
|
||||
|
||||
self.policy = self.policy.to(self.device)
|
||||
print("Network Loaded!")
|
||||
|
||||
if self.is_imitation:
|
||||
self._convert_train_freq()
|
||||
|
||||
def supervised_learning(self, epoch, log_interval):
|
||||
self.policy.set_training_mode(True)
|
||||
data_loader = DataLoader(YopoDataset(), batch_size=self.batch_size, shuffle=True, num_workers=0)
|
||||
|
||||
n_updates = 0
|
||||
start_time = time.time()
|
||||
for epoch_ in range(epoch):
|
||||
cost_losses = [] # Performance (score) of prediction
|
||||
score_losses = [] # Accuracy of the predicted score
|
||||
for step, (depth, pos, quat, obs_b, map_id) in enumerate(data_loader): # obs: body frame
|
||||
if depth.shape[0] != self.batch_size: # batch size == num of env
|
||||
continue
|
||||
n_updates = n_updates + 1
|
||||
depth = depth.to(self.device)
|
||||
obs_b = obs_b.numpy()
|
||||
|
||||
goal_dir = obs_b[:, 6:9]
|
||||
goal_w = transform(quat.numpy(), pos.numpy(), 10 * goal_dir) # Rwb * g_b + t_wb
|
||||
vel_w = rotate(quat.numpy(), obs_b[:, 0:3])
|
||||
acc_w = rotate(quat.numpy(), obs_b[:, 3:6])
|
||||
self.env.setState(pos.numpy(), vel_w, acc_w, quat.numpy())
|
||||
self.env.setGoal(goal_w)
|
||||
self.env.setMapID(map_id.numpy())
|
||||
|
||||
obs_b[:, 0:6] = self.normalize_obs(obs_b[:, 0:6])
|
||||
obs_norm_input = self.prapare_input_observation(obs_b)
|
||||
obs_norm_input = obs_norm_input.to(self.device)
|
||||
endstate_score_predictions, cost_labels = self.policy.inference(depth, obs_norm_input)
|
||||
score_labels = cost_labels.clone().detach()
|
||||
cost_labels_record = th.mean(cost_labels)
|
||||
cost_labels_filtered = self.cost_filter(cost_labels)
|
||||
|
||||
cost_loss = th.mean(cost_labels_filtered)
|
||||
score_loss = F.smooth_l1_loss(endstate_score_predictions[:, 9, :], score_labels)
|
||||
loss = self.loss_weight[0] * cost_loss + self.loss_weight[1] * score_loss
|
||||
cost_losses.append(self.loss_weight[0] * cost_labels_record.item())
|
||||
score_losses.append(self.loss_weight[1] * score_loss.item())
|
||||
|
||||
# Optimize the policy
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip gradient norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
if log_interval is not None and n_updates % log_interval[0] == 0:
|
||||
self.logger.record("time/epoch", epoch_, exclude="tensorboard")
|
||||
self.logger.record("time/steps", n_updates, exclude="tensorboard")
|
||||
self.logger.record("time/batch_fps", log_interval[0] / (time.time() - start_time),
|
||||
exclude="tensorboard")
|
||||
self.logger.record("train/trajectory_cost", np.mean(cost_losses))
|
||||
self.logger.record("train/score_loss", np.mean(score_losses))
|
||||
self.logger.dump(step=n_updates)
|
||||
cost_losses = []
|
||||
score_losses = []
|
||||
start_time = time.time()
|
||||
|
||||
if log_interval is not None and n_updates % log_interval[1] == 0:
|
||||
policy_path = self.logger.get_dir() + "/Policy"
|
||||
os.makedirs(policy_path, exist_ok=True)
|
||||
path = policy_path + "/epoch{}_iter{}.pth".format(epoch_, step)
|
||||
th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path)
|
||||
|
||||
# 模仿学习: 已弃用(暂未删除以备后续使用)
|
||||
# 0、reset_state、get_depth、reset_goal
|
||||
# 1、执行若干步(env_num * 200)
|
||||
# 2、训练若干步(batch_size = env_num, 训200次=1eposide)
|
||||
# 3、reset_state、get_depth、reset_goal
|
||||
def imitation_learning(
|
||||
self,
|
||||
total_timesteps,
|
||||
callback=None,
|
||||
log_interval=4,
|
||||
eval_env=None,
|
||||
eval_freq=-1,
|
||||
n_eval_episodes=5,
|
||||
tb_log_name="YOPO",
|
||||
eval_log_path=None,
|
||||
reset_num_timesteps=True,
|
||||
):
|
||||
|
||||
# 0. 初始化第一次观测
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
eval_log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
)
|
||||
self.pretrained = self.env.pretrained
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
# 1. 数据收集
|
||||
rollout = self.collect_rollouts(
|
||||
self.env,
|
||||
train_freq=self.train_freq,
|
||||
action_noise=self.action_noise,
|
||||
callback=callback,
|
||||
replay_buffer=self.replay_buffer,
|
||||
log_interval=log_interval,
|
||||
)
|
||||
|
||||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
# 2. 训练模型
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
# If no `gradient_steps` is specified,
|
||||
# do as many gradients steps as steps performed during the rollout
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
|
||||
# Special case when the user passes `gradient_steps=0`
|
||||
if gradient_steps > 0:
|
||||
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
|
||||
self.reset_state()
|
||||
|
||||
iteration = int(self.num_timesteps / (self.train_freq.frequency * self.env.num_envs))
|
||||
|
||||
# 3. 重置环境
|
||||
if self.change_env_freq > 0 and iteration % self.change_env_freq == 0:
|
||||
self.env.spawnTreesAndSavePointcloud()
|
||||
self._map_id = self._map_id + 1
|
||||
self.reset_state()
|
||||
|
||||
# 4. 终端打印log
|
||||
if log_interval is not None and iteration % log_interval[0] == 0:
|
||||
self._dump_logs()
|
||||
|
||||
if log_interval is not None and iteration % log_interval[1] == 0:
|
||||
policy_path = self.logger.get_dir() + "/Policy"
|
||||
os.makedirs(policy_path, exist_ok=True)
|
||||
path = policy_path + "/epoch0_iter{}.pth".format(iteration)
|
||||
th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path)
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
def test_policy(self, num_rollouts: int = 10):
|
||||
max_ep_length = 400
|
||||
self.policy.set_training_mode(False)
|
||||
|
||||
for n_roll in range(num_rollouts):
|
||||
obs, done, ep_len = self.env.reset(), False, 0
|
||||
costs = []
|
||||
# Randomly initialize the position and goal on the map.
|
||||
random_y_goal = 20 * random.uniform(-1, 1) + 20
|
||||
random_y = 20 * random.uniform(-1, 1) + 20
|
||||
goal_w = np.array([[20, random_y_goal, 2]])
|
||||
obs = np.array([[-20, random_y, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]])
|
||||
self.env.setGoal(goal_w)
|
||||
self.env.setState(np.array([[-20, random_y, 2]]), np.array([[0, 0, 0]]),
|
||||
np.array([[0, 0, 0]]), np.array([[1, 0, 0, 0]]))
|
||||
self.env.render()
|
||||
|
||||
while not (done or (ep_len >= max_ep_length)):
|
||||
depth = self.env.getDepthImage()
|
||||
depth_vis = cv2.resize(depth[0][0], (320, 180))
|
||||
cv2.imshow("depth", depth_vis)
|
||||
cv2.waitKey(10)
|
||||
depth = th.from_numpy(depth).to(self.device)
|
||||
|
||||
# transform observation to body frame
|
||||
quat_bw = -obs[:, 9:13] # inv of quat: [w, -x, -y, -z]
|
||||
quat_bw[:, 0] = -quat_bw[:, 0]
|
||||
goal_dir_w = (goal_w - obs[:, 0:3]) / np.linalg.norm(goal_w - obs[:, 0:3])
|
||||
goal_dir_b = rotate(quat_bw, goal_dir_w)
|
||||
vel_acc_norm_b = self.normalize_obs(obs[:, 3:9])
|
||||
obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b))
|
||||
|
||||
obs_norm_input = self.prapare_input_observation(obs_norm_b)
|
||||
obs_norm_input = obs_norm_input.to(self.device)
|
||||
|
||||
endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input)
|
||||
endstate_pred = endstate_pred.cpu().numpy()
|
||||
# obs: p_wb, v_b, a_b, q_wb; endstate_pred: pva in body frame
|
||||
obs, rew, done = self.env.step(endstate_pred)
|
||||
|
||||
costs.append(rew)
|
||||
ep_len += 1
|
||||
print("round ", n_roll, ", total steps:", len(costs), ", avg cost:", sum(costs) / len(costs))
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int) -> None:
|
||||
"""
|
||||
Sample the replay buffer and do the updates
|
||||
(gradient descent and update target networks)
|
||||
"""
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(True)
|
||||
# Update learning rate according to schedule (TODO in supervised learning)
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
cost_losses = []
|
||||
score_losses = [] # dy, dz, r, p, vx, vy, vz
|
||||
for _ in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
depth = th.from_numpy(replay_data.depths).to(self.device)
|
||||
pos = replay_data.observations[:, 0:3]
|
||||
vel_acc_b = replay_data.observations[:, 3:9]
|
||||
quat_wb = replay_data.observations[:, 9:13]
|
||||
goal_w = replay_data.goals
|
||||
map_id = replay_data.map_id
|
||||
|
||||
goal_dir_w = (goal_w - pos) / np.linalg.norm(goal_w - pos, axis=1)[:, np.newaxis]
|
||||
goal_dir_b = rotate_inv(quat_wb, goal_dir_w)
|
||||
vel_w = rotate(quat_wb, vel_acc_b[:, 0:3])
|
||||
acc_w = rotate(quat_wb, vel_acc_b[:, 3:6])
|
||||
self.env.setState(pos, vel_w, acc_w, quat_wb)
|
||||
self.env.setGoal(goal_w)
|
||||
self.env.setMapID(map_id)
|
||||
|
||||
vel_acc_norm_b = self.normalize_obs(vel_acc_b)
|
||||
obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b))
|
||||
obs_norm_input = self.prapare_input_observation(obs_norm_b)
|
||||
obs_norm_input = obs_norm_input.to(self.device)
|
||||
endstate_score_predictions, cost_labels = self.policy.inference(depth, obs_norm_input)
|
||||
score_labels = cost_labels.clone().detach()
|
||||
|
||||
cost_labels_record = th.mean(cost_labels)
|
||||
cost_labels_filtered = self.cost_filter(cost_labels)
|
||||
|
||||
cost_loss = th.mean(cost_labels_filtered)
|
||||
score_loss = F.smooth_l1_loss(endstate_score_predictions[:, 9, :], score_labels)
|
||||
loss = self.loss_weight[0] * cost_loss + self.loss_weight[1] * score_loss
|
||||
cost_losses.append(self.loss_weight[0] * cost_labels_record.item())
|
||||
score_losses.append(self.loss_weight[1] * score_loss.item())
|
||||
|
||||
# Optimize the policy
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip gradient norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
# Increase update counter
|
||||
self._n_updates += gradient_steps
|
||||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/trajectory_cost", np.mean(cost_losses))
|
||||
self.logger.record("train/score_loss", np.mean(score_losses))
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
env,
|
||||
callback,
|
||||
train_freq,
|
||||
replay_buffer,
|
||||
action_noise=None,
|
||||
log_interval=None,
|
||||
) -> RolloutReturn:
|
||||
|
||||
self.policy.set_training_mode(False)
|
||||
|
||||
num_collected_steps, num_collected_episodes = 0, 0
|
||||
|
||||
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
||||
assert train_freq.frequency > 0, "Should at least collect one step or episode."
|
||||
|
||||
if env.num_envs > 1:
|
||||
assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
"""
|
||||
1、pred endstate
|
||||
2、get obs: self._last_obs = env.step(endstate)
|
||||
3、get depth: self._last_depth = env.getDepthImage()
|
||||
4、record to buffer and back to 1
|
||||
"""
|
||||
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
||||
|
||||
# 1. pred endstate used latest policy or pre-trained policy
|
||||
sampled_endstate = self._sample_action(action_noise, env.num_envs)
|
||||
|
||||
# 2. perform action
|
||||
new_obs, rewards, dones = env.step(sampled_endstate)
|
||||
|
||||
self.num_timesteps += env.num_envs
|
||||
num_collected_steps += 1
|
||||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes,
|
||||
continue_training=False)
|
||||
|
||||
# 3. store the last obs, depth, and goal
|
||||
# self._update_info_buffer(infos, dones)
|
||||
self._store_transition(replay_buffer)
|
||||
self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
|
||||
|
||||
# 4. update the obs, depth, goal, and reset the goal for the done-env
|
||||
self._last_obs = new_obs
|
||||
self._last_depth = env.getDepthImage()
|
||||
|
||||
for idx, done in enumerate(dones):
|
||||
if done:
|
||||
# Update stats
|
||||
num_collected_episodes += 1
|
||||
self._episode_num += 1
|
||||
# reset goal for the 'done' env
|
||||
self._last_goal[idx] = self.get_random_goal(self._last_obs[idx])
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
|
||||
|
||||
def prapare_input_observation(self, obs):
|
||||
"""
|
||||
convert the observation from body frame to primitive frame,
|
||||
and then concatenate it with the depth features (to ensure the translational invariance)
|
||||
"""
|
||||
obs_return = np.ones(
|
||||
(obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]),
|
||||
dtype=np.float32)
|
||||
id = 0
|
||||
v_b = obs[:, 0:3]
|
||||
a_b = obs[:, 3:6]
|
||||
g_b = obs[:, 6:9]
|
||||
for i in range(self.lattice_space.vertical_num - 1, -1, -1):
|
||||
for j in range(self.lattice_space.horizon_num - 1, -1, -1):
|
||||
Rbp = self.lattice_primitive.getRotation(id)
|
||||
v_p = np.dot(Rbp.T, v_b.T).T
|
||||
a_p = np.dot(Rbp.T, a_b.T).T
|
||||
g_p = np.dot(Rbp.T, g_b.T).T
|
||||
obs_return[:, i, j, 0:3] = v_p
|
||||
obs_return[:, i, j, 3:6] = a_p
|
||||
obs_return[:, i, j, 6:9] = g_p
|
||||
# obs_return[:, i, j, 0:6] = self.normalize_obs(obs_return[:, i, j, 0:6])
|
||||
id = id + 1
|
||||
obs_return = np.transpose(obs_return, [0, 3, 1, 2])
|
||||
return th.from_numpy(obs_return)
|
||||
|
||||
def unnormalize_obs(self, vel_acc_norm):
|
||||
vel = vel_acc_norm[:, 0:3] * self.lattice_space.vel_max
|
||||
acc = vel_acc_norm[:, 3:6] * self.lattice_space.acc_max
|
||||
return np.hstack((vel, acc))
|
||||
|
||||
def normalize_obs(self, vel_acc):
|
||||
vel_norm = vel_acc[:, 0:3] / self.lattice_space.vel_max
|
||||
acc_norm = vel_acc[:, 3:6] / self.lattice_space.acc_max
|
||||
return np.hstack((vel_norm, acc_norm))
|
||||
|
||||
def cost_filter(self, costs_):
|
||||
# costs_ = costs.clone() # NOTE: numpy.ndarray is reference invocation!
|
||||
if self.unselect <= 0 or self.unselect >= 1:
|
||||
return costs_
|
||||
# filter the negative samples
|
||||
rows, cols = costs_.size()
|
||||
unselect = int(cols * self.unselect)
|
||||
for i in range(rows):
|
||||
row = costs_[i]
|
||||
_, indices = th.topk(row, unselect)
|
||||
costs_[i][indices] = 0.0
|
||||
return costs_
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps,
|
||||
eval_env=None,
|
||||
callback=None,
|
||||
eval_freq=10000,
|
||||
n_eval_episodes=5,
|
||||
log_path=None,
|
||||
reset_num_timesteps=True,
|
||||
tb_log_name="run",
|
||||
):
|
||||
# ----------------- Init the First Observation -----------------
|
||||
# super()._setup_learn() 中: self._last_obs = self.env.reset()
|
||||
total_timesteps_, callback_ = super()._setup_learn(
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
)
|
||||
self._last_depth = self.env.getDepthImage()
|
||||
self._last_goal = np.zeros([self.env.num_envs, 3], dtype=np.float32)
|
||||
for i in range(0, self.env.num_envs):
|
||||
self._last_goal[i] = self.get_random_goal(self._last_obs[i])
|
||||
self._map_id = np.zeros((self.env.num_envs, 1), dtype=np.float32)
|
||||
|
||||
return total_timesteps_, callback_
|
||||
|
||||
|
||||
def _sample_action(self) -> np.ndarray:
|
||||
"""
|
||||
use pretrained model or current model to sample the actions (endstate)
|
||||
self._last_obs: last state obs [p, v, a, q]
|
||||
self._last_depth: last depth image
|
||||
"""
|
||||
obs = self._last_obs.copy()
|
||||
goal_w = self._last_goal.copy()
|
||||
depth = th.from_numpy(self._last_depth).to(self.device)
|
||||
# wxyz 四元数的逆[w, -x, -y, -z]
|
||||
quat_bw = -obs[:, 9:13]
|
||||
quat_bw[:, 0] = -quat_bw[:, 0]
|
||||
vel_acc_norm_b = self.normalize_obs(obs[:, 3:9])
|
||||
goal_dir_w = (goal_w - obs[:, 0:3]) / np.linalg.norm(goal_w - obs[:, 0:3], axis=1)[:, np.newaxis]
|
||||
goal_dir_b = rotate(quat_bw, goal_dir_w)
|
||||
obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b))
|
||||
|
||||
obs_norm_input = self.prapare_input_observation(obs_norm_b)
|
||||
obs_norm_input = obs_norm_input.to(self.device)
|
||||
|
||||
endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input)
|
||||
endstate_pred = endstate_pred.cpu().numpy()
|
||||
return endstate_pred
|
||||
|
||||
def _dump_logs(self) -> None:
|
||||
"""
|
||||
Write log.
|
||||
"""
|
||||
time_elapsed = time.time() - self.start_time
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
|
||||
self.logger.record("time/fps", fps, exclude="tensorboard")
|
||||
self.logger.record("time/minute_elapsed", int(time_elapsed / 60), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.record("train/map_id", self._map_id[0][0], exclude="tensorboard")
|
||||
|
||||
# Pass the number of timesteps for tensorboard
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
def _store_transition(self, replay_buffer):
|
||||
|
||||
# Avoid modification by reference
|
||||
obs = deepcopy(self._last_obs)
|
||||
goal = deepcopy(self._last_goal)
|
||||
depth = deepcopy(self._last_depth)
|
||||
map_id = deepcopy(self._map_id)
|
||||
|
||||
replay_buffer.add(
|
||||
obs,
|
||||
goal,
|
||||
depth,
|
||||
map_id
|
||||
)
|
||||
|
||||
def get_random_goal(self, uav_state=None):
|
||||
world = self.env.world_box
|
||||
# 1. Use random goal in map
|
||||
if uav_state is None:
|
||||
world_center = np.array([world[3] + world[0], world[4] + world[1], world[5] + world[2]]) / 2
|
||||
world_scale = np.array([world[3] - world[0], world[4] - world[1], 1.0])
|
||||
# The goal can be out of the world, if strictly in world: np.random.uniform(-0.5, 0.5, 3)
|
||||
random_numbers = np.random.uniform(-1, 1, 3)
|
||||
random_goal = random_numbers * world_scale + world_center
|
||||
# 2. Use goal in front of the UAV (for better imitation learning)
|
||||
else:
|
||||
q_wb = uav_state[9:]
|
||||
p_wb = uav_state[0:3]
|
||||
goal = np.random.randn(3) + np.array([2, 0, 0])
|
||||
goal_dir = goal / np.linalg.norm(goal)
|
||||
random_goal_b = 50 * goal_dir
|
||||
random_goal_w = transform(q_wb, p_wb, random_goal_b)
|
||||
random_goal_w[2] = np.random.uniform(-1, 1) * 1 + (world[5] + world[2]) / 2
|
||||
random_goal = random_goal_w
|
||||
|
||||
return random_goal
|
||||
|
||||
def reset_state(self):
|
||||
"""
|
||||
Reset the state and map_id after every train step, because the state and map_id are manually set in training,
|
||||
which will affect the cost, controller, image render, and other parts for next rollout
|
||||
"""
|
||||
self.env.setMapID(-np.ones((self.env.num_envs, 1)))
|
||||
self._last_obs = self.env.reset()
|
||||
self._last_depth = self.env.getDepthImage()
|
||||
for i in range(0, self.env.num_envs):
|
||||
self._last_goal[i] = self.get_random_goal(self._last_obs[i])
|
||||
|
||||
def _convert_train_freq(self) -> None:
|
||||
"""
|
||||
Convert `train_freq` parameter (int or tuple)
|
||||
to a TrainFreq object.
|
||||
"""
|
||||
if not isinstance(self.train_freq, TrainFreq):
|
||||
train_freq = self.train_freq
|
||||
|
||||
# The value of the train frequency will be checked later
|
||||
if not isinstance(train_freq, tuple):
|
||||
train_freq = (train_freq, "step")
|
||||
|
||||
try:
|
||||
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
|
||||
|
||||
if not isinstance(train_freq[0], int):
|
||||
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
|
||||
|
||||
self.train_freq = TrainFreq(*train_freq)
|
||||
71
flightpolicy/yopo/yopo_network.py
Normal file
71
flightpolicy/yopo/yopo_network.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# The backbone and the custom gradient layer.
|
||||
import time
|
||||
import torch as th
|
||||
import torch.nn
|
||||
import numpy as np
|
||||
from torchvision.models import mobilenet_v3_small
|
||||
from flightpolicy.yopo.resnet import resnet18
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
# 18ms, Fast and effective.
|
||||
class ResNet18(torch.nn.Module):
|
||||
def __init__(self, output_dim: int, primitive_shape: int):
|
||||
super(ResNet18, self).__init__()
|
||||
self.cnn = resnet18(pretrained=False)
|
||||
self.cnn.conv1 = th.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if (primitive_shape != 1):
|
||||
self.cnn.avgpool = th.nn.Sequential()
|
||||
self.cnn.fc = th.nn.Conv2d(512, output_dim, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.features_dim = output_dim
|
||||
|
||||
def forward(self, depth: th.Tensor) -> th.Tensor:
|
||||
return self.cnn(depth)
|
||||
|
||||
|
||||
# 20ms, Performs worse than ResNet and is slower than ResNet-18.
|
||||
class MobileNet(th.nn.Module):
|
||||
def __init__(self, output_dim: int):
|
||||
super(MobileNet, self).__init__()
|
||||
self.cnn = mobilenet_v3_small(pretrained=False)
|
||||
self.cnn.features[0][0] = th.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.cnn.classifier = th.nn.Linear(576, output_dim)
|
||||
self.features_dim = output_dim
|
||||
|
||||
def forward(self, depth: th.Tensor) -> th.Tensor:
|
||||
return self.cnn(depth)
|
||||
|
||||
|
||||
def YopoBackbone(output_dim, primitive_shape):
|
||||
return ResNet18(output_dim, primitive_shape)
|
||||
|
||||
|
||||
class CostAndGradLayer(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_dp, train_env, primitive_id):
|
||||
# print("input ", input_dp.shape)
|
||||
device = input_dp.device
|
||||
cost, grad = train_env.getCostAndGradient(input_dp, primitive_id)
|
||||
grad = np.minimum(grad, 1.0) # Gradient clipping: Prevent excessively large values.
|
||||
cost = torch.tensor(cost).to(device)
|
||||
grad = torch.tensor(grad).to(device)
|
||||
ctx.save_for_backward(grad)
|
||||
cost.requires_grad = True
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, cost_grad_input):
|
||||
grad, = ctx.saved_tensors
|
||||
return_grad = th.bmm(grad.unsqueeze(-1), cost_grad_input.unsqueeze(-1)).squeeze(dim=2)
|
||||
# print("grad ", return_grad.shape)
|
||||
# print("grad: ", return_grad)
|
||||
return return_grad, None, None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = YopoBackbone(64, 3)
|
||||
input_ = torch.zeros((1, 1, 96, 96))
|
||||
start = time.time()
|
||||
output = net(input_)
|
||||
print(time.time() - start)
|
||||
213
flightpolicy/yopo/yopo_policy.py
Normal file
213
flightpolicy/yopo/yopo_policy.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
YOPO Network
|
||||
forward, prediction, pre-processing, post-processing
|
||||
"""
|
||||
|
||||
import torch as th
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Type
|
||||
from flightpolicy.yopo.yopo_network import YopoBackbone, CostAndGradLayer
|
||||
|
||||
|
||||
class YopoPolicy(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_dim,
|
||||
action_dim, # x_pva, y_pva, z_pva, score
|
||||
hidden_state,
|
||||
lattice_space,
|
||||
lattice_primitive,
|
||||
lr_schedule=None,
|
||||
train_env=None,
|
||||
net_arch=None,
|
||||
activation_fn=nn.ReLU,
|
||||
normalize_images=True,
|
||||
optimizer_class=th.optim.Adam,
|
||||
optimizer_kwargs=None,
|
||||
device=None
|
||||
):
|
||||
super(YopoPolicy, self).__init__()
|
||||
self.observation_dim = observation_dim
|
||||
self.action_dim = action_dim
|
||||
self.lattice_space = lattice_space
|
||||
self.hidden_state = hidden_state
|
||||
self.lattice_primitive = lattice_primitive
|
||||
self.optimizer_class = optimizer_class
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.normalize_images = normalize_images
|
||||
self.yaw_diff = lattice_primitive.yaw_diff
|
||||
self.pitch_diff = lattice_primitive.pitch_diff
|
||||
self.train_env = train_env
|
||||
self.device = device
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule=None) -> None:
|
||||
# output state dim = action dim + score
|
||||
output_dim = (self.action_dim + 1) * self.lattice_space.vel_num * self.lattice_space.radio_num
|
||||
# input state dim = hidden_state + vel + acc + goal
|
||||
input_dim = self.hidden_state + 9
|
||||
self.image_backbone = YopoBackbone(self.hidden_state,
|
||||
self.lattice_space.horizon_num * self.lattice_space.vertical_num)
|
||||
self.state_backbone = nn.Sequential()
|
||||
self.yopo_header = self.create_header(input_dim, output_dim, self.net_arch, self.activation_fn, True)
|
||||
self.grad_layer = CostAndGradLayer.apply
|
||||
# Setup optimizer with initial learning rate
|
||||
learning_rate = lr_schedule(1) if lr_schedule is not None else 1e-3
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=learning_rate)
|
||||
|
||||
# TenserRT Transfer
|
||||
def forward(self, depth: th.Tensor, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
forward propagation of neural network, only used for TensorRT conversion.
|
||||
"""
|
||||
depth_feature = self.image_backbone(depth)
|
||||
obs_feature = self.state_backbone(obs)
|
||||
input_tensor = th.cat((obs_feature, depth_feature), 1)
|
||||
output = self.yopo_header(input_tensor)
|
||||
# [batch, endstate+score, lattice_row, lattice_col]
|
||||
return output
|
||||
|
||||
# Training Policy
|
||||
def inference(self, depth: th.Tensor, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
For network training:
|
||||
(1) predicted the endstate(end_state) and score
|
||||
(2) record the gradients and costs of prediction
|
||||
"""
|
||||
depth_feature = self.image_backbone(depth)
|
||||
obs_feature = self.state_backbone(obs)
|
||||
input_tensor = th.cat((obs_feature, depth_feature), 1)
|
||||
output = self.yopo_header(input_tensor)
|
||||
|
||||
# [batch, endstate+score, lattice_num]
|
||||
batch_size = obs.shape[0]
|
||||
output = output.view(batch_size, 10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
|
||||
# output.register_hook(self.print_grad)
|
||||
endstate_pred = output[:, 0:9, :]
|
||||
score_pred = output[:, 9, :]
|
||||
|
||||
endstate_score_predictions = th.zeros_like(output).to(self.device)
|
||||
cost_labels = th.zeros((batch_size, self.lattice_space.horizon_num * self.lattice_space.vertical_num)).to(self.device)
|
||||
for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num):
|
||||
id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
|
||||
ids = id * np.ones((batch_size, 1))
|
||||
endstate = self.pred_to_endstate(endstate_pred[:, :, i], id)
|
||||
# endstate.register_hook(self.print_grad)
|
||||
cost_label = self.grad_layer(endstate, self.train_env, ids)
|
||||
endstate_score_predictions[:, 0:9, i] = endstate
|
||||
endstate_score_predictions[:, 9, i] = score_pred[:, i]
|
||||
cost_labels[:, i] = cost_label.squeeze()
|
||||
|
||||
return endstate_score_predictions, cost_labels
|
||||
|
||||
# Testing Policy
|
||||
def predict(self, depth: th.Tensor, obs: th.Tensor, return_all_preds=False) -> th.Tensor:
|
||||
"""
|
||||
For network testing:
|
||||
(1) predicted the endstate(end_state) and score
|
||||
"""
|
||||
with th.no_grad():
|
||||
depth_feature = self.image_backbone(depth)
|
||||
obs_feature = self.state_backbone(obs.float())
|
||||
input_tensor = th.cat((obs_feature, depth_feature), 1)
|
||||
output = self.yopo_header(input_tensor)
|
||||
batch_size = obs.shape[0]
|
||||
output = output.view(batch_size, 10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
|
||||
endstate_pred = output[:, 0:9, :]
|
||||
score_pred = output[:, 9, :]
|
||||
|
||||
if not return_all_preds:
|
||||
endstate_prediction = th.zeros(batch_size, self.action_dim)
|
||||
score_prediction = th.zeros(batch_size, 1)
|
||||
for i in range(0, batch_size):
|
||||
action_id = th.argmin(score_pred[i]).item()
|
||||
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
|
||||
endstate_prediction[i] = self.pred_to_endstate(th.unsqueeze(endstate_pred[i, :, action_id], 0), lattice_id)
|
||||
score_prediction[i] = score_pred[i, action_id]
|
||||
else:
|
||||
endstate_prediction = th.zeros_like(endstate_pred)
|
||||
score_prediction = score_pred
|
||||
for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num):
|
||||
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
|
||||
endstate = self.pred_to_endstate(endstate_pred[:, :, i], lattice_id)
|
||||
endstate_prediction[:, :, i] = endstate
|
||||
|
||||
return endstate_prediction, score_prediction
|
||||
|
||||
def pred_to_endstate(self, endstate_pred: th.Tensor, id: int):
|
||||
"""
|
||||
Transform the predicted state to the body frame.
|
||||
"""
|
||||
delta_yaw = endstate_pred[:, 0] * self.yaw_diff
|
||||
delta_pitch = endstate_pred[:, 1] * self.pitch_diff
|
||||
radio = endstate_pred[:, 2] * self.lattice_space.radio_range + self.lattice_space.radio_range
|
||||
yaw, pitch = self.lattice_primitive.getAngleLattice(id)
|
||||
endstate_x = th.cos(pitch + delta_pitch) * th.cos(yaw + delta_yaw) * radio
|
||||
endstate_y = th.cos(pitch + delta_pitch) * th.sin(yaw + delta_yaw) * radio
|
||||
endstate_z = th.sin(pitch + delta_pitch) * radio
|
||||
endstate_p = th.stack((endstate_x, endstate_y, endstate_z), dim=1)
|
||||
|
||||
endstate_vp = endstate_pred[:, 3:6] * self.lattice_space.vel_max
|
||||
endstate_ap = endstate_pred[:, 6:9] * self.lattice_space.acc_max
|
||||
Rbp = self.lattice_primitive.getRotation(id)
|
||||
endstate_vb = th.matmul(th.tensor(Rbp).to(self.device), endstate_vp.t()).t()
|
||||
endstate_ab = th.matmul(th.tensor(Rbp).to(self.device), endstate_ap.t()).t()
|
||||
endstate = th.cat((endstate_p, endstate_vb, endstate_ab), dim=1)
|
||||
endstate[:, [0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[:, [0, 3, 6, 1, 4, 7, 2, 5, 8]]
|
||||
return endstate
|
||||
|
||||
def create_header(self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
net_arch: List[int],
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
squash_output: bool = False,
|
||||
) -> nn.Sequential:
|
||||
|
||||
if len(net_arch) > 0:
|
||||
modules = [nn.Conv2d(in_channels=input_dim, out_channels=net_arch[0], kernel_size=1, stride=1, padding=0),
|
||||
activation_fn()]
|
||||
else:
|
||||
modules = []
|
||||
|
||||
for idx in range(len(net_arch) - 1):
|
||||
modules.append(nn.Conv2d(in_channels=net_arch[idx], out_channels=net_arch[idx + 1], kernel_size=1, stride=1,
|
||||
padding=0))
|
||||
modules.append(activation_fn())
|
||||
|
||||
if output_dim > 0:
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
|
||||
modules.append(nn.Conv2d(in_channels=last_layer_dim, out_channels=output_dim, kernel_size=1, stride=1,
|
||||
padding=0))
|
||||
if squash_output:
|
||||
modules.append(nn.Tanh())
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
def get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = {"net_arch": self.net_arch,
|
||||
"hidden_state": self.hidden_state,
|
||||
"observation_dim": self.observation_dim,
|
||||
"action_dim": self.action_dim,
|
||||
"activation_fn": self.activation_fn,
|
||||
"lattice_space": self.lattice_space,
|
||||
"lattice_primitive": self.lattice_primitive
|
||||
}
|
||||
return data
|
||||
|
||||
def print_grad(ctx, grad):
|
||||
print("grad of hook: ", grad)
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
"""
|
||||
Put the policy in either training or evaluation mode.
|
||||
|
||||
This affects certain modules, such as batch normalisation and dropout.
|
||||
|
||||
:param mode: if true, set to training mode, else set to evaluation mode
|
||||
"""
|
||||
self.train(mode)
|
||||
Reference in New Issue
Block a user