diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 16070fac12ae8..55791605ee49d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -61,20 +61,10 @@ void IfOp::Build(pir::Builder &builder, // NOLINT std::vector outs_stop_gradient; for (size_t i = 0; i < op.num_operands(); ++i) { argument.AddOutput(op.operand(i).type()); - bool input_stop_gradient = true; - auto *defining_op = op.operand_source(i).defining_op(); - if (defining_op && defining_op->HasAttribute(kStopGradientAttrName)) { - auto attrs = defining_op->attribute(kStopGradientAttrName) - .dyn_cast() - .AsVector(); - input_stop_gradient = - attrs[op.operand_source(i).dyn_cast().index()] - .dyn_cast() - .data(); - } else { - input_stop_gradient = false; - } - outs_stop_gradient.push_back(builder.bool_attr(input_stop_gradient)); + auto bool_attr = op.operand_source(i).attribute( + kStopGradientAttrName); + outs_stop_gradient.push_back(bool_attr ? bool_attr + : builder.bool_attr(false)); } argument.AddAttribute( @@ -110,19 +100,7 @@ void IfOp::Build(pir::Builder &builder, // NOLINT argument.AddRegion().push_back(true_block.release()); argument.AddRegion().push_back(false_block.release()); argument.AddInput(cond); - - auto cond_op = cond.defining_op(); - if (cond_op && cond_op->HasAttribute(kStopGradientAttrName)) { - auto attrs = cond_op->attribute(kStopGradientAttrName) - .dyn_cast() - .AsVector(); - attrs[cond.dyn_cast().index()] = - pir::BoolAttribute::get(pir::IrContext::Instance(), true); - - cond_op->set_attribute( - kStopGradientAttrName, - pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs)); - } + cond.set_attribute(kStopGradientAttrName, builder.bool_attr(true)); } pir::Block &IfOp::true_block() { diff --git a/paddle/pir/core/block_argument.cc b/paddle/pir/core/block_argument.cc index 66a18964280d3..014537c867c36 100644 --- a/paddle/pir/core/block_argument.cc +++ b/paddle/pir/core/block_argument.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/pir/core/block_argument.h" #include "paddle/common/enforce.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation_utils.h" #include "paddle/pir/core/value_impl.h" #define CHECK_NULL_IMPL(func_name) \ @@ -32,6 +34,18 @@ class BlockArgumentImpl : public ValueImpl { return value.kind() == BLOCK_ARG_IDX; } + /// + /// \brief attribute related public interfaces + /// + Attribute attribute(const std::string &key) const { + auto iter = attributes_.find(key); + return iter == attributes_.end() ? nullptr : iter->second; + } + + void set_attribute(const std::string &key, Attribute value) { + attributes_[key] = value; + } + private: BlockArgumentImpl(Type type, Block *owner, uint32_t index) : ValueImpl(type, BLOCK_ARG_IDX), owner_(owner), index_(index) {} @@ -39,6 +53,8 @@ class BlockArgumentImpl : public ValueImpl { ~BlockArgumentImpl(); // access construction and owner friend BlockArgument; + + AttributeMap attributes_; Block *owner_; uint32_t index_; }; @@ -63,10 +79,18 @@ Block *BlockArgument::owner() const { } uint32_t BlockArgument::index() const { - CHECK_NULL_IMPL(arg_index); + CHECK_NULL_IMPL(index); return IMPL_->index_; } +Attribute BlockArgument::attribute(const std::string &key) const { + return impl_ ? IMPL_->attribute(key) : nullptr; +} +void BlockArgument::set_attribute(const std::string &key, Attribute value) { + CHECK_NULL_IMPL(set_attribute); + return IMPL_->set_attribute(key, value); +} + BlockArgument BlockArgument::Create(Type type, Block *owner, uint32_t index) { return new detail::BlockArgumentImpl(type, owner, index); } diff --git a/paddle/pir/core/block_argument.h b/paddle/pir/core/block_argument.h index 890e37234b131..3c2afd273f552 100644 --- a/paddle/pir/core/block_argument.h +++ b/paddle/pir/core/block_argument.h @@ -33,6 +33,9 @@ class IR_API BlockArgument : public Value { Block *owner() const; uint32_t index() const; + Attribute attribute(const std::string &key) const; + void set_attribute(const std::string &key, Attribute value); + private: /// constructor BlockArgument(detail::BlockArgumentImpl *impl); // NOLINT diff --git a/paddle/pir/core/builtin_attribute.cc b/paddle/pir/core/builtin_attribute.cc index 0b7138e027605..a817fb48c55fc 100644 --- a/paddle/pir/core/builtin_attribute.cc +++ b/paddle/pir/core/builtin_attribute.cc @@ -54,6 +54,9 @@ bool ArrayAttribute::empty() const { return storage()->empty(); } Attribute ArrayAttribute::at(size_t index) const { return storage()->at(index); } +Attribute ArrayAttribute::operator[](size_t index) const { + return storage()->operator[](index); +} ArrayAttribute ArrayAttribute::get(IrContext* ctx, const std::vector& value) { diff --git a/paddle/pir/core/builtin_attribute.h b/paddle/pir/core/builtin_attribute.h index 24efb529c7f62..a1751a8c248b8 100644 --- a/paddle/pir/core/builtin_attribute.h +++ b/paddle/pir/core/builtin_attribute.h @@ -118,8 +118,12 @@ class IR_API ArrayAttribute : public Attribute { bool empty() const; + // Returns element at specified location pos, with bounds checking. Attribute at(size_t index) const; + // Returns element at specified location pos. No bounds checking is performed. + Attribute operator[](size_t index) const; + static ArrayAttribute get(IrContext* ctx, const std::vector& value); }; diff --git a/paddle/pir/core/builtin_attribute_storage.h b/paddle/pir/core/builtin_attribute_storage.h index c35d17e2544e6..533b0a4ad03e9 100644 --- a/paddle/pir/core/builtin_attribute_storage.h +++ b/paddle/pir/core/builtin_attribute_storage.h @@ -142,6 +142,7 @@ struct ArrayAttributeStorage : public AttributeStorage { size_); return data_[index]; } + Attribute operator[](size_t index) const { return data_[index]; } private: Attribute *data_; diff --git a/paddle/pir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc index 9f80b7a93a419..59f0e3cd856cf 100644 --- a/paddle/pir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -25,18 +25,8 @@ void PassStopGradientsDefaultly(OperationArgument &argument) { // NOLINT VLOG(4) << "Builder construction stop gradient for OpResults."; bool stop_gradient = true; for (auto value : argument.inputs) { - auto input = value.dyn_cast(); - if (!input) continue; - auto *defining_op = input.owner(); - bool input_stop_gradient = true; - if (defining_op->HasAttribute(kStopGradientAttrName)) { - auto attrs = defining_op->attribute(kStopGradientAttrName) - .dyn_cast() - .AsVector(); - input_stop_gradient = - attrs[input.index()].dyn_cast().data(); - } - if (!input_stop_gradient) { + auto attr = value.attribute(kStopGradientAttrName); + if (attr && !attr.data()) { stop_gradient = false; break; } @@ -52,18 +42,8 @@ void PassStopGradientsDefaultly(OperationArgument &argument) { // NOLINT void RefreshStopGradientsDefaultly(Operation *op) { bool stop_gradient = true; for (auto value : op->operands_source()) { - auto input = value.dyn_cast(); - if (!input) continue; - auto *defining_op = input.owner(); - bool input_stop_gradient = true; - if (defining_op->HasAttribute(kStopGradientAttrName)) { - auto attrs = defining_op->attribute(kStopGradientAttrName) - .dyn_cast() - .AsVector(); - input_stop_gradient = - attrs[input.index()].dyn_cast().data(); - } - if (!input_stop_gradient) { + auto attr = value.attribute(kStopGradientAttrName); + if (attr && !attr.data()) { stop_gradient = false; break; } @@ -303,10 +283,9 @@ void SliceOp::RefreshStopGradients() { IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName), "Required CombineOp must have attribute %s", kStopGradientAttrName); - auto attrs = defining_op->attribute(kStopGradientAttrName) - .dyn_cast() - .AsVector(); - outs_stop_gradient[0] = attrs[static_cast(index)]; + auto attr = defining_op->attribute(kStopGradientAttrName) + .dyn_cast(); + outs_stop_gradient[0] = attr.at(static_cast(index)); } } (*this)->set_attribute( @@ -384,23 +363,18 @@ void SplitOp::PassStopGradients(OperationArgument &argument) { argument.output_types.size(), defining_op->num_operands()); for (uint32_t i = 0; i < defining_op->num_operands(); ++i) { - auto value = defining_op->operand_source(i); - if (!value) continue; - auto *oprand_defining_op = value.dyn_cast().owner(); - if (oprand_defining_op->HasAttribute(kStopGradientAttrName)) { - auto attrs = oprand_defining_op->attribute(kStopGradientAttrName) - .dyn_cast() - .AsVector(); - defaut_stop_gradients[i] = attrs[value.dyn_cast().index()] - .dyn_cast() - .data(); + auto attr = + defining_op->operand_source(i).attribute( + kStopGradientAttrName); + if (attr) { + defaut_stop_gradients[i] = attr.data(); } } } else if (defining_op && defining_op->HasAttribute(kStopGradientAttrName)) { bool stop_gradient = defining_op->attribute(kStopGradientAttrName) .dyn_cast() - .AsVector()[0] + .at(0) .dyn_cast() .data(); defaut_stop_gradients.assign(defaut_stop_gradients.size(), stop_gradient); diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index c7f82954844d7..f0b0451f86461 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -49,7 +49,15 @@ class IR_API OpBase { Block *parent() const { return operation()->GetParent(); } + // Attribtue related interfaces const AttributeMap &attributes() const { return operation()->attributes(); } + Attribute attribute(const std::string &key) const { + return operation()->attribute(key); + } + template + T attribute(const std::string &key) const { + return operation()->attribute(key); + } Value operand_source(uint32_t index) const { return operation()->operand_source(index); @@ -57,15 +65,6 @@ class IR_API OpBase { OpResult result(uint32_t index) const { return operation()->result(index); } - pir::Attribute attribute(const std::string &name) const { - return operation()->attribute(name); - } - - template - T attribute(const std::string &name) const { - return operation()->attribute(name); - } - void VerifySig() {} void VerifyRegion() {} diff --git a/paddle/pir/core/op_result.cc b/paddle/pir/core/op_result.cc index 30c6ec97d8fba..868cddf630454 100644 --- a/paddle/pir/core/op_result.cc +++ b/paddle/pir/core/op_result.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/core/op_result.h" #include "paddle/common/enforce.h" +#include "paddle/pir/core/attribute.h" #include "paddle/pir/core/op_result_impl.h" #define CHECK_OPRESULT_NULL_IMPL(func_name) \ @@ -46,6 +47,14 @@ bool OpResult::operator==(const OpResult &other) const { return impl_ == other.impl_; } +Attribute OpResult::attribute(const std::string &key) const { + return impl_ ? IMPL_->attribute(key) : nullptr; +} +void OpResult::set_attribute(const std::string &key, Attribute value) { + CHECK_OPRESULT_NULL_IMPL(set_attribute); + return IMPL_->set_attribute(key, value); +} + OpResult::OpResult(detail::OpResultImpl *impl) : Value(impl) {} } // namespace pir diff --git a/paddle/pir/core/op_result.h b/paddle/pir/core/op_result.h index 5ca9164a04a23..dfb128037bb04 100644 --- a/paddle/pir/core/op_result.h +++ b/paddle/pir/core/op_result.h @@ -34,6 +34,9 @@ class IR_API OpResult : public Value { uint32_t index() const; bool operator==(const OpResult &other) const; + Attribute attribute(const std::string &key) const; + void set_attribute(const std::string &key, Attribute value); + private: friend Operation; OpResult(detail::OpResultImpl *impl); // NOLINT diff --git a/paddle/pir/core/op_result_impl.cc b/paddle/pir/core/op_result_impl.cc index d731de937bd5d..a9ca0a16ace9f 100644 --- a/paddle/pir/core/op_result_impl.cc +++ b/paddle/pir/core/op_result_impl.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pir/core/op_result_impl.h" +#include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/operation.h" namespace pir { @@ -31,21 +32,56 @@ OpResultImpl::~OpResultImpl() { } } +int32_t OpResultImpl::ComputeOperationOffset() const { + auto kind = this->kind(); + // Compute inline op result offset. + if (kind < OUTLINE_RESULT_IDX) { + return static_cast((kind + 1u) * sizeof(OpInlineResultImpl)); + } + // Compute outline op result offset. + constexpr int32_t outline_size = + static_cast(sizeof(OpOutlineResultImpl)); + constexpr int32_t inline_size = + static_cast(sizeof(OpInlineResultImpl)); + constexpr int32_t diff = OUTLINE_RESULT_IDX * (outline_size - inline_size); + + auto index = static_cast(this)->index(); + + return static_cast(index + 1) * outline_size - diff; +} + +const Operation *OpResultImpl::owner() const { + int32_t offset = ComputeOperationOffset(); + return reinterpret_cast( + reinterpret_cast(this) + offset); +} + Operation *OpResultImpl::owner() { - // For inline result, pointer offset index to obtain the address of op. - if (auto *result = dyn_cast(this)) { - result += result->index() + 1; - return reinterpret_cast(result); + int32_t offset = ComputeOperationOffset(); + return reinterpret_cast(reinterpret_cast(this) + offset); +} + +Attribute OpResultImpl::attribute(const std::string &key) const { + auto array = owner()->attribute(key); + auto index = this->index(); + return array && array.size() > index ? array[index] : nullptr; +} + +void OpResultImpl::set_attribute(const std::string &key, Attribute value) { + auto owner = this->owner(); + auto attr = owner->attribute(key); + if (attr && !attr.isa()) { + IR_THROW( + "The %s attribute has existed as operation attriubute. Can't set it as " + "value attribute. "); } - // For outline result, pointer offset outline_index to obtain the address of - // maximum inline result. - auto *outline_result = static_cast(this); - outline_result += (outline_result->index() - MAX_INLINE_RESULT_IDX); - // The offset of the maximum inline result distance op is - // GetMaxInlineResultIndex. - auto *inline_result = reinterpret_cast(outline_result); - inline_result += OUTLINE_RESULT_IDX; - return reinterpret_cast(inline_result); + auto array_attr = attr.dyn_cast(); + auto index = this->index(); + std::vector vec; + if (array_attr) vec = array_attr.AsVector(); + vec.resize(owner->num_results()); + vec[index] = value; + owner->set_attribute(key, ArrayAttribute::get(owner->ir_context(), vec)); } } // namespace detail diff --git a/paddle/pir/core/op_result_impl.h b/paddle/pir/core/op_result_impl.h index 8183eb9ef0283..c90e065b1c7bd 100644 --- a/paddle/pir/core/op_result_impl.h +++ b/paddle/pir/core/op_result_impl.h @@ -33,6 +33,8 @@ class OpResultImpl : public ValueImpl { /// \brief Get the parent operation of this result.(op_ptr = value_ptr + /// index) /// + const Operation *owner() const; + Operation *owner(); /// @@ -41,6 +43,15 @@ class OpResultImpl : public ValueImpl { uint32_t index() const; ~OpResultImpl(); + + /// + /// \brief attribute related public interfaces + /// + Attribute attribute(const std::string &key) const; + void set_attribute(const std::string &key, Attribute value); + + private: + int32_t ComputeOperationOffset() const; }; /// diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index 0c3f213adab50..90558f05fa674 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -65,15 +65,16 @@ class IR_API alignas(8) Operation final /// /// \brief op attribute related public interfaces /// + const AttributeMap &attributes() const { return attributes_; } + // return nullptr if attribute not found. Attribute attribute(const std::string &key) const { - return attributes_.at(key); + auto iter = attributes_.find(key); + return iter == attributes_.end() ? nullptr : iter->second; } - const AttributeMap &attributes() const { return attributes_; } + template - T attribute(const std::string &name) { - Attribute attr = attribute(name); - IR_ENFORCE(attr.isa(), "Attribute (%s) type is not right.", name); - return attr.dyn_cast(); + T attribute(const std::string &key) const { + return attribute(key).dyn_cast(); } void set_attribute(const std::string &key, Attribute value) { attributes_[key] = value; diff --git a/paddle/pir/core/value.cc b/paddle/pir/core/value.cc index 8bdda56a5d75e..03345c63702bf 100644 --- a/paddle/pir/core/value.cc +++ b/paddle/pir/core/value.cc @@ -97,4 +97,16 @@ void Value::ReplaceAllUsesWith(Value new_value) const { } } +Attribute Value::attribute(const std::string &key) const { + auto op_result = dyn_cast(); + if (op_result) return op_result.attribute(key); + return dyn_cast().attribute(key); +} + +void Value::set_attribute(const std::string &key, Attribute value) { + auto op_result = dyn_cast(); + if (op_result) return op_result.set_attribute(key, value); + return dyn_cast().set_attribute(key, value); +} + } // namespace pir diff --git a/paddle/pir/core/value.h b/paddle/pir/core/value.h index 11d1193bbc068..3a62cab47ee89 100644 --- a/paddle/pir/core/value.h +++ b/paddle/pir/core/value.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/pir/core/attribute.h" #include "paddle/pir/core/iterator.h" #include "paddle/pir/core/op_operand.h" #include "paddle/pir/core/type.h" @@ -100,6 +101,19 @@ class IR_API Value { void ReplaceAllUsesWith(Value new_value) const; detail::ValueImpl *impl() const { return impl_; } + /// + /// \brief attribute related public interfaces + /// + // return nullptr if value is null or attribute not found. + Attribute attribute(const std::string &key) const; + + template + T attribute(const std::string &name) { + return attribute(name).dyn_cast(); + } + + void set_attribute(const std::string &key, Attribute value); + protected: detail::ValueImpl *impl_{nullptr}; }; diff --git a/paddle/pir/core/value_impl.h b/paddle/pir/core/value_impl.h index 0720360d563bc..bfdd08b4d1316 100644 --- a/paddle/pir/core/value_impl.h +++ b/paddle/pir/core/value_impl.h @@ -67,7 +67,7 @@ class alignas(8) ValueImpl { } template - bool isa() { + bool isa() const { return T::classof(*this); }