Skip to content

Commit

Permalink
[IR] rectify the verify api (#54895)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Jun 27, 2023
1 parent e49c17d commit 9665226
Show file tree
Hide file tree
Showing 18 changed files with 427 additions and 389 deletions.
4 changes: 1 addition & 3 deletions paddle/fluid/ir/dialect/kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ namespace dialect {
const char *PhiKernelOp::attributes_name[attributes_num] = {
"base_op", "infermeta_fn", "kernel_fn"};

void PhiKernelOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void PhiKernelOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp.";

// Verify inputs type:
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/ir/dialect/kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> {
static const char *name() { return "phi.kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
void Verify();
};

} // namespace dialect
Expand Down
247 changes: 16 additions & 231 deletions paddle/fluid/ir/dialect/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os

import yaml
from op_verify_gen import gen_verify_func_str

# =====================================
# String Template for h file code gen
Expand Down Expand Up @@ -65,7 +66,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static OpInfoTuple GetOpInfo();
static void Build({build_args});
{build_mutable_attr_is_input}
static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
void Verify();
{get_inputs_and_outputs}
{exclusive_interface}
}};
Expand Down Expand Up @@ -141,105 +142,6 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
{build_outputs}
}}
"""

# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}.";
// Verify inputs type:
PADDLE_ENFORCE_EQ(inputs.size(), {inputs_size},
phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", inputs.size()));
{inputs_type_check}
// Verify outputs type:
PADDLE_ENFORCE_EQ(outputs.size(), {outputs_size},
phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", outputs.size()));
{outputs_type_check}
// Verify if attributes contain attribute name in attributes_name:
{attributes_check}
}}
"""

GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
(void)inputs;
(void)outputs;
(void)attributes;
}}
"""

INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{
if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
"""

OUTPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
"""

ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}
"""
OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferShape( phi::InferMetaContext *infer_meta ) {{
auto fn = PD_INFER_META(phi::{infer_meta_func});
Expand Down Expand Up @@ -1004,8 +906,8 @@ def GenBuildOutputs(
}}
"""

CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""

CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
Expand Down Expand Up @@ -1557,135 +1459,18 @@ def OpGenerator(
view=view_str,
)

# =================================== #
# gen Verify func str #
# =================================== #
# generate op verify function: inputs_type_check_str
if (
len(op_input_type_list) + len(op_mutable_attribute_name_list)
) == 0:
inputs_type_check_str = (
"// Inputs num is 0, not need to check inputs type."
)
else:
inputs_type_check_str = ""
for idx in range(len(op_input_type_list)):
input_type = op_input_type_list[idx]
is_optional = op_input_optional_list[idx]
is_vector = False
if input_type.startswith("ir::VectorType<"):
is_vector = True
input_type = input_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = (
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
)
else:
check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
if is_vector:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
inputs_type_check_str += check_str

for idx in range(len(op_mutable_attribute_name_list)):
mutable_attribute_type = op_mutable_attribute_type_list[idx][0]
check_str = ""
if mutable_attribute_type == "paddle::dialect::ScalarAttribute":
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
else:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
inputs_type_check_str += check_str
# generate op verify function: outputs_type_check_str
if len(op_output_type_list) == 0:
outputs_type_check_str = (
"// Outputs num is 0, not need to check outputs type."
)
else:
outputs_type_check_str = ""
for idx in range(len(op_output_type_list)):
output_type = op_output_type_list[idx]
is_optional = op_output_optional_list[idx]
is_vector = False
if output_type.startswith("ir::VectorType<"):
is_vector = True
output_type = output_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = (
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
)
else:
check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
if is_vector:
check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
outputs_type_check_str += check_str
# generate op verify function: attributes_check_str
if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = (
"// Attributes num is 0, not need to check attributes type."
)
else:
attributes_check_str = ""
for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_name = op_non_mutable_attribute_name_list[idx]
attribute_type = op_non_mutable_attribute_type_list[idx]
if attribute_type.startswith("ir::ArrayAttribute<"):
attribute_type = attribute_type[19:-1]
attributes_check_str += (
ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format(
attribute_name=attribute_name,
standard=attribute_type,
)
)
else:
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type
)
# generate op verify function
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
)
else:
op_verify_str = OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
inputs_size=len(op_input_type_list)
+ len(op_mutable_attribute_type_list),
outputs_size=len(op_output_type_list),
inputs_type_check=inputs_type_check_str,
outputs_type_check=outputs_type_check_str,
attributes_check=attributes_check_str,
)
# generate op verify function str
op_verify_str = gen_verify_func_str(
op_class_name,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_output_type_list,
op_output_optional_list,
)

op_infer_shape_str = ""
if op_info.infer_shape_func:
Expand Down
Loading

0 comments on commit 9665226

Please sign in to comment.