Skip to content

Commit

Permalink
[CPU] Fixed Reduce kernel with bf16 destination precision (#28732)
Browse files Browse the repository at this point in the history
### Details:
 - Cherry-picks: #28731
  • Loading branch information
dmitry-gorokhov authored Jan 31, 2025
1 parent 7e8c945 commit c8d66a7
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions src/plugins/intel_cpu/src/nodes/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
data_type::f32);
}

if (mayiuse(avx512_core)) {
uni_vcvtneps2bf16 = std::make_shared<jit_uni_vcvtneps2bf16>(this, isa);
}
uni_vcvtneps2bf16 = std::make_shared<jit_uni_vcvtneps2bf16>(this, isa);

this->preamble();

Expand Down Expand Up @@ -188,9 +186,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene

this->postamble();

if (mayiuse(avx512_core)) {
uni_vcvtneps2bf16->emit_data();
}
uni_vcvtneps2bf16->emit_data();

if (jcp_.reduce_mode == Algorithm::ReduceAnd || jcp_.reduce_mode == Algorithm::ReduceL1 ||
jcp_.reduce_mode == Algorithm::ReduceMax || jcp_.reduce_mode == Algorithm::ReduceMin ||
Expand Down Expand Up @@ -1017,9 +1013,15 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
uni_vmovups(op, vmm_dst);
break;
case memory::data_type::bf16:
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(ymm_dst.getIdx())});
vmovdqu16(op, ymm_dst);
if (isa == cpu::x64::avx512_core) {
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(ymm_dst.getIdx())});
vmovdqu16(op, ymm_dst);
} else {
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(xmm_dst.getIdx())});
uni_vmovdqu(op, xmm_dst);
}
break;
case memory::data_type::f16:
vcvtps2ph(op, vmm_dst, 0x4);
Expand Down Expand Up @@ -1253,9 +1255,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
data_type::f32);
}

if (mayiuse(avx512_core)) {
uni_vcvtneps2bf16 = std::make_shared<jit_uni_vcvtneps2bf16>(this, isa);
}
uni_vcvtneps2bf16 = std::make_shared<jit_uni_vcvtneps2bf16>(this, isa);

this->preamble();

Expand Down Expand Up @@ -1312,9 +1312,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi

this->postamble();

if (mayiuse(avx512_core)) {
uni_vcvtneps2bf16->emit_data();
}
uni_vcvtneps2bf16->emit_data();

if (jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp) {
log_injector->prepare_table();
Expand Down Expand Up @@ -1770,9 +1768,15 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
uni_vmovups(op, vmm_dst);
break;
case memory::data_type::bf16:
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(ymm_dst.getIdx())});
vmovdqu16(op, ymm_dst);
if (isa == cpu::x64::avx512_core) {
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(ymm_dst.getIdx())});
vmovdqu16(op, ymm_dst);
} else {
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(xmm_dst.getIdx())});
uni_vmovdqu(op, xmm_dst);
}
break;
case memory::data_type::f16:
vcvtps2ph(op, vmm_dst, 0x4);
Expand Down

0 comments on commit c8d66a7

Please sign in to comment.