Initial Commit (tested training, testing, and TRT conversion)

This commit is contained in:
Lu Junjie
2024-10-20 17:01:07 +08:00
parent 86d2f311f8
commit 5738088bae
221 changed files with 59249 additions and 6 deletions

View File

@@ -0,0 +1,29 @@
#include "flightlib/common/command.hpp"
namespace flightlib {
Command::Command() {}
Command::Command(const Scalar t, const Scalar thrust, const Vector<3>& omega)
: t(t), collective_thrust(thrust), omega(omega) {}
Command::Command(const Scalar t, const Vector<4>& thrusts)
: t(t), thrusts(thrusts) {}
bool Command::valid() const {
return std::isfinite(t) &&
((std::isfinite(collective_thrust) && omega.allFinite()) ^
thrusts.allFinite());
}
bool Command::isSingleRotorThrusts() const {
return std::isfinite(t) && thrusts.allFinite();
}
bool Command::isRatesThrust() const {
return std::isfinite(t) && std::isfinite(collective_thrust) &&
omega.allFinite();
}
} // namespace flightlib

View File

@@ -0,0 +1,33 @@
#include "flightlib/common/integrator_base.hpp"
namespace flightlib {
IntegratorBase::IntegratorBase(IntegratorBase::DynamicsFunction function,
const Scalar dt_max)
: dynamics_(function), dt_max_(dt_max) {}
bool IntegratorBase::integrate(const QuadState& initial,
QuadState* const final) const {
if (std::isnan(initial.t) || std::isnan(final->t)) return false;
if (initial.t >= final->t) return false;
return integrate(initial.x, final->t - initial.t, final->x);
}
bool IntegratorBase::integrate(const Ref<const Vector<>> initial,
const Scalar dt, Ref<Vector<>> final) const {
Scalar dt_remaining = dt;
Vector<> state = initial;
do {
const Scalar dt_this = std::min(dt_remaining, dt_max_);
if (!step(state, dt_this, final)) return false;
state = final;
dt_remaining -= dt_this;
} while (dt_remaining > 0.0);
return true;
}
Scalar IntegratorBase::dtMax() const { return dt_max_; }
} // namespace flightlib

View File

@@ -0,0 +1,15 @@
#include "flightlib/common/integrator_euler.hpp"
namespace flightlib {
bool IntegratorEuler::step(const Ref<const Vector<>> initial, const Scalar dt,
Ref<Vector<>> final) const {
Vector<> derivative(initial.rows());
if (!this->dynamics_(initial, derivative)) return false;
final = initial + dt * derivative;
return true;
}
} // namespace flightlib

View File

@@ -0,0 +1,34 @@
#include "flightlib/common/integrator_rk4.hpp"
namespace flightlib {
bool IntegratorRK4::step(const Ref<const Vector<>> initial, const Scalar dt,
Ref<Vector<>> final) const {
static const Vector<4> rk4_sum_vec{1.0 / 6.0, 2.0 / 6.0, 2.0 / 6.0,
1.0 / 6.0};
Matrix<> k = Matrix<>::Zero(initial.rows(), 4);
final = initial;
// k_1
if (!this->dynamics_(final, k.col(0))) return false;
// k_2
final = initial + 0.5 * dt * k.col(0);
if (!this->dynamics_(final, k.col(1))) return false;
// k_3
final = initial + 0.5 * dt * k.col(1);
if (!this->dynamics_(final, k.col(2))) return false;
// k_4
final = initial + dt * k.col(2);
if (!this->dynamics_(final, k.col(3))) return false;
final = initial + dt * k * rk4_sum_vec;
return true;
}
} // namespace flightlib

View File

@@ -0,0 +1,63 @@
#include "flightlib/common/logger.hpp"
namespace flightlib {
Logger::Logger(const std::string& name, const bool color)
: sink_(std::cout.rdbuf()), colored_(color) {
name_ = "[" + name + "]";
if (name_.size() < NAME_PADDING)
name_ = name_ + std::string(NAME_PADDING - name_.size(), ' ');
else
name_ = name_ + " ";
sink_.precision(DEFAULT_PRECISION);
}
Logger::Logger(const std::string& name, const std::string& filename)
: Logger(name, false) {
if (!filename.empty()) {
std::filebuf* fbuf = new std::filebuf;
if (fbuf->open(filename, std::ios::out))
sink_.rdbuf(fbuf);
else
warn("Could not open file %s. Logging to console!", filename);
}
sink_.precision(DEFAULT_PRECISION);
}
Logger::~Logger() {}
inline std::streamsize Logger::precision(const std::streamsize n) {
return sink_.precision(n);
}
inline void Logger::scientific(const bool on) {
if (on)
sink_ << std::scientific;
else
sink_ << std::fixed;
}
void Logger::info(const std::string& message) const {
if (colored_)
sink_ << name_ << message << std::endl;
else
sink_ << name_ << INFO << message << std::endl;
}
void Logger::warn(const std::string& message) const {
if (colored_)
sink_ << YELLOW << name_ << message << RESET << std::endl;
else
sink_ << name_ << WARN << message << std::endl;
}
void Logger::error(const std::string& message) const {
if (colored_)
sink_ << RED << name_ << message << RESET << std::endl;
else
sink_ << name_ << ERROR << message << std::endl;
}
} // namespace flightlib

232
flightlib/src/common/math.cpp Executable file
View File

@@ -0,0 +1,232 @@
#include "flightlib/common/math.hpp"
#include "iostream"
namespace flightlib {
Matrix<3, 3> skew(const Vector<3>& v) { return (Matrix<3, 3>() << 0, -v.z(), v.y(), v.z(), 0, -v.x(), -v.y(), v.x(), 0).finished(); }
Matrix<4, 4> Q_left(const Quaternion& q) {
return (Matrix<4, 4>() << q.w(), -q.x(), -q.y(), -q.z(), q.x(), q.w(), -q.z(), q.y(), q.y(), q.z(), q.w(), -q.x(), q.z(), -q.y(), q.x(), q.w())
.finished();
}
Matrix<4, 4> Q_right(const Quaternion& q) {
return (Matrix<4, 4>() << q.w(), -q.x(), -q.y(), -q.z(), q.x(), q.w(), q.z(), -q.y(), q.y(), -q.z(), q.w(), q.x(), q.z(), q.y(), -q.x(), q.w())
.finished();
}
Matrix<4, 3> qFromQeJacobian(const Quaternion& q) {
return (Matrix<4, 3>() << -1.0 / q.w() * q.vec().transpose(), Matrix<3, 3>::Identity()).finished();
}
Matrix<4, 4> qConjugateJacobian() { return Matrix<4, 1>(1, -1, -1, -1).asDiagonal(); }
Matrix<3, 3> qeRotJacobian(const Quaternion& q, const Matrix<3, 1>& t) {
return 2.0 * (Matrix<3, 3>() << (q.y() + q.z() * q.x() / q.w()) * t.y() + (q.z() - q.y() * q.x() / q.w()) * t.z(), // entry 0,0
-2.0 * q.y() * t.x() + (q.x() + q.z() * q.y() / q.w()) * t.y() + (q.w() - q.y() * q.y() / q.w()) * t.z(), // entry 0,1
-2.0 * q.z() * t.x() + (-q.w() + q.z() * q.z() / q.w()) * t.y() + (q.x() - q.y() * q.z() / q.w()) * t.z(), // entry 0,2
(q.y() - q.z() * q.x() / q.w()) * t.x() + (-2.0 * q.x()) * t.y() + (-q.w() + q.x() * q.x() / q.w()) * t.z(), // entry 1,0
(q.x() - q.z() * q.y() / q.w()) * t.x() + (q.z() + q.x() * q.y() / q.w()) * t.z(), // entry 1,1
(q.w() - q.z() * q.z() / q.w()) * t.x() + (-2.0 * q.z()) * t.y() + (q.y() + q.x() * q.z() / q.w()) * t.z(), // entry 1,2
(q.z() + q.y() * q.x() / q.w()) * t.x() + (q.w() - q.x() * q.x() / q.w()) * t.y() + (-2.0 * q.x()) * t.z(), // entry 2,0
(-q.w() + q.y() * q.y() / q.w()) * t.x() + (q.z() - q.x() * q.y() / q.w()) * t.y() + (-2.0 * q.y()) * t.z(), // entry 2,1
(q.x() + q.y() * q.z() / q.w()) * t.x() + (q.y() - q.x() * q.z() / q.w()) * t.y() // entry 2,2
)
.finished();
}
Matrix<3, 3> qeInvRotJacobian(const Quaternion& q, const Matrix<3, 1>& t) {
return 2.0 * (Matrix<3, 3>() << (q.y() - q.z() * q.x() / q.w()) * t.y() + (q.z() + q.y() * q.x() / q.w()) * t.z(), // entry 0,0
-2.0 * q.y() * t.x() + (q.x() - q.z() * q.y() / q.w()) * t.y() - (q.w() - q.y() * q.y() / q.w()) * t.z(), // entry 0,1
-2.0 * q.z() * t.x() + (q.w() - q.z() * q.z() / q.w()) * t.y() + (q.x() + q.y() * q.z() / q.w()) * t.z(), // entry 0,2
(q.y() + q.z() * q.x() / q.w()) * t.x() - 2.0 * q.x() * t.y() + (q.w() - q.x() * q.x() / q.w()) * t.z(), // entry 1,0
(q.x() + q.z() * q.y() / q.w()) * t.x() + (q.z() - q.x() * q.y() / q.w()) * t.z(), // entry 1,1
-(q.w() - q.z() * q.z() / q.w()) * t.x() - 2.0 * q.z() * t.y() + (q.y() - q.x() * q.z() / q.w()) * t.z(), // entry 1,2
(q.z() - q.y() * q.x() / q.w()) * t.x() - (q.w() - q.x() * q.x() / q.w()) * t.y() - 2.0 * q.x() * t.z(), // entry 2,0
(q.w() - q.y() * q.y() / q.w()) * t.x() + (q.z() + q.x() * q.y() / q.w()) * t.y() - 2.0 * q.y() * t.z(), // entry 2,1
(q.x() - q.y() * q.z() / q.w()) * t.x() + (q.y() + q.x() * q.z() / q.w()) * t.y() // entry 2,2
)
.finished();
}
void matrixToTripletList(const SparseMatrix& matrix, std::vector<SparseTriplet>* const list, const int row_offset, const int col_offset) {
list->reserve((size_t)matrix.nonZeros() + list->size());
for (int i = 0; i < matrix.outerSize(); i++) {
for (typename SparseMatrix::InnerIterator it(matrix, i); it; ++it) {
list->emplace_back(it.row() + row_offset, it.col() + col_offset, it.value());
}
}
}
void matrixToTripletList(const Matrix<Dynamic, Dynamic>& matrix, std::vector<SparseTriplet>* const list, const int row_offset, const int col_offset) {
const SparseMatrix sparse = matrix.sparseView();
matrixToTripletList(sparse, list, row_offset, col_offset);
}
void insert(const SparseMatrix& from, SparseMatrix* const into, const int row_offset, const int col_offset) {
std::vector<SparseTriplet> v;
matrixToTripletList(*into, &v);
matrixToTripletList(from, &v, row_offset, col_offset);
into->setFromTriplets(v.begin(), v.end(), [](const Scalar& older, const Scalar& newer) { return newer; });
}
void insert(const Matrix<>& from, SparseMatrix* const into, const int row_offset, const int col_offset) {
const SparseMatrix from_sparse = from.sparseView();
insert(from_sparse, into, row_offset, col_offset);
}
inline void insert(const Matrix<>& from, Matrix<>* const into, const int row_offset, const int col_offset) {
into->block(row_offset, col_offset, from.rows(), from.cols()) = from;
}
void quaternionToEuler(const Quaternion& quat, Ref<Vector<3>> euler) {
euler.x() = std::atan2(2 * quat.w() * quat.x() + 2 * quat.y() * quat.z(),
quat.w() * quat.w() - quat.x() * quat.x() - quat.y() * quat.y() + quat.z() * quat.z());
euler.y() = -std::asin(2 * quat.x() * quat.z() - 2 * quat.w() * quat.y());
euler.z() = std::atan2(2 * quat.w() * quat.z() + 2 * quat.x() * quat.y(),
quat.w() * quat.w() + quat.x() * quat.x() - quat.y() * quat.y() - quat.z() * quat.z());
}
std::vector<Scalar> transformationRos2Unity(const Matrix<4, 4>& ros_tran_mat) {
/// [ Transformation Matrix ] from ROS coordinate system (right hand)
/// to Unity coordinate system (left hand)
Matrix<4, 4> tran_mat = Matrix<4, 4>::Zero();
tran_mat(0, 0) = 1.0;
tran_mat(1, 2) = 1.0;
tran_mat(2, 1) = 1.0;
tran_mat(3, 3) = 1.0;
//
Matrix<4, 4> unity_tran_mat = tran_mat * ros_tran_mat * tran_mat.transpose();
// std::vector<Scalar> unity_tran_mat_vec(
// unity_tran_mat.data(),
// unity_tran_mat.data() + unity_tran_mat.rows() * unity_tran_mat.cols());
std::vector<Scalar> tran_unity;
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
tran_unity.push_back(unity_tran_mat(i, j));
}
}
return tran_unity;
}
std::vector<Scalar> quaternionRos2Unity(const Quaternion& ros_quat) {
/// [ Quaternion ] from ROS coordinate system (right hand)
/// to Unity coordinate system (left hand)
Matrix<3, 3> rot_mat = Matrix<3, 3>::Zero();
rot_mat(0, 0) = 1.0;
rot_mat(1, 2) = 1.0;
rot_mat(2, 1) = 1.0;
//
Matrix<3, 3> unity_rot_mat = rot_mat * ros_quat.toRotationMatrix() * rot_mat.transpose();
Quaternion unity_quat(unity_rot_mat);
std::vector<Scalar> unity_quat_vec{unity_quat.x(), unity_quat.y(), unity_quat.z(), unity_quat.w()};
return unity_quat_vec;
}
std::vector<Scalar> positionRos2Unity(const Vector<3>& ros_pos_vec) {
/// [ Position Vector ] from ROS coordinate system (right hand)
/// to Unity coordinate system (left hand)
std::vector<Scalar> unity_position{ros_pos_vec(0), ros_pos_vec(2), ros_pos_vec(1)};
return unity_position;
}
std::vector<Scalar> scalarRos2Unity(const Vector<3>& ros_scalar) {
/// [ Object Scalar Vector ] from ROS coordinate system (right hand)
/// to Unity coordinate system (left hand)
std::vector<Scalar> unity_scalar{ros_scalar(0), ros_scalar(2), ros_scalar(1)};
return unity_scalar;
}
// rpy顺序
void get_euler_from_R(Vector<3>& e, const Matrix<3, 3>& R) {
float phi = atan2(R(2, 1), R(2, 2));
float theta = asin(-R(2, 0));
float psi = atan2(R(1, 0), R(0, 0));
float pi = M_PI;
if (fabs(theta - pi / 2.0f) < 1.0e-3) {
phi = 0.0f;
psi = atan2(R(1, 2), R(0, 2));
} else if (fabs(theta + pi / 2.0f) < 1.0e-3) {
phi = 0.0f;
psi = atan2(-R(1, 2), -R(0, 2));
}
e(0) = phi;
e(1) = theta;
e(2) = psi;
}
double wrapMinusPiToPi(const double angle) {
if (angle >= -M_PIl && angle <= M_PIl) {
return angle;
}
double wrapped_angle = angle + M_PIl;
wrapped_angle = wrapZeroToTwoPi(wrapped_angle);
wrapped_angle -= M_PIl;
return wrapped_angle;
}
double wrapZeroToTwoPi(const double angle) {
if (angle >= 0.0 && angle <= 2.0 * M_PIl) {
return angle;
}
double wrapped_angle = fmod(angle, 2.0 * M_PIl);
if (wrapped_angle < 0.0) {
wrapped_angle += 2.0 * M_PIl;
}
return wrapped_angle;
}
// calculate and constrain the yaw angle per sim_t
float calculate_yaw(float yaw_cur, float yaw_ref, float sim_t) // yaw [-pi,pi]
{
float PI = 3.1415926;
float YAW_DOT_MAX_PER_SEC = 0.3 * PI;
float max_yaw_change = YAW_DOT_MAX_PER_SEC * sim_t;
float yaw_temp = yaw_ref;
float last_yaw_ = yaw_cur;
float yaw = 0;
if (yaw_temp - last_yaw_ > PI) {
if (yaw_temp - last_yaw_ - 2 * PI < -max_yaw_change) {
yaw = last_yaw_ - max_yaw_change;
if (yaw < -PI)
yaw += 2 * PI;
} else {
yaw = yaw_temp;
}
} else if (yaw_temp - last_yaw_ < -PI) {
if (yaw_temp - last_yaw_ + 2 * PI > max_yaw_change) {
yaw = last_yaw_ + max_yaw_change;
if (yaw > PI)
yaw -= 2 * PI;
} else {
yaw = yaw_temp;
}
} else {
if (yaw_temp - last_yaw_ < -max_yaw_change) {
yaw = last_yaw_ - max_yaw_change;
if (yaw < -PI)
yaw += 2 * PI;
} else if (yaw_temp - last_yaw_ > max_yaw_change) {
yaw = last_yaw_ + max_yaw_change;
if (yaw > PI)
yaw -= 2 * PI;
} else {
yaw = yaw_temp;
}
}
return yaw;
}
} // namespace flightlib

View File

@@ -0,0 +1,15 @@
#include "flightlib/common/parameter_base.hpp"
namespace flightlib {
ParameterBase::ParameterBase() {}
ParameterBase::ParameterBase(const YAML::Node& cfg_node)
: cfg_node_(cfg_node) {}
ParameterBase::ParameterBase(const std::string& cfg_path)
: cfg_node_(YAML::Node(cfg_path)) {}
ParameterBase::~ParameterBase() {}
} // namespace flightlib

View File

@@ -0,0 +1,41 @@
#include "flightlib/common/pend_state.hpp"
namespace flightlib {
PendState::PendState() {}
PendState::PendState(const Vector<IDX::SIZE>& x, const Scalar t) : x(x), t(t) {}
PendState::PendState(const PendState& state) : x(state.x), t(state.t) {}
PendState::~PendState() {}
Quaternion PendState::q() const {
return Quaternion(x(ATTW), x(ATTX), x(ATTY), x(ATTZ));
}
void PendState::q(const Quaternion quaternion) {
x(IDX::ATTW) = quaternion.w();
x(IDX::ATTX) = quaternion.x();
x(IDX::ATTY) = quaternion.y();
x(IDX::ATTZ) = quaternion.z();
}
Matrix<3, 3> PendState::R() const {
return Quaternion(x(ATTW), x(ATTX), x(ATTY), x(ATTZ)).toRotationMatrix();
}
void PendState::setZero() {
t = 0.0;
x.setZero();
x(ATTW) = 1.0;
}
std::ostream& operator<<(std::ostream& os, const PendState& state) {
os.precision(3);
os << "State at " << state.t << "s: [" << state.x.transpose() << "]";
os.precision();
return os;
}
} // namespace flightlib

View File

@@ -0,0 +1,41 @@
#include "flightlib/common/quad_state.hpp"
namespace flightlib {
QuadState::QuadState() {}
QuadState::QuadState(const Vector<IDX::SIZE>& x, const Scalar t) : x(x), t(t) {}
QuadState::QuadState(const QuadState& state) : x(state.x), t(state.t) {}
QuadState::~QuadState() {}
Quaternion QuadState::q() const {
return Quaternion(x(ATTW), x(ATTX), x(ATTY), x(ATTZ));
}
void QuadState::q(const Quaternion quaternion) {
x(IDX::ATTW) = quaternion.w();
x(IDX::ATTX) = quaternion.x();
x(IDX::ATTY) = quaternion.y();
x(IDX::ATTZ) = quaternion.z();
}
Matrix<3, 3> QuadState::R() const {
return Quaternion(x(ATTW), x(ATTX), x(ATTY), x(ATTZ)).toRotationMatrix();
}
void QuadState::setZero() {
t = 0.0;
x.setZero();
x(ATTW) = 1.0;
}
std::ostream& operator<<(std::ostream& os, const QuadState& state) {
os.precision(3);
os << "State at " << state.t << "s: [" << state.x.transpose() << "]";
os.precision();
return os;
}
} // namespace flightlib

View File

@@ -0,0 +1,117 @@
#include "flightlib/common/timer.hpp"
#include <cmath>
#include <limits>
namespace flightlib {
Timer::Timer(const std::string name, const std::string module)
: name_(name),
module_(module),
timing_mean_(0.0),
timing_last_(0.0),
timing_S_(0.0),
timing_min_(std::numeric_limits<Scalar>::max()),
timing_max_(0.0),
n_samples_(0) {}
Timer::Timer(const Timer& other)
: name_(other.name_),
module_(other.module_),
t_start_(other.t_start_),
timing_mean_(other.timing_mean_),
timing_last_(other.timing_last_),
timing_S_(other.timing_S_),
timing_min_(other.timing_min_),
timing_max_(other.timing_max_),
n_samples_(other.n_samples_) {}
void Timer::tic() { t_start_ = std::chrono::high_resolution_clock::now(); }
Scalar Timer::toc() {
// Calculate timing.
const TimePoint t_end = std::chrono::high_resolution_clock::now();
timing_last_ = 1e-9 * std::chrono::duration_cast<std::chrono::nanoseconds>(
t_end - t_start_)
.count();
n_samples_++;
// Set timing, filter if already initialized.
if (timing_mean_ <= 0.0) {
timing_mean_ = timing_last_;
} else {
const Scalar timing_mean_prev = timing_mean_;
timing_mean_ =
timing_mean_prev + (timing_last_ - timing_mean_prev) / n_samples_;
timing_S_ = timing_S_ + (timing_last_ - timing_mean_prev) *
(timing_last_ - timing_mean_);
}
timing_min_ = (timing_last_ < timing_min_) ? timing_last_ : timing_min_;
timing_max_ = (timing_last_ > timing_max_) ? timing_last_ : timing_max_;
t_start_ = t_end;
return timing_mean_;
}
Scalar Timer::operator()() const { return timing_mean_; }
Scalar Timer::mean() const { return timing_mean_; }
Scalar Timer::last() const { return timing_last_; }
Scalar Timer::min() const { return timing_min_; }
Scalar Timer::max() const { return timing_max_; }
Scalar Timer::std() const { return std::sqrt(timing_S_ / n_samples_); }
int Timer::count() const { return n_samples_; }
void Timer::reset() {
n_samples_ = 0u;
t_start_ = TimePoint();
timing_mean_ = 0.0;
timing_last_ = 0.0;
timing_S_ = 0.0;
timing_min_ = std::numeric_limits<Scalar>::max();
timing_max_ = 0.0;
}
void Timer::print() const { std::cout << *this; }
std::ostream& operator<<(std::ostream& os, const Timer& timer) {
if (!timer.module_.empty()) os << "[" << timer.module_ << "] ";
if (timer.n_samples_ < 1) {
os << "Timing " << timer.name_ << " has no call yet." << std::endl;
return os;
}
const std::streamsize prec = os.precision();
os.precision(3);
os << "Timing " << timer.name_ << " in " << timer.n_samples_ << " calls"
<< std::endl;
if (!timer.module_.empty()) os << "[" << timer.module_ << "] ";
os << "mean|std: " << 1000 * timer.timing_mean_ << " | "
<< 1000 * timer.timing_S_ << " ms "
<< "[min|max: " << 1000 * timer.timing_min_ << " | "
<< 1000 * timer.timing_max_ << " ms]" << std::endl;
os.precision(prec);
return os;
}
ScopedTimer::ScopedTimer(const std::string name, const std::string module)
: Timer(name, module) {
this->tic();
}
ScopedTimer::~ScopedTimer() {
this->toc();
this->print();
}
} // namespace flightlib