Skip to content

Commit

Permalink
[CPU][RV64] Implemented table
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 10, 2025
1 parent e203f66 commit d9e453c
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace riscv64 {

using namespace Xbyak_riscv;

#define CONST_1_F 0x3f800000 // 1.f

namespace {
ov::element::Type get_arithmetic_binary_exec_precision(const std::shared_ptr<ov::Node>& n) {
std::vector<ov::element::Type> input_precisions;
Expand Down Expand Up @@ -63,10 +65,13 @@ jit_clamp_emitter::jit_clamp_emitter(ov::intel_cpu::riscv64::jit_generator* host
} else {
OPENVINO_THROW("Incompatible node!");
}
prepare_table();
}

jit_clamp_emitter::jit_clamp_emitter(ov::intel_cpu::riscv64::jit_generator* host, float min, float max, const ov::element::Type exec_prc)
: jit_emitter(host, exec_prc), min(min), max(max) {}
: jit_emitter(host, exec_prc), min(min), max(max) {
prepare_table();
}

size_t jit_clamp_emitter::get_inputs_num() const {
return 1;
Expand All @@ -79,28 +84,28 @@ size_t jit_clamp_emitter::aux_fp_gprs_count() const {
void jit_clamp_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
VReg src = VReg(in_vec_idxs[0]);
VReg dst = VReg(out_vec_idxs[0]);
FReg fmin = FReg(aux_fp_gpr_idxs[0]), fmax = FReg(aux_fp_gpr_idxs[0]);
FReg bound = FReg(aux_fp_gpr_idxs[0]);

h->flw(fmin, p_table, 0);
h->vfmax_vf(dst, src, fmin);
load_table_val("min", bound);
h->vfmax_vf(dst, src, bound);

h->flw(fmax, p_table, sizeof(float));
h->vfmin_vf(dst, dst, fmin);
load_table_val("max", bound);
h->vfmin_vf(dst, dst, bound);
}

std::set<std::vector<element::Type>> jit_clamp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
void jit_clamp_emitter::register_table_entries() {
push_arg_entry_of("min", dnnl::impl::float2int(min));
push_arg_entry_of("max", dnnl::impl::float2int(max));
}

bool jit_clamp_emitter::need_table() const {
return true;
const jit_clamp_emitter::table_entry_val_t* jit_clamp_emitter::get_table() const {
static uint32_t tbl[2];
fill_table(tbl, 2);
return tbl;
}

const void* jit_clamp_emitter::get_table() const {
static float tbl[2];
tbl[0] = min; // use explicit assignment to change dynamically array in runtime
tbl[1] = max;
return tbl;
std::set<std::vector<element::Type>> jit_clamp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

/// DIV ///
Expand Down Expand Up @@ -128,10 +133,14 @@ std::set<std::vector<element::Type>> jit_div_emitter::get_supported_precisions(c

/// Exp ///
jit_exp_emitter::jit_exp_emitter(ov::intel_cpu::riscv64::jit_generator* host, const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, get_arithmetic_binary_exec_precision(node)) {}
: jit_emitter(host, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
}

jit_exp_emitter::jit_exp_emitter(ov::intel_cpu::riscv64::jit_generator* host, const ov::element::Type exec_prc)
: jit_emitter(host, exec_prc) {}
: jit_emitter(host, exec_prc) {
prepare_table();
}

size_t jit_exp_emitter::get_inputs_num() const {
return 1;
Expand All @@ -155,51 +164,48 @@ void jit_exp_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const st
VReg aux0 = VReg(aux_vec_idxs[0]);
VReg aux1 = VReg(aux_vec_idxs[1]);
VReg aux2 = VReg(aux_vec_idxs[2]);
FReg fp0 = FReg(aux_fp_gpr_idxs[0]);
FReg fp1 = FReg(aux_fp_gpr_idxs[1]);
Reg tmp = Reg(aux_gpr_idxs[0]);

// save src
h->vmv_v_v(aux2, src);

// get mask of values lower than log(FLT_MIN) to zero them in the output
FReg ln_flt_min_f = FReg(aux_fp_gpr_idxs[0]);
h->flw(ln_flt_min_f, p_table, 10 * sizeof(uint32_t));
FReg ln_flt_min_f = fp0;
load_table_val("ln_flt_min_f", ln_flt_min_f);
h->vfmax_vf(dst, src, ln_flt_min_f);

FReg ln_flt_max_f = FReg(aux_fp_gpr_idxs[1]);
h->flw(ln_flt_max_f, p_table, 9 * sizeof(uint32_t));
h->vfmin_vf(dst, dst, ln_flt_max_f);
load_table_val("ln_flt_max_f", fp1);
h->vfmin_vf(dst, dst, fp1);

// keep dst = x for further computations
h->vmv_v_v(aux0, dst);

// calculate exp(x)
// fx = x * log2ef + 0.5
FReg log2ef = FReg(aux_fp_gpr_idxs[1]);
h->flw(log2ef, p_table, 8 * sizeof(uint32_t));
h->vfmul_vf(dst, dst, log2ef);
FReg half = FReg(aux_fp_gpr_idxs[1]);
h->flw(half, p_table, 6 * sizeof(uint32_t));
h->vfadd_vf(dst, dst, log2ef);
load_table_val("log2ef", fp1);
h->vfmul_vf(dst, dst, fp1);
load_table_val("half", fp1);
h->vfadd_vf(dst, dst, fp1);

// aux1 = floorf(fx)
h->vfcvt_x_f_v(aux1, dst); // fp32 -> int32
h->vfcvt_f_x_v(aux1, aux1); // int32 -> fp32
h->vmfgt_vv(mask_vreg(), aux1, dst);
FReg one = FReg(aux_fp_gpr_idxs[1]);
h->flw(one, p_table, 5 * sizeof(uint32_t)); // one
h->vfsub_vf(aux1, aux1, one, VM::masked);
load_table_val("one", fp1);
h->vfsub_vf(aux1, aux1, fp1, VM::masked);

// keep dst = floorf(fx) for further computations
h->vmv_v_v(dst, aux1);

// x = x - fx * ln2
FReg ln2 = FReg(aux_fp_gpr_idxs[1]);
h->flw(ln2, p_table, 7 * sizeof(uint32_t));
h->vfnmsac_vf(aux0, ln2, aux1);
load_table_val("ln2f", fp1);
h->vfnmsac_vf(aux0, fp1, aux1);

// compute 2^n
Reg tmp = Reg(aux_gpr_idxs[0]);
h->vfcvt_x_f_v(aux1, dst);
h->lw(tmp, p_table, 11 * sizeof(uint32_t)); // exponent_bias
load_table_val("exponent_bias", tmp);
h->vadd_vx(aux1, aux1, tmp);
const int n_mantissa_bits = 23;
h->vsll_vi(aux1, aux1, n_mantissa_bits);
Expand All @@ -210,27 +216,27 @@ void jit_exp_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const st
h->vand_vx(aux1, aux1, zero, VM::masked);

// compute polynomial
FReg pol = FReg(aux_fp_gpr_idxs[1]);
h->flw(pol, p_table, 4 * sizeof(uint32_t)); // pol5
FReg pol = fp1;
load_table_val("pol5", pol);
h->vfmv_v_f(dst, pol);

h->flw(pol, p_table, 3 * sizeof(uint32_t)); // pol4
load_table_val("pol4", pol);
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 2 * sizeof(uint32_t)); // pol3
load_table_val("pol3", pol);
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 1 * sizeof(uint32_t)); // pol2
load_table_val("pol2", pol);
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 0 * sizeof(uint32_t)); // pol1
load_table_val("pol1", pol);
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 5 * sizeof(uint32_t)); // one
load_table_val("one", pol);
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

Expand All @@ -242,25 +248,25 @@ std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(c
return {{element::f32}};
}

bool jit_exp_emitter::need_table() const {
return true;
}

const void* jit_exp_emitter::get_table() const {
static uint32_t tbl[12] = {
0x3f7ffffb, // pol1
0x3efffee3, // pol2
0x3e2aad40, // pol3
0x3d2b9d0d, // pol4
0x3c07cfce, // pol5
0x3f800000, // one
0x3f000000, // 0.5f
0x3f317218, // ln2f
0x3fb8aa3b, // log2ef
0x42b17218, // ln_flt_max_f
0xc2aeac50, // ln_flt_min_f
0x0000007f // exponent_bias
};
void jit_exp_emitter::register_table_entries() {
push_arg_entry_of("pol1", 0x3f7ffffb); // p1 = 0.999999701f
push_arg_entry_of("pol2", 0x3efffee3); // p2 = 0.499991506f
push_arg_entry_of("pol3", 0x3e2aad40); // p3 = 0.166676521f
push_arg_entry_of("pol4", 0x3d2b9d0d); // p4 = 0.0418978221f
push_arg_entry_of("pol5", 0x3c07cfce); // p5 = 0.00828929059f

push_arg_entry_of("one", CONST_1_F);
push_arg_entry_of("half", 0x3f000000);
push_arg_entry_of("ln2f", 0x3f317218);
push_arg_entry_of("log2ef", 0x3fb8aa3b);
push_arg_entry_of("ln_flt_max_f", 0x42b17218);
push_arg_entry_of("ln_flt_min_f", 0xc2aeac50);
push_arg_entry_of("exponent_bias", 0x0000007f);
}

const jit_exp_emitter::table_entry_val_t* jit_exp_emitter::get_table() const {
static uint32_t tbl[12];
fill_table(tbl, 12);
return tbl;
}

Expand Down Expand Up @@ -361,21 +367,22 @@ void jit_relu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const s
h->vmflt_vf(mask_vreg(), dst, fzero);

FReg alpha_reg = fzero;
h->flw(alpha_reg, p_table);
load_table_val("alpha", alpha_reg);
h->vfmul_vf(dst, dst, alpha_reg, VM::masked);
}

std::set<std::vector<element::Type>> jit_relu_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

bool jit_relu_emitter::need_table() const {
return alpha != 0;
void jit_relu_emitter::register_table_entries() {
if (alpha != 0)
push_arg_entry_of("alpha", dnnl::impl::float2int(alpha));
}

const void* jit_relu_emitter::get_table() const {
static float tbl[1];
tbl[0] = alpha; // use explicit assignment to change dynamically array in runtime
const jit_relu_emitter::table_entry_val_t* jit_relu_emitter::get_table() const {
static uint32_t tbl[1];
fill_table(tbl, 1);
return tbl;
}

Expand All @@ -402,6 +409,8 @@ std::set<std::vector<element::Type>> jit_sub_emitter::get_supported_precisions(c
return {{element::f32, element::f32}};
}

#undef CONST_1_F

} // namespace riscv64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class jit_clamp_emitter : public jit_emitter {

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;
bool need_table() const override;
const void* get_table() const override;
void register_table_entries() override;
const table_entry_val_t* get_table() const override;

float min = 0.f;
float max = 0.f;
Expand Down Expand Up @@ -73,8 +73,8 @@ class jit_exp_emitter : public jit_emitter {

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;
bool need_table() const override;
const void* get_table() const override;
void register_table_entries() override;
const table_entry_val_t* get_table() const override;
};

class jit_mul_emitter : public jit_emitter {
Expand Down Expand Up @@ -119,8 +119,9 @@ class jit_relu_emitter : public jit_emitter {

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;
bool need_table() const override;
const void* get_table() const override;

void register_table_entries() override;
const table_entry_val_t* get_table() const override;

float alpha = 0.f;
};
Expand Down
33 changes: 29 additions & 4 deletions src/plugins/intel_cpu/src/emitters/plugin/riscv64/jit_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ size_t jit_emitter::aux_vecs_count() const {

size_t jit_emitter::aux_gprs_count() const {
// We need one gpr to load table address
return need_table() ? 1 : 0;
return entry_map_.empty() ? 0 : 1;
}

size_t jit_emitter::aux_fp_gprs_count() const {
Expand Down Expand Up @@ -169,16 +169,16 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
}
OPENVINO_ASSERT(aux_gprs_count() <= aux_gpr_idxs.size(), "Failed to allocate required number of general-purpose registers");

if (need_table()) {
if (!entry_map_.empty()) {
// last aux_gpr_idx is for p_table, we can use aux_gpr_idxs from idx 0 for other purpose
p_table = Reg(aux_gpr_idxs[aux_gprs_count() - 1]);
aux_gpr_idxs.erase(aux_gpr_idxs.end() - 1);
}

store_context(preserved_gpr_idxs, preserved_fp_gpr_idxs, preserved_vec_idxs);

if (need_table()) {
load_table_addr(get_table());
if (!entry_map_.empty()) {
load_table_addr();
}
}

Expand Down Expand Up @@ -280,6 +280,31 @@ void jit_emitter::restore_context(const std::vector<size_t>& gpr_regs,
}
}

void jit_emitter::prepare_table() {
register_table_entries();

// Now that we registered the entries, we set the offsets. No
// entries should be registered after this point. This allows to
// expect the same order when injecting the table entries in
// prepare_table.
size_t off = 0;
for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) {
auto& te = (*it).second;
te.off = off;
off += sizeof(table_entry_val_t);
}
}

void jit_emitter::fill_table(table_entry_val_t* tbl, size_t size) const {
OPENVINO_ASSERT(entry_map_.size() == size, "Incorrect table size for filling");
for (const auto& te : entry_map_) {
const auto& elem = te.second;
const auto idx = elem.off / sizeof(table_entry_val_t);
OPENVINO_ASSERT(idx < size, "Incorrect index of table");
tbl[idx] = elem.val;
}
}

} // namespace riscv64
} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit d9e453c

Please sign in to comment.