Skip to content

Commit

Permalink
[PIR]Support PassStopGradients mechanism and Set stop_gradient as True (
Browse files Browse the repository at this point in the history
PaddlePaddle#57255)

* [PIR]Set stop_gradient as True and is_persistable as False

* add PassStopGradients for build op

* fix unittest

* fix lint

* fix uint32_t

* fix ut
  • Loading branch information
Aurelius84 authored Sep 15, 2023
1 parent 3a75f1b commit 4a3e44d
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 8 deletions.
5 changes: 4 additions & 1 deletion paddle/cinn/utils/attribute_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"

namespace cinn {
Expand Down Expand Up @@ -61,7 +62,9 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) {
AttributeMap dst_attrs;
for (auto& item : src_attrs) {
VLOG(4) << "deal with " << item.first;
if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
if (item.first == ::pir::kStopGradientAttrName) {
continue;
} else if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
auto is_cpu =
item.second.dyn_cast<paddle::dialect::PlaceAttribute>().data() ==
phi::CPUPlace();
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ def GenBuildOutputs(
)

build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n"
# NOTE(Aurelius84): PassStopGradients must be placed after argument.AddOutputs.
build_output_str += " ::pir::PassStopGradientsDefaultly(argument);\n"

return build_output_str

Expand Down
21 changes: 16 additions & 5 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,14 @@ bool GetOpResultBoolAttr(const OpResult &self, const std::string &attr_name) {
.AsVector();
return attrs[self.GetResultIndex()].dyn_cast<pir::BoolAttribute>().data();
} else {
return false;
return true;
}
}

void SetOpResultBoolAttr(const OpResult &self,
const std::string &attr_name,
bool value) {
bool value,
bool default_value) {
auto *defining_op = self.owner();
std::vector<pir::Attribute> attrs;
if (defining_op->HasAttribute(attr_name)) {
Expand All @@ -386,7 +387,7 @@ void SetOpResultBoolAttr(const OpResult &self,
} else {
attrs = std::vector<pir::Attribute>(
defining_op->num_results(),
pir::BoolAttribute::get(pir::IrContext::Instance(), false));
pir::BoolAttribute::get(pir::IrContext::Instance(), default_value));
}
attrs[self.GetResultIndex()] =
pir::BoolAttribute::get(pir::IrContext::Instance(), value);
Expand Down Expand Up @@ -479,15 +480,25 @@ void BindOpResult(py::module *m) {
return GetOpResultBoolAttr(self, kAttrStopGradients);
},
[](OpResult &self, bool stop_gradient) {
SetOpResultBoolAttr(self, kAttrStopGradients, stop_gradient);
// NOTE(Aurelius84): For other OpResult, set theirs stop_gradient
// default value as true.
SetOpResultBoolAttr(self,
kAttrStopGradients,
stop_gradient,
/*default_value=*/true);
})
.def_property(
"is_persistable",
[](OpResult &self) {
return GetOpResultBoolAttr(self, kAttrIsPersisable);
},
[](OpResult &self, bool is_persistable) {
SetOpResultBoolAttr(self, kAttrIsPersisable, is_persistable);
// NOTE(Aurelius84): For other OpResult, set theirs is_persistable
// default value as false.
SetOpResultBoolAttr(self,
kAttrIsPersisable,
is_persistable,
/*default_value=*/false);
})
.def_property(
"shape",
Expand Down
110 changes: 110 additions & 0 deletions paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ namespace pir {

const char *ModuleOp::attributes_name[attributes_num] = {"program"}; // NOLINT

void PassStopGradientsDefaultly(OperationArgument &argument) { // NOLINT
VLOG(4) << "Builder construction stop gradient for OpResults.";
bool stop_gradient = true;
for (auto &input : argument.inputs) {
if (input.Value::impl() == nullptr) continue;

auto *defining_op = input.owner();
bool input_stop_gradient = true;
if (defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
input_stop_gradient =
attrs[input.GetResultIndex()].dyn_cast<pir::BoolAttribute>().data();
}
if (!input_stop_gradient) {
stop_gradient = false;
break;
}
}
std::vector<pir::Attribute> outs_stop_gradient(
argument.output_types.size(),
pir::BoolAttribute::get(pir::IrContext::Instance(), stop_gradient));
argument.AddAttribute(
kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

Program *ModuleOp::program() {
const AttributeMap &attr = this->attributes();
auto iter = attr.find("program");
Expand Down Expand Up @@ -79,6 +107,15 @@ void GetParameterOp::Build(Builder &builder,
argument.attributes[attributes_name[0]] =
pir::StrAttribute::get(builder.ir_context(), name);
argument.output_types.emplace_back(type);
PassStopGradients(argument);
}

void GetParameterOp::PassStopGradients(OperationArgument &argument) {
std::vector<pir::Attribute> outs_stop_gradient(
1, pir::BoolAttribute::get(pir::IrContext::Instance(), false));
argument.AddAttribute(
kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void GetParameterOp::Verify() const {
Expand Down Expand Up @@ -136,6 +173,7 @@ void CombineOp::Build(Builder &builder,
argument.output_types.emplace_back(
pir::VectorType::get(builder.ir_context(), inputs_type));
}
PassStopGradientsDefaultly(argument);
}

void CombineOp::Verify() const {
Expand Down Expand Up @@ -177,6 +215,28 @@ void SliceOp::Build(Builder &builder,
argument.output_types.emplace_back(input.type()
.dyn_cast<pir::VectorType>()
.data()[static_cast<size_t>(index)]);
PassStopGradients(argument, index);
}

void SliceOp::PassStopGradients(OperationArgument &argument, int index) {
std::vector<pir::Attribute> outs_stop_gradient(
1, pir::BoolAttribute::get(pir::IrContext::Instance(), true));
auto &input = argument.inputs[0];
if (input.Value::impl() != nullptr) {
auto *defining_op = input.owner();
if (defining_op && defining_op->isa<CombineOp>()) {
IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName),
"Required CombineOp must have attribute %s",
kStopGradientAttrName);
auto attrs = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
outs_stop_gradient[0] = attrs[index];
}
}
argument.AddAttribute(
kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void SliceOp::Verify() const {
Expand Down Expand Up @@ -233,6 +293,56 @@ void SplitOp::Build(Builder &builder,
argument.output_types.emplace_back(
input.type().dyn_cast<pir::VectorType>().data()[idx]);
}

PassStopGradients(argument);
}

void SplitOp::PassStopGradients(OperationArgument &argument) {
std::vector<bool> defaut_stop_gradients(argument.output_types.size(), true);
auto &input = argument.inputs[0];
if (input.Value::impl() != nullptr) {
auto *defining_op = input.owner();
if (defining_op && defining_op->isa<CombineOp>()) {
IR_ENFORCE(argument.output_types.size(),
defining_op->num_operands(),
"Required SplitOp.output.size() == CombineOp.input.size(), "
"but received %d != %d",
argument.output_types.size(),
defining_op->num_operands());
for (uint32_t i = 0; i < defining_op->num_operands(); ++i) {
auto value = defining_op->operand_source(i);
if (!value) continue;
auto *oprand_defining_op = value.GetDefiningOp();
if (oprand_defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = oprand_defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
defaut_stop_gradients[i] =
attrs[value.dyn_cast<OpResult>().GetResultIndex()]
.dyn_cast<pir::BoolAttribute>()
.data();
}
}
} else if (defining_op &&
defining_op->HasAttribute(kStopGradientAttrName)) {
bool stop_gradient = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector()[0]
.dyn_cast<pir::BoolAttribute>()
.data();
defaut_stop_gradients.assign(defaut_stop_gradients.size(), stop_gradient);
}
}

std::vector<pir::Attribute> outs_stop_gradient;
outs_stop_gradient.reserve(argument.output_types.size());
for (auto stop_gradient : defaut_stop_gradients) {
outs_stop_gradient.push_back(
pir::BoolAttribute::get(pir::IrContext::Instance(), stop_gradient));
}
argument.AddAttribute(
kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void SplitOp::Verify() const {
Expand Down
13 changes: 13 additions & 0 deletions paddle/pir/core/builtin_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace pir {

class Program;
class Block;
constexpr char kStopGradientAttrName[] = "stop_gradient";
///
/// \brief ModuleOp
///
Expand Down Expand Up @@ -56,6 +57,9 @@ class IR_API GetParameterOp : public pir::Op<GetParameterOp> {
const std::string &name,
Type type);
void Verify() const;

private:
static void PassStopGradients(OperationArgument &argument); // NOLINT
};

///
Expand Down Expand Up @@ -123,6 +127,10 @@ class IR_API SliceOp : public pir::Op<SliceOp> {

void Verify() const;
pir::Value input() { return operand_source(0); }

private:
static void PassStopGradients(OperationArgument &argument, // NOLINT
int index);
};

///
Expand Down Expand Up @@ -151,6 +159,9 @@ class IR_API SplitOp : public pir::Op<SplitOp> {
}
return outputs;
}

private:
static void PassStopGradients(OperationArgument &argument); // NOLINT
};

class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
Expand Down Expand Up @@ -180,6 +191,8 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
Attribute value() const;
};

void PassStopGradientsDefaultly(OperationArgument &argument); // NOLINT

} // namespace pir

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ModuleOp)
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/core/ir_program_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ TEST(program_test, builder) {
EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op.num_operands(), 0u);
EXPECT_EQ(full_op.num_results(), 1u);
EXPECT_EQ(full_op.attributes().size(), 4u);
EXPECT_EQ(full_op.attributes().size(), 5u);
EXPECT_EQ(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0,
true);
Expand Down
2 changes: 1 addition & 1 deletion test/ir/new_ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ file(
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")

set(TEST_IR_SYSTEM_CASES test_build_model test_pd_inplace_pass
test_symbol_overload)
test_symbol_overload test_stop_gradient)
list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES})

foreach(target ${TEST_INTERP_CASES})
Expand Down
63 changes: 63 additions & 0 deletions test/ir/new_ir/test_stop_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle


class TestAPI(unittest.TestCase):
def setUp(self):
paddle.enable_static()

def assert_api(self, api_func, stop_gradient):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = api_func()
self.assertEqual(x.stop_gradient, stop_gradient)
# test for setter
x.stop_gradient = not stop_gradient
self.assertEqual(x.stop_gradient, not stop_gradient)

def test_full(self):
api = lambda: paddle.full(shape=[2, 3], fill_value=1.0)
self.assert_api(api, True)

def test_data(self):
api = lambda: paddle.static.data('x', [4, 4], dtype='float32')
self.assert_api(api, True)

# TODO(Aurelius84): Add more test cases after API is migrated.


class TestParametes(unittest.TestCase):
def setUp(self):
paddle.enable_static()

def test_create_param(self):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
w = paddle.create_parameter(shape=[784, 200], dtype='float32')
self.assertEqual(w.stop_gradient, False)
self.assertEqual(w.is_persistable, True)

# test for setter
w.stop_gradient = True
w.is_persistable = False
self.assertEqual(w.stop_gradient, True)
self.assertEqual(w.is_persistable, False)


if __name__ == '__main__':
unittest.main()

0 comments on commit 4a3e44d

Please sign in to comment.