Skip to content
This repository has been archived by the owner on Nov 17, 2021. It is now read-only.

Fixing misuse of the standard math library (performance audit) #41

Merged
merged 9 commits into from
Mar 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions matrix/AxisAngle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class AxisAngle : public Vector<Type, 3>
Vector<Type, 3>()
{
AxisAngle &v = *this;
Type ang = Type(2.0f)*acosf(q(0));
Type mag = sinf(ang/2.0f);
if (fabsf(mag) > 0) {
Type ang = Type(2.0f)*acos(q(0));
Type mag = sin(ang/2.0f);
if (fabs(mag) > 0) {
v(0) = ang*q(1)/mag;
v(1) = ang*q(2)/mag;
v(2) = ang*q(3)/mag;
Expand Down
1 change: 0 additions & 1 deletion matrix/Matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#pragma once

#include <cmath>
#include <cstdio>
#include <cstring>

Expand Down
18 changes: 9 additions & 9 deletions matrix/Quaternion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,28 @@ class Quaternion : public Vector<Type, 4>
Quaternion &q = *this;
Type t = R.trace();
if (t > Type(0)) {
t = sqrtf(Type(1) + t);
t = sqrt(Type(1) + t);
q(0) = Type(0.5) * t;
t = Type(0.5) / t;
q(1) = (R(2,1) - R(1,2)) * t;
q(2) = (R(0,2) - R(2,0)) * t;
q(3) = (R(1,0) - R(0,1)) * t;
} else if (R(0,0) > R(1,1) && R(0,0) > R(2,2)) {
t = sqrtf(Type(1) + R(0,0) - R(1,1) - R(2,2));
t = sqrt(Type(1) + R(0,0) - R(1,1) - R(2,2));
q(1) = Type(0.5) * t;
t = Type(0.5) / t;
q(0) = (R(2,1) - R(1,2)) * t;
q(2) = (R(1,0) + R(0,1)) * t;
q(3) = (R(0,2) + R(2,0)) * t;
} else if (R(1,1) > R(2,2)) {
t = sqrtf(Type(1) - R(0,0) + R(1,1) - R(2,2));
t = sqrt(Type(1) - R(0,0) + R(1,1) - R(2,2));
q(2) = Type(0.5) * t;
t = Type(0.5) / t;
q(0) = (R(0,2) - R(2,0)) * t;
q(1) = (R(1,0) + R(0,1)) * t;
q(3) = (R(2,1) + R(1,2)) * t;
} else {
t = sqrtf(Type(1) - R(0,0) - R(1,1) + R(2,2));
t = sqrt(Type(1) - R(0,0) - R(1,1) + R(2,2));
q(3) = Type(0.5) * t;
t = Type(0.5) / t;
q(0) = (R(1,0) - R(0,1)) * t;
Expand Down Expand Up @@ -171,8 +171,8 @@ class Quaternion : public Vector<Type, 4>
q(0) = Type(1.0);
q(1) = q(2) = q(3) = 0;
} else {
Type magnitude = sinf(angle / 2.0f);
q(0) = cosf(angle / 2.0f);
Type magnitude = sin(angle / 2.0f);
q(0) = cos(angle / 2.0f);
q(1) = axis(0) * magnitude;
q(2) = axis(1) * magnitude;
q(3) = axis(2) * magnitude;
Expand Down Expand Up @@ -389,9 +389,9 @@ class Quaternion : public Vector<Type, 4>
q(1) = q(2) = q(3) = 0;
}

Type magnitude = sinf(theta / 2.0f);
Type magnitude = sin(theta / 2.0f);

q(0) = cosf(theta / 2.0f);
q(0) = cos(theta / 2.0f);
q(1) = axis(0) * magnitude;
q(2) = axis(1) * magnitude;
q(3) = axis(2) * magnitude;
Expand All @@ -418,7 +418,7 @@ class Quaternion : public Vector<Type, 4>

if (axis_magnitude >= Type(1e-10)) {
vec = vec / axis_magnitude;
vec = vec * wrap_pi(Type(2.0) * atan2f(axis_magnitude, q(0)));
vec = vec * wrap_pi(Type(2.0) * atan2(axis_magnitude, q(0)));
}

return vec;
Expand Down
10 changes: 5 additions & 5 deletions matrix/SquareMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ bool inv(const SquareMatrix<Type, M> & A, SquareMatrix<Type, M> & inv)
for (size_t n = 0; n < M; n++) {

// if diagonal is zero, swap with row below
if (fabsf(static_cast<float>(U(n, n))) < 1e-8f) {
if (fabs(static_cast<float>(U(n, n))) < 1e-8f) {
//printf("trying pivot for row %d\n",n);
for (size_t i = n + 1; i < M; i++) {

//printf("\ttrying row %d\n",i);
if (fabsf(static_cast<float>(U(i, n))) > 1e-8f) {
if (fabs(static_cast<float>(U(i, n))) > 1e-8f) {
//printf("swapped %d\n",i);
U.swapRows(i, n);
P.swapRows(i, n);
Expand All @@ -157,11 +157,11 @@ bool inv(const SquareMatrix<Type, M> & A, SquareMatrix<Type, M> & inv)
//printf("U:\n"); U.print();
//printf("P:\n"); P.print();
//fflush(stdout);
//ASSERT(fabsf(U(n, n)) > 1e-8f);
//ASSERT(fabs(U(n, n)) > 1e-8f);
#endif

// failsafe, return zero matrix
if (fabsf(static_cast<float>(U(n, n))) < 1e-8f) {
if (fabs(static_cast<float>(U(n, n))) < 1e-8f) {
return false;
}

Expand Down Expand Up @@ -280,7 +280,7 @@ SquareMatrix <Type, M> cholesky(const SquareMatrix<Type, M> & A)
if (res <= 0) {
L(j, j) = 0;
} else {
L(j, j) = sqrtf(res);
L(j, j) = sqrt(res);
}
} else {
float sum = 0;
Expand Down
2 changes: 0 additions & 2 deletions matrix/Vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

#pragma once

#include <cmath>

#include "math.hpp"

namespace matrix
Expand Down
1 change: 0 additions & 1 deletion matrix/helper_functions.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include "math.hpp"
#include <cmath>

namespace matrix
{
Expand Down
1 change: 1 addition & 0 deletions matrix/math.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "stdlib_imports.hpp"
#ifdef __PX4_QURT
#include "dspal_math.h"
#endif
Expand Down
130 changes: 130 additions & 0 deletions matrix/stdlib_imports.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* @file stdlib_imports.hpp
*
* This file is needed to shadow the C standard library math functions with ones provided by the C++ standard library.
* This way we can guarantee that unwanted functions from the C library will never creep back in unexpectedly.
*
* @author Pavel Kirienko <[email protected]>
*/

#pragma once

#include <cmath>
#include <cstdlib>
#include <cinttypes>

namespace matrix {

#if defined(__PX4_NUTTX)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to detect NuttX in general instead of the PX4 flavor of it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pavel-kirienko

For the record no there is none. So far, a change to add NUTTX to the generated config.h will not be accepted upstream. So this has to be set in the build by the flags, and we do that with __PX4_NUTTX.

/*
* NuttX has no usable C++ math library, so we need to provide the needed definitions here manually.
*/
#define MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(name) \
inline float name(float x) { return ::name##f(x); } \
inline double name(double x) { return ::name(x); } \
inline long double name(long double x) { return ::name##l(x); }

#define MATRIX_NUTTX_WRAP_MATH_FUN_BINARY(name) \
inline float name(float x, float y) { return ::name##f(x, y); } \
inline double name(double x, double y) { return ::name(x, y); } \
inline long double name(long double x, long double y) { return ::name##l(x, y); }

MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(fabs)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(log)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(log10)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(exp)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(sqrt)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(sin)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(cos)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(tan)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(asin)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(acos)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(atan)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(sinh)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(cosh)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(tanh)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(ceil)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(floor)

MATRIX_NUTTX_WRAP_MATH_FUN_BINARY(pow)
MATRIX_NUTTX_WRAP_MATH_FUN_BINARY(atan2)

#else // Not NuttX, using the C++ standard library

using std::abs;
using std::div;
using std::fabs;
using std::fmod;
using std::exp;
using std::log;
using std::log10;
using std::pow;
using std::sqrt;
using std::sin;
using std::cos;
using std::tan;
using std::asin;
using std::acos;
using std::atan;
using std::atan2;
using std::sinh;
using std::cosh;
using std::tanh;
using std::ceil;
using std::floor;
using std::frexp;
using std::ldexp;
using std::modf;

# if (__cplusplus >= 201103L)

using std::imaxabs;
using std::imaxdiv;
using std::remainder;
using std::remquo;
using std::fma;
using std::fmax;
using std::fmin;
using std::fdim;
using std::nan;
using std::nanf;
using std::nanl;
using std::exp2;
using std::expm1;
using std::log2;
using std::log1p;
using std::cbrt;
using std::hypot;
using std::asinh;
using std::acosh;
using std::atanh;
using std::erf;
using std::erfc;
using std::tgamma;
using std::lgamma;
using std::trunc;
using std::round;
using std::nearbyint;
using std::rint;
using std::scalbn;
using std::ilogb;
using std::logb;
using std::nextafter;
using std::copysign;
using std::fpclassify;
using std::isfinite;
using std::isinf;
using std::isnan;
using std::isnormal;
using std::signbit;
using std::isgreater;
using std::isgreaterequal;
using std::isless;
using std::islessequal;
using std::islessgreater;
using std::isunordered;

# endif
#endif

}
Loading