Skip to content

Commit

Permalink
Add Trunc, Modulus, SignBit functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Auburn committed Dec 2, 2024
1 parent 868e772 commit c665d1f
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 23 deletions.
58 changes: 53 additions & 5 deletions include/FastSIMD/ToolSet/Generic/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ namespace FS
}


// Round value
// Round value, banker's rounding
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> Round( const Register<T, N, SIMD>& a )
{
Expand All @@ -145,7 +145,7 @@ namespace FS
}


// Floor value
// Floor value, round towards -infinity
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> Floor( const Register<T, N, SIMD>& a )
{
Expand All @@ -154,14 +154,22 @@ namespace FS
}


// Ceil value
// Ceil value, round towards +infinity
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> Ceil( const Register<T, N, SIMD>& a )
{
static_assert( !IsNativeV<Register<T, N, SIMD>>, "FastSIMD: FS::Ceil not supported with provided types" );
return Register<T, N, SIMD>{ Ceil( a.v0 ), Ceil( a.v1 ) };
}

// Truncate value, round towards 0
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> Trunc( const Register<T, N, SIMD>& a )
{
static_assert( !IsNativeV<Register<T, N, SIMD>>, "FastSIMD: FS::Trunc not supported with provided types" );
return Register<T, N, SIMD>{ Trunc( a.v0 ), Trunc( a.v1 ) };
}


// Min of 2 elements
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
Expand Down Expand Up @@ -432,7 +440,7 @@ namespace FS
return Register<T, N, SIMD>{ MaskedMul( mask.v0, a.v0, b.v0 ), MaskedMul( mask.v1, a.v1, b.v1 ) };
}
}

template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> InvMaskedMul( const typename Register<T, N, SIMD>::MaskTypeArg& mask, const Register<T, N, SIMD>& a, const Register<T, N, SIMD>& b )
{
Expand All @@ -446,6 +454,46 @@ namespace FS
}
}

// Extract sign bit (high bit)
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> SignBit( const Register<T, N, SIMD>& a )
{
if constexpr( IsNativeV<Register<T, N, SIMD>> )
{
Register<T, N, SIMD> signBit;
if constexpr( std::is_floating_point_v<T> )
{
signBit = Register<T, N, SIMD>( (T)-0.0 );
}
else
{
static_assert( std::is_signed_v<T>, "No signed bit in unsigned type" );
signBit = Register<T, N, SIMD>( std::numeric_limits<T>::min() );
}

return signBit & a;
}
else
{
return Register<T, N, SIMD>{ SignBit( a.v0 ), SignBit( a.v1 ) };
}
}

// Modulus: a % b
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> Modulus( const Register<T, N, SIMD>& a, const Register<T, N, SIMD>& b )
{
if constexpr( IsNativeV<Register<T, N, SIMD>> )
{
auto ab = a / b;
return (ab - FS::Trunc( ab )) * b;
}
else
{
return Register<T, N, SIMD>{ Modulus( a.v0, b.v0 ), Modulus( a.v1, b.v1 ) };
}
}

// Reciprocal: 1 / a
template<typename T, std::size_t N, FastSIMD::FeatureSet SIMD>
FS_FORCEINLINE Register<T, N, SIMD> Reciprocal( const Register<T, N, SIMD>& a )
Expand All @@ -455,7 +503,7 @@ namespace FS
return Register<T, N, SIMD>( 1 ) / a;
}
else
{
{
return Register<T, N, SIMD>{ Reciprocal( a.v0 ), Reciprocal( a.v1 ) };
}
}
Expand Down
7 changes: 5 additions & 2 deletions include/FastSIMD/ToolSet/Generic/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,11 @@ namespace FS
template<typename T>
constexpr bool IsNativeV = IsNative<T>::value;

template<auto T = 0>
using EnableIfRelaxed = std::enable_if_t<FastSIMD::IsRelaxed<T>()>;
template<auto SIMD = 0>
using EnableIfRelaxed = std::enable_if_t<FastSIMD::IsRelaxed<SIMD>()>;

template<auto SIMD = 0>
using EnableIfNotRelaxed = std::enable_if_t<!FastSIMD::IsRelaxed<SIMD>()>;


template<std::size_t N, FastSIMD::FeatureSet SIMD = FastSIMD::FeatureSetDefault()>
Expand Down
12 changes: 12 additions & 0 deletions include/FastSIMD/ToolSet/Generic/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ namespace FS
return std::floor( a.GetNative() );
}

template<typename T, FastSIMD::FeatureSet SIMD, typename = EnableIfNative<Register<T, 1, SIMD>>>
FS_FORCEINLINE Register<T, 1, SIMD> Trunc( const Register<T, 1, SIMD>& a )
{
return std::trunc( a.GetNative() );
}

template<typename T, FastSIMD::FeatureSet SIMD, typename = EnableIfNative<Register<T, 1, SIMD>>, typename = EnableIfNotRelaxed<SIMD>>
FS_FORCEINLINE f32<1, SIMD> Modulus( const Register<T, 1, SIMD>& a, const Register<T, 1, SIMD>& b )
{
return std::fmod( a.GetNative(), b.GetNative() );
}

template<typename T, FastSIMD::FeatureSet SIMD, typename = EnableIfNative<Register<T, 1, SIMD>>>
FS_FORCEINLINE Register<T, 1, SIMD> Min( const Register<T, 1, SIMD>& a, const Register<T, 1, SIMD>& b )
{
Expand Down
17 changes: 17 additions & 0 deletions include/FastSIMD/ToolSet/x86/128/f32x4.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,23 @@ namespace FS
return MaskedIncrement( aRound < a, aRound );
}
}

template<FastSIMD::FeatureSet SIMD, typename = EnableIfNative<f32<4, SIMD>>>
FS_FORCEINLINE f32<4, SIMD> Trunc( const f32<4, SIMD>& a )
{
if constexpr( SIMD & FastSIMD::FeatureFlag::SSE41 )
{
return _mm_round_ps( a.native, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC );
}
else
{
__m128i aInt = _mm_cvttps_epi32( a.native );
__m128 aIntF = _mm_cvtepi32_ps( aInt );

return _mm_xor_ps( aIntF, _mm_and_ps( _mm_castsi128_ps( _mm_cmpeq_epi32( aInt, _mm_set1_epi32( (-2147483647 - 1) ) ) ), _mm_xor_ps( a.native, aIntF ) ) );

}
}

template<FastSIMD::FeatureSet SIMD, typename = EnableIfNative<f32<4, SIMD>>>
FS_FORCEINLINE f32<4, SIMD> Min( const f32<4, SIMD>& a, const f32<4, SIMD>& b )
Expand Down
6 changes: 6 additions & 0 deletions include/FastSIMD/ToolSet/x86/256/f32x8.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ namespace FS
{
return _mm256_round_ps( a.native, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC );
}

template<FastSIMD::FeatureSet SIMD, typename = EnableIfNative<f32<8, SIMD>>>
FS_FORCEINLINE f32<8, SIMD> Trunc( const f32<8, SIMD>& a )
{
return _mm256_round_ps( a.native, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC );
}

template<FastSIMD::FeatureSet SIMD, typename = EnableIfNative<f32<8, SIMD>>>
FS_FORCEINLINE f32<8, SIMD> Min( const f32<8, SIMD>& a, const f32<8, SIMD>& b )
Expand Down
6 changes: 6 additions & 0 deletions include/FastSIMD/ToolSet/x86/512/f32x16.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ namespace FS
{
return _mm512_roundscale_ps( a.native, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC );
}

template<FastSIMD::FeatureSet SIMD, typename = EnableIfNative<f32<16, SIMD>>>
FS_FORCEINLINE f32<16, SIMD> Trunc( const f32<16, SIMD>& a )
{
return _mm512_roundscale_ps( a.native, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC );
}

template<FastSIMD::FeatureSet SIMD, typename = EnableIfNative<f32<16, SIMD>>>
FS_FORCEINLINE f32<16, SIMD> Min( const f32<16, SIMD>& a, const f32<16, SIMD>& b )
Expand Down
35 changes: 19 additions & 16 deletions tests/test.inl
Original file line number Diff line number Diff line change
Expand Up @@ -303,33 +303,33 @@ class FastSIMD::DispatchClass<TestFastSIMD<RegisterBytes, Relaxed>, SIMD> : publ
RegisterTest( tests, "i32 load store", []( TestRegi32 a ) { return a; } );
RegisterTest( tests, "i32 load scalar", []( int32_t a ) { return TestRegi32( a ); } );
RegisterTest( tests, "i32 splat", []( int32_t a ) { return FS::Splat<TestRegi32::ElementCount>( a ); } );
RegisterTest( tests, "i32 extract 0", []( TestRegi32 a ) { return FS::Extract0( a ); } );
RegisterTest( tests, "i32 extract 0", []( TestRegi32 a ) { return FS::Extract0( a ); } );
RegisterTest( tests, "i32 load incremented", []() { return FS::LoadIncremented<TestRegi32>(); } );

RegisterTest( tests, "i32 plus operator", std::plus<TestRegi32>() );
RegisterTest( tests, "i32 minus operator", std::minus<TestRegi32>() );
RegisterTest( tests, "i32 multiply operator", std::multiplies<TestRegi32>() );

RegisterTest( tests, "i32 bit and operator", std::bit_and<TestRegi32>() );
RegisterTest( tests, "i32 bit or operator", std::bit_or<TestRegi32>() );
RegisterTest( tests, "i32 bit xor operator", std::bit_xor<TestRegi32>() );
RegisterTest( tests, "i32 bit not operator", std::bit_not<TestRegi32>() );
RegisterTest( tests, "i32 negate operator", std::negate<TestRegi32>() );
RegisterTest( tests, "i32 bit and not", []( TestRegi32 a, TestRegi32 b ) { return FS::BitwiseAndNot( a, b ); } );

RegisterTest( tests, "i32 increment", []( TestRegi32 a ) { return FS::Increment( a ); } );
RegisterTest( tests, "i32 decrement", []( TestRegi32 a ) { return FS::Decrement( a ); } );
RegisterTest( tests, "i32 abs", []( TestRegi32 a ) { return FS::Abs( a ); } );
RegisterTest( tests, "i32 min", []( TestRegi32 a, TestRegi32 b ) { return FS::Min( a, b ); } );
RegisterTest( tests, "i32 max", []( TestRegi32 a, TestRegi32 b ) { return FS::Max( a, b ); } );

RegisterTest( tests, "i32 bit shift left scalar", []( TestRegi32 a, int b ) { return a << ( b & 31 ); } );
RegisterTest( tests, "i32 bit shift right scalar", []( TestRegi32 a, int b ) { return a >> ( b & 31 ); } );
RegisterTest( tests, "i32 bit shift right zero extend scalar", []( TestRegi32 a, int b ) { return FS::BitShiftRightZeroExtend( a, b & 31 ); } );

//RegisterTest( tests, "i32 bit shift left", []( TestRegi32 a, TestRegi32 b ) { return a << FS::Min( TestRegi32( 31 ), FS::Abs( b ) ); } );
//RegisterTest( tests, "i32 bit shift right", []( TestRegi32 a, TestRegi32 b ) { return a >> FS::Min( TestRegi32( 31 ), FS::Abs( b ) ); } );

RegisterTest( tests, "i32 equals operator", []( TestRegi32 a, TestRegi32 b ) { return a == b; } );
RegisterTest( tests, "i32 equals operator alt", []( TestRegi32 a ) { return a == a; } );
RegisterTest( tests, "i32 not equals operator", []( TestRegi32 a, TestRegi32 b ) { return a != b; } );
Expand All @@ -338,7 +338,7 @@ class FastSIMD::DispatchClass<TestFastSIMD<RegisterBytes, Relaxed>, SIMD> : publ
RegisterTest( tests, "i32 greater than operator", []( TestRegi32 a, TestRegi32 b ) { return a > b; } );
RegisterTest( tests, "i32 less equal than operator", []( TestRegi32 a, TestRegi32 b ) { return a <= b; } );
RegisterTest( tests, "i32 greater equal than operator", []( TestRegi32 a, TestRegi32 b ) { return a >= b; } );

RegisterTest( tests, "i32 select", []( TestRegm32 m, TestRegi32 a, TestRegi32 b ) { return FS::Select( m, a, b ); } );
RegisterTest( tests, "i32 select high bit", []( TestRegf32 m, TestRegi32 a, TestRegi32 b ) { return FS::SelectHighBit( m, a, b ); } );
RegisterTest( tests, "i32 masked", []( TestRegm32 m, TestRegi32 a ) { return FS::Masked( m, a ); } );
Expand All @@ -351,13 +351,13 @@ class FastSIMD::DispatchClass<TestFastSIMD<RegisterBytes, Relaxed>, SIMD> : publ
RegisterTest( tests, "i32 inv masked add", []( TestRegm32 m, TestRegi32 a, TestRegi32 b ) { return FS::InvMaskedAdd( m, a, b ); } );
RegisterTest( tests, "i32 inv masked sub", []( TestRegm32 m, TestRegi32 a, TestRegi32 b ) { return FS::InvMaskedSub( m, a, b ); } );
RegisterTest( tests, "i32 inv masked mul", []( TestRegm32 m, TestRegi32 a, TestRegi32 b ) { return FS::InvMaskedMul( m, a, b ); } );

RegisterTest( tests, "f32 load store", []( TestRegf32 a ) { return a; } );
RegisterTest( tests, "f32 load scalar", []( float a ) { return TestRegf32( a ); } );
RegisterTest( tests, "f32 splat", []( float a ) { return FS::Splat<TestRegf32::ElementCount>( a ); } );
RegisterTest( tests, "f32 extract 0", []( TestRegf32 a ) { return FS::Extract0( a ); } );
RegisterTest( tests, "f32 load incremented", []() { return FS::LoadIncremented<TestRegf32>(); } );

RegisterTest( tests, "f32 plus operator", std::plus<TestRegf32>() );
RegisterTest( tests, "f32 minus operator", std::minus<TestRegf32>() );
RegisterTest( tests, "f32 multiply operator", std::multiplies<TestRegf32>() );
Expand All @@ -367,14 +367,14 @@ class FastSIMD::DispatchClass<TestFastSIMD<RegisterBytes, Relaxed>, SIMD> : publ
RegisterTest( tests, "f32 fused multiply sub", []( TestRegf32 a, TestRegf32 b ) { return FS::FMulSub( a, TestRegf32( -1 ), b ); } );
RegisterTest( tests, "f32 fused negative multiply add", []( TestRegf32 a, TestRegf32 b ) { return FS::FNMulAdd( a, TestRegf32( -1 ), b ); } );
RegisterTest( tests, "f32 fused negative multiply sub", []( TestRegf32 a, TestRegf32 b ) { return FS::FNMulSub( a, TestRegf32( -1 ), b ); } );

RegisterTest( tests, "f32 bit and operator", std::bit_and<TestRegf32>() );
RegisterTest( tests, "f32 bit or operator", std::bit_or<TestRegf32>() );
RegisterTest( tests, "f32 bit xor operator", std::bit_xor<TestRegf32>() );
RegisterTest( tests, "f32 bit not operator", std::bit_not<TestRegf32>() );
RegisterTest( tests, "f32 negate operator", std::negate<TestRegf32>() );
RegisterTest( tests, "f32 bit and not", []( TestRegf32 a, TestRegf32 b ) { return FS::BitwiseAndNot( a, b ); } );

RegisterTest( tests, "f32 equals operator", []( TestRegf32 a, TestRegf32 b ) { return a == b; } );
RegisterTest( tests, "f32 greater equal than operator", []( TestRegf32 a, TestRegf32 b ) { return a >= b; } );
RegisterTest( tests, "f32 not equals operator", []( TestRegf32 a, TestRegf32 b ) { return a != b; } );
Expand All @@ -400,16 +400,19 @@ class FastSIMD::DispatchClass<TestFastSIMD<RegisterBytes, Relaxed>, SIMD> : publ
RegisterTest( tests, "f32 inv masked add", []( TestRegm32 m, TestRegf32 a, TestRegf32 b ) { return FS::InvMaskedAdd( m, a, b ); } );
RegisterTest( tests, "f32 inv masked sub", []( TestRegm32 m, TestRegf32 a, TestRegf32 b ) { return FS::InvMaskedSub( m, a, b ); } );
RegisterTest( tests, "f32 inv masked mul", []( TestRegm32 m, TestRegf32 a, TestRegf32 b ) { return FS::InvMaskedMul( m, a, b ); } );

RegisterTest( tests, "f32 round", []( TestRegf32 a ) { return FS::Round( a ); } );
RegisterTest( tests, "f32 ceil", []( TestRegf32 a ) { return FS::Ceil( a ); } );
RegisterTest( tests, "f32 floor", []( TestRegf32 a ) { return FS::Floor( a ); } );
RegisterTest( tests, "f32 trunc", []( TestRegf32 a ) { return FS::Trunc( a ); } );
RegisterTest( tests, "f32 signbit", []( TestRegf32 a, TestRegf32 b ) { return FS::SignBit( a ) ^ b; } );
//RegisterTest( tests, "f32 modulus", []( TestRegf32 a, TestRegf32 b ) { return FS::Modulus( a, b ); } );

RegisterTest( tests, "f32 sqrt", []( TestRegf32 a ) { return FS::Sqrt( FS::Min( FS::Max( FS::Abs( a ), TestRegf32( 1.e-16f ) ), TestRegf32( 1.e+16f ) ) ); } );
RegisterTest( tests, "f32 inv sqrt", []( TestRegf32 a ) { return FS::InvSqrt( FS::Min( FS::Max( FS::Abs( a ), TestRegf32( 1.e-16f ) ), TestRegf32( 1.e+16f ) ) ); } ).relaxedAccuracy = 8192;
RegisterTest( tests, "f32 reciprocal", []( TestRegf32 a ) {
TestRegf32 clamped = FS::Min( FS::Max( FS::Abs( a ), TestRegf32( 1.e-16f ) ), TestRegf32( 1.e+16f ) );
return FS::Reciprocal( FS::Select( a > TestRegf32( 0 ), clamped, -clamped ) );
a = FS::Min( FS::Max( FS::Abs( a ), TestRegf32( 1.e-16f ) ), TestRegf32( 1.e+16f ) ) | FS::SignBit( a );
return FS::Reciprocal( a );
} ).relaxedAccuracy = 8192;

RegisterTest( tests, "f32 cos", []( TestRegf32 a ) { return FS::Cos( FS::Min( FS::Max( a, TestRegf32( -1.e+16f ) ), TestRegf32( 1.e+16f ) ) ); } ).relaxedAccuracy = 8192;
Expand All @@ -419,7 +422,7 @@ class FastSIMD::DispatchClass<TestFastSIMD<RegisterBytes, Relaxed>, SIMD> : publ
RegisterTest( tests, "f32 pow", []( TestRegf32 a, TestRegf32 b ) { return FS::Pow( a, b ); } ).relaxedAccuracy = 8192;

RegisterTest( tests, "i32 convert to f32", []( TestRegi32 a ) { return FS::Convert<float>( a ); } );
RegisterTest( tests, "f32 convert to i32", []( TestRegf32 a ) { return FS::Convert<int32_t>( FS::Min( FS::Max( a, TestRegf32( 2147483648 ) ), TestRegf32( 2147483520 ) ) ); } );
RegisterTest( tests, "f32 convert to i32", []( TestRegf32 a ) { return FS::Convert<int32_t>( FS::Min( FS::Max( a, TestRegf32( -2147483648 ) ), TestRegf32( 2147483520 ) ) ); } );

RegisterTest( tests, "f32 cast to i32", []( TestRegf32 a ) { return FS::Cast<int32_t>( a ); } );
RegisterTest( tests, "i32 cast to f32", []( TestRegi32 a ) { return FS::Cast<float>( a ); } );
Expand Down

0 comments on commit c665d1f

Please sign in to comment.