diff --git a/paddle/cinn/utils/attribute_util.h b/paddle/cinn/utils/attribute_util.h index 17c1471c38c2d5..474bc09e2c64c2 100644 --- a/paddle/cinn/utils/attribute_util.h +++ b/paddle/cinn/utils/attribute_util.h @@ -20,6 +20,7 @@ #include "paddle/cinn/utils/type_defs.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/phi/common/data_type.h" +#include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" namespace cinn { @@ -61,7 +62,9 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) { AttributeMap dst_attrs; for (auto& item : src_attrs) { VLOG(4) << "deal with " << item.first; - if (item.second.isa()) { + if (item.first == ::pir::kStopGradientAttrName) { + continue; + } else if (item.second.isa()) { auto is_cpu = item.second.dyn_cast().data() == phi::CPUPlace(); diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index e86072fac58c32..df487f01e99307 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -594,6 +594,8 @@ def GenBuildOutputs( ) build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n" + # NOTE(Aurelius84): PassStopGradients must be placed after argument.AddOutputs. + build_output_str += " ::pir::PassStopGradientsDefaultly(argument);\n" return build_output_str diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index e92d7894152056..ed8ec2f65d9d6c 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -370,13 +370,14 @@ bool GetOpResultBoolAttr(const OpResult &self, const std::string &attr_name) { .AsVector(); return attrs[self.GetResultIndex()].dyn_cast().data(); } else { - return false; + return true; } } void SetOpResultBoolAttr(const OpResult &self, const std::string &attr_name, - bool value) { + bool value, + bool default_value) { auto *defining_op = self.owner(); std::vector attrs; if (defining_op->HasAttribute(attr_name)) { @@ -386,7 +387,7 @@ void SetOpResultBoolAttr(const OpResult &self, } else { attrs = std::vector( defining_op->num_results(), - pir::BoolAttribute::get(pir::IrContext::Instance(), false)); + pir::BoolAttribute::get(pir::IrContext::Instance(), default_value)); } attrs[self.GetResultIndex()] = pir::BoolAttribute::get(pir::IrContext::Instance(), value); @@ -479,7 +480,12 @@ void BindOpResult(py::module *m) { return GetOpResultBoolAttr(self, kAttrStopGradients); }, [](OpResult &self, bool stop_gradient) { - SetOpResultBoolAttr(self, kAttrStopGradients, stop_gradient); + // NOTE(Aurelius84): For other OpResult, set theirs stop_gradient + // default value as true. + SetOpResultBoolAttr(self, + kAttrStopGradients, + stop_gradient, + /*default_value=*/true); }) .def_property( "is_persistable", @@ -487,7 +493,12 @@ void BindOpResult(py::module *m) { return GetOpResultBoolAttr(self, kAttrIsPersisable); }, [](OpResult &self, bool is_persistable) { - SetOpResultBoolAttr(self, kAttrIsPersisable, is_persistable); + // NOTE(Aurelius84): For other OpResult, set theirs is_persistable + // default value as false. + SetOpResultBoolAttr(self, + kAttrIsPersisable, + is_persistable, + /*default_value=*/false); }) .def_property( "shape", diff --git a/paddle/pir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc index aba3ff9b282e43..b63e8d1d3c53d2 100644 --- a/paddle/pir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -22,6 +22,34 @@ namespace pir { const char *ModuleOp::attributes_name[attributes_num] = {"program"}; // NOLINT +void PassStopGradientsDefaultly(OperationArgument &argument) { // NOLINT + VLOG(4) << "Builder construction stop gradient for OpResults."; + bool stop_gradient = true; + for (auto &input : argument.inputs) { + if (input.Value::impl() == nullptr) continue; + + auto *defining_op = input.owner(); + bool input_stop_gradient = true; + if (defining_op->HasAttribute(kStopGradientAttrName)) { + auto attrs = defining_op->attribute(kStopGradientAttrName) + .dyn_cast() + .AsVector(); + input_stop_gradient = + attrs[input.GetResultIndex()].dyn_cast().data(); + } + if (!input_stop_gradient) { + stop_gradient = false; + break; + } + } + std::vector outs_stop_gradient( + argument.output_types.size(), + pir::BoolAttribute::get(pir::IrContext::Instance(), stop_gradient)); + argument.AddAttribute( + kStopGradientAttrName, + pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); +} + Program *ModuleOp::program() { const AttributeMap &attr = this->attributes(); auto iter = attr.find("program"); @@ -79,6 +107,15 @@ void GetParameterOp::Build(Builder &builder, argument.attributes[attributes_name[0]] = pir::StrAttribute::get(builder.ir_context(), name); argument.output_types.emplace_back(type); + PassStopGradients(argument); +} + +void GetParameterOp::PassStopGradients(OperationArgument &argument) { + std::vector outs_stop_gradient( + 1, pir::BoolAttribute::get(pir::IrContext::Instance(), false)); + argument.AddAttribute( + kStopGradientAttrName, + pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } void GetParameterOp::Verify() const { @@ -136,6 +173,7 @@ void CombineOp::Build(Builder &builder, argument.output_types.emplace_back( pir::VectorType::get(builder.ir_context(), inputs_type)); } + PassStopGradientsDefaultly(argument); } void CombineOp::Verify() const { @@ -177,6 +215,28 @@ void SliceOp::Build(Builder &builder, argument.output_types.emplace_back(input.type() .dyn_cast() .data()[static_cast(index)]); + PassStopGradients(argument, index); +} + +void SliceOp::PassStopGradients(OperationArgument &argument, int index) { + std::vector outs_stop_gradient( + 1, pir::BoolAttribute::get(pir::IrContext::Instance(), true)); + auto &input = argument.inputs[0]; + if (input.Value::impl() != nullptr) { + auto *defining_op = input.owner(); + if (defining_op && defining_op->isa()) { + IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName), + "Required CombineOp must have attribute %s", + kStopGradientAttrName); + auto attrs = defining_op->attribute(kStopGradientAttrName) + .dyn_cast() + .AsVector(); + outs_stop_gradient[0] = attrs[index]; + } + } + argument.AddAttribute( + kStopGradientAttrName, + pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } void SliceOp::Verify() const { @@ -233,6 +293,56 @@ void SplitOp::Build(Builder &builder, argument.output_types.emplace_back( input.type().dyn_cast().data()[idx]); } + + PassStopGradients(argument); +} + +void SplitOp::PassStopGradients(OperationArgument &argument) { + std::vector defaut_stop_gradients(argument.output_types.size(), true); + auto &input = argument.inputs[0]; + if (input.Value::impl() != nullptr) { + auto *defining_op = input.owner(); + if (defining_op && defining_op->isa()) { + IR_ENFORCE(argument.output_types.size(), + defining_op->num_operands(), + "Required SplitOp.output.size() == CombineOp.input.size(), " + "but received %d != %d", + argument.output_types.size(), + defining_op->num_operands()); + for (uint32_t i = 0; i < defining_op->num_operands(); ++i) { + auto value = defining_op->operand_source(i); + if (!value) continue; + auto *oprand_defining_op = value.GetDefiningOp(); + if (oprand_defining_op->HasAttribute(kStopGradientAttrName)) { + auto attrs = oprand_defining_op->attribute(kStopGradientAttrName) + .dyn_cast() + .AsVector(); + defaut_stop_gradients[i] = + attrs[value.dyn_cast().GetResultIndex()] + .dyn_cast() + .data(); + } + } + } else if (defining_op && + defining_op->HasAttribute(kStopGradientAttrName)) { + bool stop_gradient = defining_op->attribute(kStopGradientAttrName) + .dyn_cast() + .AsVector()[0] + .dyn_cast() + .data(); + defaut_stop_gradients.assign(defaut_stop_gradients.size(), stop_gradient); + } + } + + std::vector outs_stop_gradient; + outs_stop_gradient.reserve(argument.output_types.size()); + for (auto stop_gradient : defaut_stop_gradients) { + outs_stop_gradient.push_back( + pir::BoolAttribute::get(pir::IrContext::Instance(), stop_gradient)); + } + argument.AddAttribute( + kStopGradientAttrName, + pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } void SplitOp::Verify() const { diff --git a/paddle/pir/core/builtin_op.h b/paddle/pir/core/builtin_op.h index fee0ca406a7410..34e78e36b9b387 100644 --- a/paddle/pir/core/builtin_op.h +++ b/paddle/pir/core/builtin_op.h @@ -21,6 +21,7 @@ namespace pir { class Program; class Block; +constexpr char kStopGradientAttrName[] = "stop_gradient"; /// /// \brief ModuleOp /// @@ -56,6 +57,9 @@ class IR_API GetParameterOp : public pir::Op { const std::string &name, Type type); void Verify() const; + + private: + static void PassStopGradients(OperationArgument &argument); // NOLINT }; /// @@ -123,6 +127,10 @@ class IR_API SliceOp : public pir::Op { void Verify() const; pir::Value input() { return operand_source(0); } + + private: + static void PassStopGradients(OperationArgument &argument, // NOLINT + int index); }; /// @@ -151,6 +159,9 @@ class IR_API SplitOp : public pir::Op { } return outputs; } + + private: + static void PassStopGradients(OperationArgument &argument); // NOLINT }; class IR_API ConstantLikeTrait : public OpTraitBase { @@ -180,6 +191,8 @@ class IR_API ConstantOp : public Op { Attribute value() const; }; +void PassStopGradientsDefaultly(OperationArgument &argument); // NOLINT + } // namespace pir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ModuleOp) diff --git a/test/cpp/pir/core/ir_program_test.cc b/test/cpp/pir/core/ir_program_test.cc index b4a2ebc2522dc1..cafca1c97bdb22 100644 --- a/test/cpp/pir/core/ir_program_test.cc +++ b/test/cpp/pir/core/ir_program_test.cc @@ -274,7 +274,7 @@ TEST(program_test, builder) { EXPECT_EQ(program.block()->back(), full_op.operation()); EXPECT_EQ(full_op.num_operands(), 0u); EXPECT_EQ(full_op.num_results(), 1u); - EXPECT_EQ(full_op.attributes().size(), 4u); + EXPECT_EQ(full_op.attributes().size(), 5u); EXPECT_EQ( full_op_output.dyn_cast().offset() == 0, true); diff --git a/test/ir/new_ir/CMakeLists.txt b/test/ir/new_ir/CMakeLists.txt index be352019e2d50a..733ab1338b4d5c 100644 --- a/test/ir/new_ir/CMakeLists.txt +++ b/test/ir/new_ir/CMakeLists.txt @@ -5,7 +5,7 @@ file( string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") set(TEST_IR_SYSTEM_CASES test_build_model test_pd_inplace_pass - test_symbol_overload) + test_symbol_overload test_stop_gradient) list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES}) foreach(target ${TEST_INTERP_CASES}) diff --git a/test/ir/new_ir/test_stop_gradient.py b/test/ir/new_ir/test_stop_gradient.py new file mode 100644 index 00000000000000..52623d2058fe86 --- /dev/null +++ b/test/ir/new_ir/test_stop_gradient.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + + +class TestAPI(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def assert_api(self, api_func, stop_gradient): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = api_func() + self.assertEqual(x.stop_gradient, stop_gradient) + # test for setter + x.stop_gradient = not stop_gradient + self.assertEqual(x.stop_gradient, not stop_gradient) + + def test_full(self): + api = lambda: paddle.full(shape=[2, 3], fill_value=1.0) + self.assert_api(api, True) + + def test_data(self): + api = lambda: paddle.static.data('x', [4, 4], dtype='float32') + self.assert_api(api, True) + + # TODO(Aurelius84): Add more test cases after API is migrated. + + +class TestParametes(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def test_create_param(self): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + w = paddle.create_parameter(shape=[784, 200], dtype='float32') + self.assertEqual(w.stop_gradient, False) + self.assertEqual(w.is_persistable, True) + + # test for setter + w.stop_gradient = True + w.is_persistable = False + self.assertEqual(w.stop_gradient, True) + self.assertEqual(w.is_persistable, False) + + +if __name__ == '__main__': + unittest.main()