Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]support set value attribute by value. #59656

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 5 additions & 27 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,10 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
std::vector<pir::Attribute> outs_stop_gradient;
for (size_t i = 0; i < op.num_operands(); ++i) {
argument.AddOutput(op.operand(i).type());
bool input_stop_gradient = true;
auto *defining_op = op.operand_source(i).defining_op();
if (defining_op && defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
input_stop_gradient =
attrs[op.operand_source(i).dyn_cast<pir::OpResult>().index()]
.dyn_cast<pir::BoolAttribute>()
.data();
} else {
input_stop_gradient = false;
}
outs_stop_gradient.push_back(builder.bool_attr(input_stop_gradient));
auto bool_attr = op.operand_source(i).attribute<pir::BoolAttribute>(
kStopGradientAttrName);
outs_stop_gradient.push_back(bool_attr ? bool_attr
: builder.bool_attr(false));
}

argument.AddAttribute(
Expand Down Expand Up @@ -110,19 +100,7 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
argument.AddRegion().push_back(true_block.release());
argument.AddRegion().push_back(false_block.release());
argument.AddInput(cond);

auto cond_op = cond.defining_op();
if (cond_op && cond_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = cond_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
attrs[cond.dyn_cast<pir::OpResult>().index()] =
pir::BoolAttribute::get(pir::IrContext::Instance(), true);

cond_op->set_attribute(
kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs));
}
cond.set_attribute(kStopGradientAttrName, builder.bool_attr(true));
}

pir::Block &IfOp::true_block() {
Expand Down
26 changes: 25 additions & 1 deletion paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/pir/core/block_argument.h"
#include "paddle/common/enforce.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/operation_utils.h"
#include "paddle/pir/core/value_impl.h"

#define CHECK_NULL_IMPL(func_name) \
Expand All @@ -32,13 +34,27 @@ class BlockArgumentImpl : public ValueImpl {
return value.kind() == BLOCK_ARG_IDX;
}

///
/// \brief attribute related public interfaces
///
Attribute attribute(const std::string &key) const {
auto iter = attributes_.find(key);
return iter == attributes_.end() ? nullptr : iter->second;
}

void set_attribute(const std::string &key, Attribute value) {
attributes_[key] = value;
}

private:
BlockArgumentImpl(Type type, Block *owner, uint32_t index)
: ValueImpl(type, BLOCK_ARG_IDX), owner_(owner), index_(index) {}

~BlockArgumentImpl();
// access construction and owner
friend BlockArgument;

AttributeMap attributes_;
Block *owner_;
uint32_t index_;
};
Expand All @@ -63,10 +79,18 @@ Block *BlockArgument::owner() const {
}

uint32_t BlockArgument::index() const {
CHECK_NULL_IMPL(arg_index);
CHECK_NULL_IMPL(index);
return IMPL_->index_;
}

Attribute BlockArgument::attribute(const std::string &key) const {
return impl_ ? IMPL_->attribute(key) : nullptr;
}
void BlockArgument::set_attribute(const std::string &key, Attribute value) {
CHECK_NULL_IMPL(set_attribute);
return IMPL_->set_attribute(key, value);
}

BlockArgument BlockArgument::Create(Type type, Block *owner, uint32_t index) {
return new detail::BlockArgumentImpl(type, owner, index);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/core/block_argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class IR_API BlockArgument : public Value {
Block *owner() const;
uint32_t index() const;

Attribute attribute(const std::string &key) const;
void set_attribute(const std::string &key, Attribute value);

private:
/// constructor
BlockArgument(detail::BlockArgumentImpl *impl); // NOLINT
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/core/builtin_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ bool ArrayAttribute::empty() const { return storage()->empty(); }
Attribute ArrayAttribute::at(size_t index) const {
return storage()->at(index);
}
Attribute ArrayAttribute::operator[](size_t index) const {
return storage()->operator[](index);
}

ArrayAttribute ArrayAttribute::get(IrContext* ctx,
const std::vector<Attribute>& value) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/pir/core/builtin_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,12 @@ class IR_API ArrayAttribute : public Attribute {

bool empty() const;

// Returns element at specified location pos, with bounds checking.
Attribute at(size_t index) const;

// Returns element at specified location pos. No bounds checking is performed.
Attribute operator[](size_t index) const;

static ArrayAttribute get(IrContext* ctx,
const std::vector<Attribute>& value);
};
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/core/builtin_attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ struct ArrayAttributeStorage : public AttributeStorage {
size_);
return data_[index];
}
Attribute operator[](size_t index) const { return data_[index]; }

private:
Attribute *data_;
Expand Down
52 changes: 13 additions & 39 deletions paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,8 @@ void PassStopGradientsDefaultly(OperationArgument &argument) { // NOLINT
VLOG(4) << "Builder construction stop gradient for OpResults.";
bool stop_gradient = true;
for (auto value : argument.inputs) {
auto input = value.dyn_cast<OpResult>();
if (!input) continue;
auto *defining_op = input.owner();
bool input_stop_gradient = true;
if (defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
input_stop_gradient =
attrs[input.index()].dyn_cast<pir::BoolAttribute>().data();
}
if (!input_stop_gradient) {
auto attr = value.attribute<BoolAttribute>(kStopGradientAttrName);
if (attr && !attr.data()) {
stop_gradient = false;
break;
}
Expand All @@ -52,18 +42,8 @@ void PassStopGradientsDefaultly(OperationArgument &argument) { // NOLINT
void RefreshStopGradientsDefaultly(Operation *op) {
bool stop_gradient = true;
for (auto value : op->operands_source()) {
auto input = value.dyn_cast<OpResult>();
if (!input) continue;
auto *defining_op = input.owner();
bool input_stop_gradient = true;
if (defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
input_stop_gradient =
attrs[input.index()].dyn_cast<pir::BoolAttribute>().data();
}
if (!input_stop_gradient) {
auto attr = value.attribute<BoolAttribute>(kStopGradientAttrName);
if (attr && !attr.data()) {
stop_gradient = false;
break;
}
Expand Down Expand Up @@ -303,10 +283,9 @@ void SliceOp::RefreshStopGradients() {
IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName),
"Required CombineOp must have attribute %s",
kStopGradientAttrName);
auto attrs = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
outs_stop_gradient[0] = attrs[static_cast<int>(index)];
auto attr = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>();
outs_stop_gradient[0] = attr.at(static_cast<size_t>(index));
}
}
(*this)->set_attribute(
Expand Down Expand Up @@ -384,23 +363,18 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
argument.output_types.size(),
defining_op->num_operands());
for (uint32_t i = 0; i < defining_op->num_operands(); ++i) {
auto value = defining_op->operand_source(i);
if (!value) continue;
auto *oprand_defining_op = value.dyn_cast<OpResult>().owner();
if (oprand_defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = oprand_defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
defaut_stop_gradients[i] = attrs[value.dyn_cast<OpResult>().index()]
.dyn_cast<pir::BoolAttribute>()
.data();
auto attr =
defining_op->operand_source(i).attribute<pir::BoolAttribute>(
kStopGradientAttrName);
if (attr) {
defaut_stop_gradients[i] = attr.data();
}
}
} else if (defining_op &&
defining_op->HasAttribute(kStopGradientAttrName)) {
bool stop_gradient = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector()[0]
.at(0)
.dyn_cast<pir::BoolAttribute>()
.data();
defaut_stop_gradients.assign(defaut_stop_gradients.size(), stop_gradient);
Expand Down
17 changes: 8 additions & 9 deletions paddle/pir/core/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,22 @@ class IR_API OpBase {

Block *parent() const { return operation()->GetParent(); }

// Attribtue related interfaces
const AttributeMap &attributes() const { return operation()->attributes(); }
Attribute attribute(const std::string &key) const {
return operation()->attribute(key);
}
template <typename T>
T attribute(const std::string &key) const {
return operation()->attribute<T>(key);
}

Value operand_source(uint32_t index) const {
return operation()->operand_source(index);
}

OpResult result(uint32_t index) const { return operation()->result(index); }

pir::Attribute attribute(const std::string &name) const {
return operation()->attribute(name);
}

template <typename T>
T attribute(const std::string &name) const {
return operation()->attribute<T>(name);
}

void VerifySig() {}

void VerifyRegion() {}
Expand Down
9 changes: 9 additions & 0 deletions paddle/pir/core/op_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/pir/core/op_result.h"
#include "paddle/common/enforce.h"
#include "paddle/pir/core/attribute.h"
#include "paddle/pir/core/op_result_impl.h"

#define CHECK_OPRESULT_NULL_IMPL(func_name) \
Expand Down Expand Up @@ -46,6 +47,14 @@ bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}

Attribute OpResult::attribute(const std::string &key) const {
return impl_ ? IMPL_->attribute(key) : nullptr;
}
void OpResult::set_attribute(const std::string &key, Attribute value) {
CHECK_OPRESULT_NULL_IMPL(set_attribute);
return IMPL_->set_attribute(key, value);
}

OpResult::OpResult(detail::OpResultImpl *impl) : Value(impl) {}

} // namespace pir
3 changes: 3 additions & 0 deletions paddle/pir/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class IR_API OpResult : public Value {
uint32_t index() const;
bool operator==(const OpResult &other) const;

Attribute attribute(const std::string &key) const;
void set_attribute(const std::string &key, Attribute value);

private:
friend Operation;
OpResult(detail::OpResultImpl *impl); // NOLINT
Expand Down
62 changes: 49 additions & 13 deletions paddle/pir/core/op_result_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/core/op_result_impl.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/operation.h"

namespace pir {
Expand All @@ -31,21 +32,56 @@ OpResultImpl::~OpResultImpl() {
}
}

int32_t OpResultImpl::ComputeOperationOffset() const {
auto kind = this->kind();
// Compute inline op result offset.
if (kind < OUTLINE_RESULT_IDX) {
return static_cast<int32_t>((kind + 1u) * sizeof(OpInlineResultImpl));
}
// Compute outline op result offset.
constexpr int32_t outline_size =
static_cast<int32_t>(sizeof(OpOutlineResultImpl));
constexpr int32_t inline_size =
static_cast<int32_t>(sizeof(OpInlineResultImpl));
constexpr int32_t diff = OUTLINE_RESULT_IDX * (outline_size - inline_size);

auto index = static_cast<const OpOutlineResultImpl *>(this)->index();

return static_cast<int32_t>(index + 1) * outline_size - diff;
}

const Operation *OpResultImpl::owner() const {
int32_t offset = ComputeOperationOffset();
return reinterpret_cast<const Operation *>(
reinterpret_cast<const char *>(this) + offset);
}

Operation *OpResultImpl::owner() {
// For inline result, pointer offset index to obtain the address of op.
if (auto *result = dyn_cast<OpInlineResultImpl>(this)) {
result += result->index() + 1;
return reinterpret_cast<Operation *>(result);
int32_t offset = ComputeOperationOffset();
return reinterpret_cast<Operation *>(reinterpret_cast<char *>(this) + offset);
}

Attribute OpResultImpl::attribute(const std::string &key) const {
auto array = owner()->attribute<ArrayAttribute>(key);
auto index = this->index();
return array && array.size() > index ? array[index] : nullptr;
}

void OpResultImpl::set_attribute(const std::string &key, Attribute value) {
auto owner = this->owner();
auto attr = owner->attribute(key);
if (attr && !attr.isa<ArrayAttribute>()) {
IR_THROW(
"The %s attribute has existed as operation attriubute. Can't set it as "
"value attribute. ");
}
// For outline result, pointer offset outline_index to obtain the address of
// maximum inline result.
auto *outline_result = static_cast<OpOutlineResultImpl *>(this);
outline_result += (outline_result->index() - MAX_INLINE_RESULT_IDX);
// The offset of the maximum inline result distance op is
// GetMaxInlineResultIndex.
auto *inline_result = reinterpret_cast<OpInlineResultImpl *>(outline_result);
inline_result += OUTLINE_RESULT_IDX;
return reinterpret_cast<Operation *>(inline_result);
auto array_attr = attr.dyn_cast<ArrayAttribute>();
auto index = this->index();
std::vector<Attribute> vec;
if (array_attr) vec = array_attr.AsVector();
vec.resize(owner->num_results());
vec[index] = value;
owner->set_attribute(key, ArrayAttribute::get(owner->ir_context(), vec));
}

} // namespace detail
Expand Down
11 changes: 11 additions & 0 deletions paddle/pir/core/op_result_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class OpResultImpl : public ValueImpl {
/// \brief Get the parent operation of this result.(op_ptr = value_ptr +
/// index)
///
const Operation *owner() const;

Operation *owner();

///
Expand All @@ -41,6 +43,15 @@ class OpResultImpl : public ValueImpl {
uint32_t index() const;

~OpResultImpl();

///
/// \brief attribute related public interfaces
///
Attribute attribute(const std::string &key) const;
void set_attribute(const std::string &key, Attribute value);

private:
int32_t ComputeOperationOffset() const;
};

///
Expand Down
Loading