Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support rv64 fp16 ops #1297

Open
wants to merge 11 commits into
base: dev/3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.15)
 cmake_minimum_required(VERSION 3.15)

list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/cmake/Modules)

Expand Down
1 change: 1 addition & 0 deletions Testing/Temporary/CTestCostData.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
---
3 changes: 3 additions & 0 deletions cmake/run_test.cmake
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
message(STATUS "Executing command: $ENV{TESTS_EXECUTABLE_LOADER} $ENV{TESTS_EXECUTABLE_LOADER_ARGUMENTS} ${TEST_EXECUTABLE} $ENV{TESTS_ARGUMENTS}")


execute_process(COMMAND $ENV{TESTS_EXECUTABLE_LOADER} $ENV{TESTS_EXECUTABLE_LOADER_ARGUMENTS} ${TEST_EXECUTABLE} $ENV{TESTS_ARGUMENTS} RESULT_VARIABLE result)
if(NOT "${result}" STREQUAL "0")
message(FATAL_ERROR "Test failed with return value '${result}'")
Expand Down
4 changes: 3 additions & 1 deletion ntt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ cmake_minimum_required(VERSION 3.15)

include(cmake/compile_flags.cmake)


if(BUILD_TESTING)
add_subdirectory(test/ctest)
add_subdirectory(test/ctest)
endif()


if(BUILD_BENCHMARK)
add_subdirectory(test/benchmark_test)
endif()
4 changes: 2 additions & 2 deletions ntt/cmake/compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES
"riscv64")
if(ENABLE_K230_RUNTIME)
add_compile_options(-march=rv64gv_zvl128b_zfh -mrvv-vector-bits=zvl)
add_compile_options(-march=rv64gv_zvl128b_zvfh -mrvv-vector-bits=zvl)
elseif(ENABLE_K80_RUNTIME)
add_compile_options(-march=rv64gcv_zvl1024b_zfh -mrvv-vector-bits=zvl)
add_compile_options(-march=rv64gcv_zvl1024b_zvfh -mrvv-vector-bits=zvl)
else()
message(FATAL_ERROR "Unsupported riscv64 target")
endif()
Expand Down
31 changes: 30 additions & 1 deletion ntt/include/nncase/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <codecvt>
#include <cstdint>
#include <float.h>
#include "bfloat16.h"
#include <functional>
#include <limits>

Expand Down Expand Up @@ -64,7 +65,14 @@ struct half {

constexpr half(fp16_from_raw_t, uint16_t value) noexcept : value_(value) {}

operator float() const noexcept {
operator _Float16() const noexcept{
return static_cast<_Float16>(float(*this));
}

operator bfloat16() const noexcept {
return bfloat16::round_to_bfloat16(float(*this));
}
explicit operator float() const noexcept {
const fp32 magic = {113 << 23};
const unsigned int shifted_exp = 0x7c00
<< 13; // exponent mask after shift
Expand Down Expand Up @@ -141,6 +149,7 @@ struct half {
return o;
}


static constexpr half epsilon() noexcept { return from_raw(0x0800); }

static constexpr half highest() noexcept { return from_raw(0x7bff); }
Expand All @@ -167,6 +176,9 @@ struct half {
uint16_t value_;
};




#define DEFINE_FP16_BINARY_FP16RET(x) \
inline half operator x(half a, half b) noexcept { \
return half::round_to_half(float(a) x float(b)); \
Expand All @@ -177,6 +189,14 @@ struct half {
return float(a) x float(b); \
}

#define DEFINE_FP16_FP32_BINARY_BOOLRET(x) \
inline bool operator x(half a, float b) noexcept { \
return float(a) x b; \
}

DEFINE_FP16_FP32_BINARY_BOOLRET(<)


DEFINE_FP16_BINARY_FP16RET(+)
DEFINE_FP16_BINARY_FP16RET(-)
DEFINE_FP16_BINARY_FP16RET(*)
Expand Down Expand Up @@ -208,6 +228,11 @@ inline bool operator==(const half &lhs, const half &rhs) noexcept {
inline bool operator!=(const half &lhs, const half &rhs) noexcept {
return lhs.raw() != rhs.raw();
}

inline std::ostream& operator<<(std::ostream& os, const half& a){
os << std::to_string(float(a));
return os;
}
} // namespace nncase

namespace std {
Expand Down Expand Up @@ -276,12 +301,16 @@ using nncase::half;
inline bool isinf(const half &a) { return std::isinf(float(a)); }
inline bool isnan(const half &a) { return std::isnan(float(a)); }
inline bool isfinite(const half &a) { return std::isfinite(float(a)); }
inline half fabs(const half &a) { return half::round_to_half(fabs(float(a))); }
inline half abs(const half &a) { return half::round_to_half(fabsf(float(a))); }
inline half exp(const half &a) { return half::round_to_half(expf(float(a))); }
inline half log(const half &a) { return half::round_to_half(logf(float(a))); }
inline half log10(const half &a) {
return half::round_to_half(log10f(float(a)));
}
inline half fmod(const half &a, const half &b) {
return half::round_to_half(fmod(float(a), float(b)));
}
inline half sqrt(const half &a) { return half::round_to_half(sqrtf(float(a))); }
inline half pow(const half &a, const half &b) {
return half::round_to_half(powf(float(a), float(b)));
Expand Down
31 changes: 23 additions & 8 deletions ntt/include/nncase/ntt/arch/riscv64/arch_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*/
#pragma once
#include "../../native_vector.h"

#include "../../../half.h"
#ifdef __riscv_vector
#include <riscv_vector.h>

Expand Down Expand Up @@ -62,6 +62,10 @@
__attribute__((riscv_rvv_vector_bits(NTT_VLEN / 2))); \
typedef vuint32mf2_t fixed_vuint32mf2_t \
__attribute__((riscv_rvv_vector_bits(NTT_VLEN / 2))); \
typedef vfloat16mf2_t fixed_vfloat16mf2_t \
__attribute__((riscv_rvv_vector_bits(NTT_VLEN / 2))); \
typedef vfloat16mf4_t fixed_vfloat16mf4_t \
__attribute__((riscv_rvv_vector_bits(NTT_VLEN / 4))); \
typedef vfloat32mf2_t fixed_vfloat32mf2_t \
__attribute__((riscv_rvv_vector_bits(NTT_VLEN / 2)));

Expand All @@ -84,6 +88,8 @@
__attribute__((riscv_rvv_vector_bits(NTT_VLEN * lmul))); \
typedef vuint64m##lmul##_t fixed_vuint64m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(NTT_VLEN * lmul))); \
typedef vfloat16m##lmul##_t fixed_vfloat16m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(NTT_VLEN * lmul))); \
typedef vfloat32m##lmul##_t fixed_vfloat32m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(NTT_VLEN * lmul))); \
typedef vfloat64m##lmul##_t fixed_vfloat64m##lmul##_t \
Expand Down Expand Up @@ -142,6 +148,12 @@ REGISTER_RVV_FIXED_TYPE_WITH_LMUL_GE1(8)
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT( \
uint32_t, fixed_vuint32mf2_t, NTT_VLEN / 8 / sizeof(uint32_t) / 2) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(half, fixed_vfloat16mf2_t, \
NTT_VLEN / 8 / sizeof(half) / 2) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(half, fixed_vfloat16mf4_t, \
NTT_VLEN / 8 / sizeof(half) / 4) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(float, fixed_vfloat32mf2_t, \
NTT_VLEN / 8 / sizeof(float) / 2) \
NTT_END_DEFINE_NATIVE_VECTOR()
Expand Down Expand Up @@ -184,14 +196,17 @@ REGISTER_RVV_FIXED_TYPE_WITH_LMUL_GE1(8)
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT( \
float, fixed_vfloat32m##lmul##_t, NTT_VLEN / 8 / sizeof(float) * lmul) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT( \
half, fixed_vfloat16m##lmul##_t, NTT_VLEN / 8 / sizeof(half) * lmul) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(double, fixed_vfloat64m##lmul##_t, \
NTT_VLEN / 8 / sizeof(double) * \
lmul) \
NTT_END_DEFINE_NATIVE_VECTOR()
NTT_END_DEFINE_NATIVE_VECTOR()

NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_LT1
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(1)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(2)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(4)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(8)
#endif
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_LT1
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(1)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(2)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(4)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL_GE1(8)
#endif
55 changes: 55 additions & 0 deletions ntt/include/nncase/ntt/arch/riscv64/fp16_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@


#pragma once
#include "nncase/ntt/arch/riscv64/arch_types.h"
#include "nncase/ntt/vector.h"
#include "../../../half.h"
#include "rvv_mathfun.h"
#ifdef __riscv_vector
#include <riscv_vector.h>
#endif


namespace nncase::ntt::ops{

#ifdef __riscv_vector


#define RVV_UNARY16_OP(op, dtype, vl, kernel) \
template <> struct op<ntt::vector<dtype, vl>> { \
ntt::vector<dtype, vl> \
operator()(const ntt::vector<dtype, vl> &v) const noexcept { \
return kernel(v, vl); \
} \
};

// unary with hlaf
#define REGISTER_RVV_UNARY16_OP(OP, dtype, kernel) \
RVV_UNARY16_OP(OP, half, NTT_VL(sizeof(dtype) * 8, *, 1), kernel) \
RVV_UNARY16_OP(OP, half, NTT_VL(sizeof(dtype) * 8, *, 2), kernel) \
RVV_UNARY16_OP(OP, half, NTT_VL(sizeof(dtype) * 8, *, 4), kernel) \
RVV_UNARY16_OP(OP, half, NTT_VL(sizeof(dtype) * 8, *, 8), kernel)

#define ABS_FLOAT16(lmul, mlen) \
inline vfloat16m##lmul##_t abs_float16(const vfloat16m##lmul##_t &v, \
const size_t vl) { \
return __riscv_vfabs_v_f16m##lmul(v, vl); \
}

REGISTER_RVV_KERNEL(ABS_FLOAT16)
REGISTER_RVV_UNARY16_OP(abs, half, abs_float16)

// acos
#if 0
#define ACOS_FLOAT16(lmul, mlen) \
inline vfloat16m##lmul##_t acos_float16(const vfloat16m##lmul##_t &v , const size_t vl) { \
auto x = __riscv_vabs_v_f16m##lmul(v,vl);
auto c1 = __riscv_vfmv_v_f_f16m#lmul(-0.0187293f16, vl);
}
#else

#endif
//end acos

#endif
}
Loading
Loading