Skip to content

Commit

Permalink
brgemm -> gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Jan 29, 2025
1 parent a8eaf9b commit c954669
Show file tree
Hide file tree
Showing 24 changed files with 70 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_batched.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "utils.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

using namespace Xbyak;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <cpu/x64/amx_tile_configure.hpp>

#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define INNER_K_BLK(dtype) static_cast<dnnl_dim_t>((brgemm_utils::repacking::compute_inner_k_block(in0_dtype)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

using namespace Xbyak;
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include "transformations/cpu_opset/x64/op/mha.hpp"
#include "transformations/cpu_opset/x64/op/qkv_proj.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#else
# include "emitters/snippets/x64/cpu_generator.hpp"
# include "executors/x64/subgraph.hpp"
# include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
# include "transformations/snippets/x64/pass/brgemm_to_gemm_cpu.h"
# include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
# include "transformations/snippets/x64/pass/enforce_precision.hpp"
# include "transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "snippets/lowered/expressions/buffer_expression.hpp"
#include "snippets/op/buffer.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.h"
#include "utils/general_utils.h"

using namespace Xbyak;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm_cpu.hpp"
#include "gemm_cpu.hpp"

#include "snippets/itt.hpp"
#include "snippets/lowered/port_descriptor.hpp"
Expand All @@ -14,7 +14,7 @@ namespace ov {
namespace intel_cpu {
using namespace brgemm_utils;

BrgemmCPU::BrgemmCPU(const Output<Node>& A,
GemmCPU::GemmCPU(const Output<Node>& A,
const Output<Node>& B,
BRGEMM_TYPE type,
const size_t offset_a,
Expand All @@ -35,7 +35,7 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A,
custom_constructor_validate_and_infer_types(layout_a, layout_b, layout_c);
}

BrgemmCPU::BrgemmCPU(const Output<Node>& A,
GemmCPU::GemmCPU(const Output<Node>& A,
const Output<Node>& B,
const Output<Node>& scratch,
BRGEMM_TYPE type,
Expand All @@ -58,7 +58,7 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A,
custom_constructor_validate_and_infer_types(layout_a, layout_b, layout_c);
}

BrgemmCPU::BrgemmCPU(const Output<Node>& A,
GemmCPU::GemmCPU(const Output<Node>& A,
const Output<Node>& B,
BRGEMM_TYPE type,
const PortDescriptor& desc_a,
Expand All @@ -76,7 +76,7 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A,
custom_constructor_validate_and_infer_types(layout_a, layout_b, layout_c);
}

BrgemmCPU::BrgemmCPU(const Output<Node>& A,
GemmCPU::GemmCPU(const Output<Node>& A,
const Output<Node>& B,
const Output<Node>& scratch,
BRGEMM_TYPE type,
Expand All @@ -96,10 +96,10 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A,
custom_constructor_validate_and_infer_types(layout_a, layout_b, layout_c);
}

void BrgemmCPU::custom_constructor_validate_and_infer_types(const std::vector<size_t>& layout_a,
void GemmCPU::custom_constructor_validate_and_infer_types(const std::vector<size_t>& layout_a,
const std::vector<size_t>& layout_b,
const std::vector<size_t>& layout_c) {
INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types);
INTERNAL_OP_SCOPE(GemmCPU_constructor_validate_and_infer_types);
validate_inputs();

const std::vector<ov::PartialShape> planar_input_shapes{
Expand All @@ -112,8 +112,8 @@ void BrgemmCPU::custom_constructor_validate_and_infer_types(const std::vector<si
validate_with_scratchpad();
}

void BrgemmCPU::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types);
void GemmCPU::validate_and_infer_types() {
INTERNAL_OP_SCOPE(GemmCPU_validate_and_infer_types);
validate_inputs();

const auto planar_input_shapes = get_planar_input_shapes({input(0), input(1)});
Expand All @@ -124,7 +124,7 @@ void BrgemmCPU::validate_and_infer_types() {
validate_with_scratchpad();
}

void BrgemmCPU::validate_with_scratchpad() const {
void GemmCPU::validate_with_scratchpad() const {
// Additional check for 3rd input
if (with_compensations(m_type)) {
OPENVINO_ASSERT(get_input_element_type(2) == ov::element::f32,
Expand All @@ -135,21 +135,21 @@ void BrgemmCPU::validate_with_scratchpad() const {
}
}

void BrgemmCPU::validate_inputs() const {
void GemmCPU::validate_inputs() const {
OPENVINO_ASSERT(
implication(one_of(m_type, BRGEMM_TYPE::STAND_ALONE, BRGEMM_TYPE::REPACKING_ONLY), get_input_size() == 2),
"BrgemmCPU expects 2 inputs in cases, when input precisions are f32|f32, u8|i8 or bf16|bf16 (non-AMX system)");
"GemmCPU expects 2 inputs in cases, when input precisions are f32|f32, u8|i8 or bf16|bf16 (non-AMX system)");
OPENVINO_ASSERT(
implication(one_of(m_type, BRGEMM_TYPE::WITH_COMPENSATIONS, BRGEMM_TYPE::WITH_AMX), get_input_size() == 3),
"BrgemmCPU expects 3 inputs with input precisions i8|i8 and bf16|bf16 on AMX system");
"GemmCPU expects 3 inputs with input precisions i8|i8 and bf16|bf16 on AMX system");
}

std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(BrgemmCPU_clone_with_new_inputs);
std::shared_ptr<Node> GemmCPU::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(GemmCPU_clone_with_new_inputs);
check_new_args_count(this, new_args);
std::shared_ptr<BrgemmCPU> brgemm;
std::shared_ptr<GemmCPU> brgemm;
if (!with_scratchpad(m_type)) {
return std::make_shared<BrgemmCPU>(
return std::make_shared<GemmCPU>(
new_args.at(0),
new_args.at(1),
m_type,
Expand All @@ -160,7 +160,7 @@ std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(),
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout());
} else {
return std::make_shared<BrgemmCPU>(
return std::make_shared<GemmCPU>(
new_args.at(0),
new_args.at(1),
new_args.at(2),
Expand All @@ -175,13 +175,13 @@ std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a
}
}

size_t BrgemmCPU::get_offset_scratch() const {
size_t GemmCPU::get_offset_scratch() const {
OPENVINO_ASSERT(with_scratchpad(m_type) && get_input_size() == 3,
"Offset of scratchpad must be only in Brgemm with scratchpad on 3rd input");
return get_input_offset(2);
}

bool BrgemmCPU::visit_attributes(AttributeVisitor& visitor) {
bool GemmCPU::visit_attributes(AttributeVisitor& visitor) {
Brgemm::visit_attributes(visitor);
visitor.on_attribute("type", m_type);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ namespace ov {
namespace intel_cpu {

/**
* @interface BrgemmCPU
* @brief BrgemmCPU is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
* @interface GemmCPU
* @brief GemmCPU is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
* with support of several precisions on plugin level
* @ingroup snippets
*/
class BrgemmCPU : public snippets::op::Brgemm {
class GemmCPU : public snippets::op::Brgemm {
public:
using BRGEMM_TYPE = brgemm_utils::BRGEMM_TYPE;
OPENVINO_OP("BrgemmCPU", "SnippetsOpset", snippets::op::Brgemm);
OPENVINO_OP("GemmCPU", "SnippetsOpset", snippets::op::Brgemm);

BrgemmCPU(const Output<Node>& A,
GemmCPU(const Output<Node>& A,
const Output<Node>& B,
BRGEMM_TYPE type,
const size_t offset_a = 0,
Expand All @@ -32,7 +32,7 @@ class BrgemmCPU : public snippets::op::Brgemm {
const std::vector<size_t>& layout_a = {},
const std::vector<size_t>& layout_b = {},
const std::vector<size_t>& layout_c = {});
BrgemmCPU(const Output<Node>& A,
GemmCPU(const Output<Node>& A,
const Output<Node>& B,
const Output<Node>& scratch,
BRGEMM_TYPE type,
Expand All @@ -43,7 +43,7 @@ class BrgemmCPU : public snippets::op::Brgemm {
const std::vector<size_t>& layout_a = {},
const std::vector<size_t>& layout_b = {},
const std::vector<size_t>& layout_c = {});
BrgemmCPU(const Output<Node>& A,
GemmCPU(const Output<Node>& A,
const Output<Node>& B,
BRGEMM_TYPE type,
const PortDescriptor& desc_a,
Expand All @@ -52,7 +52,7 @@ class BrgemmCPU : public snippets::op::Brgemm {
const std::vector<size_t>& layout_a = {},
const std::vector<size_t>& layout_b = {},
const std::vector<size_t>& layout_c = {});
BrgemmCPU(const Output<Node>& A,
GemmCPU(const Output<Node>& A,
const Output<Node>& B,
const Output<Node>& scratch,
BRGEMM_TYPE type,
Expand All @@ -63,7 +63,7 @@ class BrgemmCPU : public snippets::op::Brgemm {
const std::vector<size_t>& layout_a = {},
const std::vector<size_t>& layout_b = {},
const std::vector<size_t>& layout_c = {});
BrgemmCPU() = default;
GemmCPU() = default;

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm_to_brgemm_cpu.hpp"
#include "brgemm_to_gemm_cpu.hpp"

#include "cpu/x64/cpu_isa_traits.hpp"
#include "cpu_shape.h"
Expand All @@ -14,7 +14,7 @@
#include "snippets/op/buffer.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/modifiers.hpp"
#include "utils/general_utils.h"
Expand All @@ -34,21 +34,21 @@ void set_full_port_desc(const T& port) {
}
} // namespace

pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
MATCHER_SCOPE(BrgemmToBrgemmCPU);
pass::BrgemmToGemmCPU::BrgemmToGemmCPU() {
MATCHER_SCOPE(BrgemmToGemmCPU);
auto is_not_tpp = [](const Output<Node>& out) {
return !std::dynamic_pointer_cast<const intel_cpu::tpp::modifier::TensorProcessingPrimitive>(
out.get_node_shared_ptr());
};
auto m_brgemm = ov::pass::pattern::wrap_type<snippets::op::Brgemm>(is_not_tpp);

auto callback = [=](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::BrgemmToBrgemmCPU")
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::BrgemmToGemmCPU")
const auto node = m.get_match_root();
const auto brgemm = ov::as_type_ptr<snippets::op::Brgemm>(node);
const auto brgemm_plugin = ov::as_type_ptr<BrgemmCPU>(node);
const auto brgemm_plugin = ov::as_type_ptr<GemmCPU>(node);
if (!brgemm || brgemm_plugin) {
OPENVINO_THROW("BrgemmCPU cannot be in body before BrgemmToBrgemmCPU pass");
OPENVINO_THROW("GemmCPU cannot be in body before BrgemmToGemmCPU pass");
}

const auto& brgemm_in0_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(0));
Expand All @@ -69,10 +69,10 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
const auto offset_b = brgemm->get_offset_b();
const auto offset_c = brgemm->get_offset_c();

std::shared_ptr<BrgemmCPU> brgemm_cpu = nullptr;
std::shared_ptr<GemmCPU> brgemm_cpu = nullptr;
std::shared_ptr<BrgemmCopyB> brgemm_repacking = nullptr;
if (stand_alone(brgemm_type)) {
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0),
brgemm_cpu = std::make_shared<GemmCPU>(brgemm->input_value(0),
brgemm->input_value(1),
brgemm_type,
offset_a,
Expand All @@ -99,8 +99,8 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
}

if (with_amx(brgemm_type)) {
const auto scratch = std::make_shared<snippets::op::Buffer>(ov::Shape{BrgemmCPU::SCRATCH_BYTE_SIZE});
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0),
const auto scratch = std::make_shared<snippets::op::Buffer>(ov::Shape{GemmCPU::SCRATCH_BYTE_SIZE});
brgemm_cpu = std::make_shared<GemmCPU>(brgemm->input_value(0),
brgemm_repacking->output(0),
scratch,
brgemm_type,
Expand All @@ -114,7 +114,7 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
set_full_port_desc(scratch->output(0));
set_full_port_desc(brgemm_cpu->input(2));
} else if (with_compensations(brgemm_type)) {
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0),
brgemm_cpu = std::make_shared<GemmCPU>(brgemm->input_value(0),
brgemm_repacking->output(0),
brgemm_repacking->output(1),
brgemm_type,
Expand All @@ -126,7 +126,7 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
std::vector<size_t>{},
layout_c);
} else if (repacking_only(brgemm_type)) {
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0),
brgemm_cpu = std::make_shared<GemmCPU>(brgemm->input_value(0),
brgemm_repacking->output(0),
brgemm_type,
offset_a,
Expand Down
Loading

0 comments on commit c954669

Please sign in to comment.