From e70e8a596539d861d7bda5cb66a101cd9d5e3c93 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Fri, 13 Oct 2023 15:13:18 +0800 Subject: [PATCH] [Prim][PIR] Support composite rules of Llama ops (#58018) * pir prim support decomposite rule of ops * PIR Prim support ops * change unnecessary error raise * fix code * fix code * support prim op register * support max vjp * fix code * fixx test case --- .../op_generator/vjp_interface_gen_op_list.py | 2 + paddle/fluid/primitive/codegen/gen.py | 12 + paddle/fluid/primitive/primitive.yaml | 1 + paddle/fluid/primitive/rule/vjp/details.h | 300 ++++++++++++++++++ paddle/fluid/pybind/pir.cc | 8 + paddle/phi/api/yaml/legacy_backward.yaml | 2 +- python/paddle/autograd/ir_backward.py | 3 +- python/paddle/decomposition/decomp.py | 8 +- python/paddle/decomposition/rules.py | 138 +++++++- test/legacy_test/test_activation_op.py | 132 ++++++-- test/legacy_test/test_dropout_op.py | 4 +- test/legacy_test/test_erf_op.py | 23 +- test/legacy_test/test_expand_v2_op.py | 25 +- test/legacy_test/test_full_like_op.py | 4 +- test/legacy_test/test_gather_nd_op.py | 80 ++++- test/legacy_test/test_pad_op.py | 23 +- test/legacy_test/test_slice_op.py | 16 +- test/legacy_test/test_softmax_op.py | 96 +++++- test/legacy_test/test_squeeze2_op.py | 13 +- test/legacy_test/test_stack_op.py | 20 +- test/legacy_test/test_tile_op.py | 23 +- test/legacy_test/test_unsqueeze2_op.py | 13 +- 22 files changed, 870 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index 3a559ef8dedf84..58abcbf1143b9f 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -82,6 +82,7 @@ 'maximum', 'argsort', 'min', + 'max', 'batch_norm', 'max_pool2d_with_index', 'pool2d', @@ -183,6 +184,7 @@ 'maximum', 'argsort', 'min', + 'max', 'batch_norm', 'max_pool2d_with_index', 'pool2d', diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 96630e50bd5cf6..f9f0d5c32b11c9 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -57,11 +57,23 @@ 'tanh_grad', 'transpose_grad', 'concat_grad', + 'erf_grad', + 'exp_grad', + 'expand_grad', + 'log_grad', + 'gather_nd_grad', + 'pad_grad', + 'max_grad', + 'slice_grad', + 'tile_grad', ] # vjp list of primitive op CUSTOM_VJP = [ 'gelu_grad', 'layer_norm_grad', 'dropout_grad', + 'silu_grad', + 'softmax_grad', + 'sqrt_grad', ] # custom vjp list of composite op VJP_COMPS = PRIM_VJP + CUSTOM_VJP diff --git a/paddle/fluid/primitive/primitive.yaml b/paddle/fluid/primitive/primitive.yaml index 794f1121da679a..85ffc28a20d20e 100644 --- a/paddle/fluid/primitive/primitive.yaml +++ b/paddle/fluid/primitive/primitive.yaml @@ -51,3 +51,4 @@ - full - cast - sign +- slice diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 4e2d7d4732b89a..3211af3bb6a98a 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -536,6 +536,306 @@ void dropout_grad(const Tensor& mask, } } +template +void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto m_2_sqrt_pi = full(phi::vectorize(x.dims()), M_2_SQRTPI, x.dtype()); + auto neg_one = full(phi::vectorize(x.dims()), -1.0, x.dtype()); + auto neg_tmp = neg_one * x * x; + auto mul_tmp = m_2_sqrt_pi * exp(neg_tmp); + set_output(out_grad * mul_tmp, x_grad); + } +} + +template +void expand_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& shape, + Tensor* x_grad) { + if (x_grad) { + auto out_dims = phi::make_ddim(shape.GetData()); + if (out_dims != x.dims()) { + auto axes = get_reduce_dims(x.dims(), out_dims); + if (!axes.size()) { + by_pass(out_grad, x_grad); + } else { + auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false); + if (reduced.dims().size() != x.dims().size()) { + reduced = reshape(reduced, x.shape()); + } + set_output(reduced, x_grad); + } + } else { + by_pass(out_grad, x_grad); + } + } +} + +template +void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + // dx = dout / x + set_output(out_grad / x, x_grad); + } +} + +template +void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + if (out.dtype() == phi::DataType::FLOAT16 || + out.dtype() == phi::DataType::BFLOAT16) { + Tensor out_promote = cast(out, phi::DataType::FLOAT32); + Tensor out_grad_promote = cast(out_grad, phi::DataType::FLOAT32); + set_output(cast(out_promote * out_grad_promote, out.dtype()), + x_grad); + } else { + set_output(out_grad * out, x_grad); + } + } +} + +template +void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + // This calculation is important for resnet. + auto x_grad_tmp = (0.5 / out) * out_grad; + set_output(x_grad_tmp, x_grad); + } +} + +template +void silu_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + Tensor* x_grad) { + if (x_grad) { + auto org_dtype = x.dtype(); + bool need_cast = org_dtype == phi::DataType::FLOAT16 || + org_dtype == phi::DataType::BFLOAT16; + if (need_cast) { + auto x_cast = cast(x, phi::DataType::FLOAT32); + auto out_cast = cast(out, phi::DataType::FLOAT32); + auto out_grad_cast = cast(out_grad, phi::DataType::FLOAT32); + auto sigmoid = 1.0 / (1.0 + exp(-x_cast)); + auto res = out_grad_cast * sigmoid * (1.0 + x_cast - out_cast); + set_output(cast(res, org_dtype), x_grad); + } else { + auto sigmoid = 1.0 / (1.0 + exp(-x)); + auto res = out_grad * sigmoid * (1.0 + x - out); + set_output(res, x_grad); + } + } +} + +template +void softmax_grad(const Tensor& out, + const Tensor& out_grad, + int axis, + Tensor* x_grad) { + if (x_grad) { + if (out_grad.dims().size() > 0) { + if (axis >= 0) { + auto new_out_grad = out_grad * out; + auto tmp_x_grad = new_out_grad - + out * sum(new_out_grad, {axis}, out.dtype(), true); + set_output(tmp_x_grad, x_grad); + } else { + auto new_out_grad = out_grad * out; + auto tmp_x_grad = + new_out_grad - out * sum(new_out_grad, + {out.dims().size() + axis}, + out.dtype(), + true); + set_output(tmp_x_grad, x_grad); + } + } else { + set_output( + full(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()), + x_grad); + } + } +} + +template +void gather_nd_grad(const Tensor& x, + const Tensor& index, + const Tensor& out_grad, + Tensor* x_grad) { + if (x_grad) { + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + auto x_grad_tmp = scatter_nd_add(zero_tensor, index, out_grad); + set_output(x_grad_tmp, x_grad); + } +} + +template +void pad_grad(const Tensor& input, + const Tensor& out_grad, + const std::vector& paddings, + const Scalar& pad_value, + Tensor* input_grad) { + if (input_grad) { + size_t rank = input.dims().size(); + auto out_dims = out_grad.dims(); + + std::vector starts(rank, 0); + std::vector ends(rank, 0); + std::vector axes(rank, 0); + std::vector infer_flags(rank, 1); + std::vector decrease_axis({}); + for (size_t i = 0; i < rank; ++i) { + starts[i] = static_cast(paddings[2 * i]); + ends[i] = static_cast(out_dims[i] - paddings[2 * i + 1]); + axes[i] = i; + } + auto out_tmp = + slice(out_grad, axes, starts, ends, infer_flags, decrease_axis); + set_output(out_tmp, input_grad); + } +} + +template +void max_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + std::vector x_dim = phi::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + if (x_dim_size == 0 || x_dim_size == 1 || keepdim) { + auto out_grad_tmp = out_grad.expand(IntArray(x_dim)); + auto out_tmp = out.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } else { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 0; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_ = reshape(out_grad, out_grad_shape); + auto out_ = reshape(out, out_grad_shape); + auto out_grad_tmp = out_grad_.expand(IntArray(x_dim)); + auto out_tmp = out_.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } + set_output(x_grad_tmp, x_grad); +} + +template +void slice_grad(const Tensor& input, + const Tensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + Tensor* input_grad) { + if (input_grad) { + size_t rank = input.dims().size(); + auto out_dims = out_grad.dims(); + std::vector origin_out_shape; + auto in_dims = input.dims(); + + auto decrease_size = decrease_axis.size(); + if (decrease_size > 0) { + if (decrease_size == static_cast(in_dims.size())) { + // all dims decrease + out_dims = phi::make_ddim(std::vector(decrease_size, 1)); + } else { + origin_out_shape.resize(out_dims.size() + decrease_size, -1); + for (size_t i = 0; i < decrease_size; ++i) { + origin_out_shape[decrease_axis[i]] = 1; + } + + int index = 0; + for (size_t i = 0; i < origin_out_shape.size(); ++i) { + if (origin_out_shape[i] == -1) { + origin_out_shape[i] = out_dims[index]; + ++index; + } + } + out_dims = phi::make_ddim(origin_out_shape); + } + } + + std::vector offsets(rank, 0); + std::vector extents(rank, 0); + for (size_t i = 0; i < rank; ++i) { + offsets[i] = 0; + extents[i] = out_dims[i]; + } + for (size_t i = 0; i < axes.size(); ++i) { + int axis = axes[i]; + int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; + start = std::max(start, static_cast(0)); + offsets[axis] = start; + } + + std::vector paddings; + for (size_t i = 0; i < rank; ++i) { + paddings.push_back(offsets[i]); + paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]); + } + if (decrease_size > 0 && + (decrease_size != static_cast(in_dims.size()))) { + auto out_tmp = + pad(reshape(out_grad, origin_out_shape), paddings, 0.0); + set_output(out_tmp, input_grad); + } else { + auto out_tmp = pad(out_grad, paddings, 0.0); + set_output(out_tmp, input_grad); + } + } +} + +template +void tile_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& repeat_times, + Tensor* x_grad) { + if (x_grad) { + auto repeat_times_data = repeat_times.GetData(); + auto out_grad_shape = phi::vectorize(out_grad.dims()); + auto result = out_grad; + for (int i = 0; i < static_cast(repeat_times_data.size()); i++) { + int size = out_grad_shape[i] / repeat_times_data[i]; + std::vector sections(repeat_times_data[i], size); + auto split_arr = split(result, IntArray(sections), i); + result = full(phi::vectorize(split_arr[0].dims()), 0.0, x.dtype()); + for (int j = 0; j < static_cast(split_arr.size()); j++) { + result = split_arr[j] + result; + } + } + result = reshape(result, x.shape()); + set_output(result, x_grad); + } +} + } // namespace details } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index f32fe6f592218d..475b8d3d76074a 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -363,6 +363,10 @@ void BindValue(py::module *m) { .def("first_use", &Value::first_use, return_value_policy::reference) .def("has_one_use", &Value::HasOneUse) .def("use_empty", &Value::use_empty) + .def("replace_all_uses_with", + [](Value &self, Value &op_value) { + self.ReplaceAllUsesWith(op_value); + }) .def("__eq__", &Value::operator==) .def("__eq__", [](Value &self, OpResult &other) { @@ -610,6 +614,10 @@ void BindOpResult(py::module *m) { return false; } }) + .def("replace_all_uses_with", + [](OpResult &self, OpResult &op_result) { + self.ReplaceAllUsesWith(op_result); + }) .def_property( "stop_gradient", [](OpResult &self) { diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index d95bc19c57bff2..73e508434697c3 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -747,7 +747,7 @@ kernel : func : tile_grad no_need_buffer : x - composite : tile_grad(x, outgrad, repeat_times, x_grad) + composite : tile_grad(x, out_grad, repeat_times, x_grad) backward : tile_double_grad - backward_op : trans_layout_grad diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 97a315c1010566..ddfba7a22b12bf 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import logging from collections.abc import Sequence import paddle.pir @@ -556,7 +557,7 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): if state.value_to_valuegrad[item] != []: outputs_set.add(state.value_to_valuegrad[item][0][0]) else: - raise ValueError("input privided by inputs has no use") + logging.warning("input privided by inputs has no use") inputs_set = set() for output in outputs: diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index bfb9b6e9ba2c66..3eafd915b254a1 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -225,14 +225,18 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): input_args = _prepare_python_api_arguments(op) pir.set_insertion_point(op) orig_outs = op.results() + if op.name() == "pd_op.stack": + input_args += (op.attrs()["axis"],) new_outs = _build_tensor_tuple(decom_rule(*input_args)) # Todo: To cover such case: some outputs are no longer needed after decomposition. _check_op_results( op_name, orig_outs, new_outs, orig_vars, dst_vars ) - - op.replace_all_uses_with(new_outs) + if op.name() in ("pd_op.unsqueeze", "pd_op.squeeze"): + orig_outs[0].replace_all_uses_with(new_outs[0]) + else: + op.replace_all_uses_with(new_outs) block.remove_op(op) if temp_op is not None: diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index 924ccf1756b0e0..6e59f9858e74a7 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -56,15 +56,32 @@ def gelu(x, approximate): tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) out = x * half * (one + tanh_out) return out - else: # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - cdf = half * (one + _pir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype))) out = x * cdf return out +@register_decomp('pd_op.sqrt') +def sqrt(x): + """ + define composite rule of op sqrt + res = pow(x, 0.5) + """ + is_amp = False + from paddle.base.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + + y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype) + res = pow_composite(x, y) + return res if not is_amp else cast(res, dtype) + + @register_decomp('pd_op.rsqrt') def rsqrt(x): """define composite rule of op rsqrt.""" @@ -211,3 +228,120 @@ def add_n(x): for xi in x[1:]: ans = xi + ans return ans + + +@register_decomp('pd_op.silu') +def silu(x): + """ + define composite rule of op silu + res = x / (1 + exp(-x)) + """ + is_amp = False + from paddle.base.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + + sum_temp = exp(-x) + 1 + res = x / sum_temp + return res if not is_amp else cast(res, dtype) + + +@register_decomp('pd_op.softmax') +def softmax(x, axis): + """define composite rule of op softmax""" + is_amp = False + from paddle.base.data_feeder import convert_dtype + + # Softmax need fp32 compute since it has sum op in + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + if not x.shape: + # do not return 1, to ensure gradients + res = exp(x - x) + if is_amp: + res = cast(res, "float16") + return res + max_temp = max(x, axis, keepdim=True) + max_temp.stop_gradient = True + molecular = exp(x - max_temp) + denominator = sum(molecular, axis=axis, keepdim=True) + res = divide(molecular, denominator) + if is_amp: + res = cast(res, dtype) + return res + + +@register_decomp('pd_op.full_like') +def full_like(x, fill_value, dtype, place=None): + """define composite rule of op full_like.""" + """op name: full_like op type name: fill_any_like.""" + """arg place is not used, add it here to keep same as python api.""" + fill_value = fill_value.get_defining_op().attrs()["value"] + val = full(x.shape, fill_value, dtype) + return val + + +@register_decomp('pd_op.stack') +def stack(x, axis): + """ + define composite rule of op stack + unsqueeze each dimension of the input (use reshape), and then concat + """ + x_shape = x[0].shape + if axis < 0: + axis += len(x_shape) + 1 + out_shape = x_shape[:axis] + [1] + x_shape[axis:] + out = concat([reshape(item, out_shape) for item in x], axis) + return out + + +@register_decomp('pd_op.squeeze') +def squeeze(x, axis): + """define composite rule of squeeze""" + """ + canonicalize dim within range 0 to rank and + determine new shape after squeeze op + if axis not specified, remove all dims equal to 1 + otherwise, remove dims equal to 1 in axis + axis can only be list, not int + """ + axis = axis.get_defining_op().attrs()["value"] + rank = len(x.shape) + if rank == 0: + return [assign(x), None] + if len(axis) == 0: + dims = set(range(rank)) + else: + dims = {ax % rank for ax in axis} + new_shape = [] + for d, s in enumerate(x.shape): + if not (s == 1 and (d in dims)): + new_shape.append(s) + out = reshape(x, new_shape) + return [out, None] + + +@register_decomp('pd_op.unsqueeze') +def unsqueeze(x, axis): + """define composite rule of op unsqueeze""" + """using reshape to implement unsqueeze op""" + axis = axis.get_defining_op().attrs()["value"] + x_shape = list(x.shape) + axis_list = list(axis) + for i in axis_list: + if i < 0: + i += len(x_shape) + 1 + x_shape = ( + x_shape[:i] + + [ + 1, + ] + + x_shape[i:] + ) + out = reshape(x, x_shape) + return [out, None] diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 12fd5ae8f09a5a..4488f9f177610c 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -130,7 +130,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def init_dtype(self): self.dtype = np.float32 @@ -501,14 +507,25 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_new_ir=True) + if self.dtype == np.complex64 or self.dtype == np.complex128: + self.check_output(check_new_ir=True) + else: + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) def test_check_grad(self): # TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype if self.dtype == np.complex64 or self.dtype == np.complex128: - self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True) + self.check_grad(['X'], 'Out', check_new_ir=True) else: - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestSilu_ZeroDim(TestSilu): @@ -1439,7 +1456,7 @@ def test_errors(self): class TestSqrt(TestActivation, TestParameter): def setUp(self): self.op_type = "sqrt" - self.prim_op_type = "prim" + self.prim_op_type = "comp" self.python_api = paddle.sqrt self.public_python_api = paddle.sqrt @@ -1461,16 +1478,24 @@ def if_enable_cinn(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) class TestSqrtPrimFp32(TestActivation): def setUp(self): self.op_type = "sqrt" - self.prim_op_type = "prim" + self.prim_op_type = "comp" self.python_api = paddle.sqrt self.public_python_api = paddle.sqrt self.init_dtype() @@ -1486,10 +1511,16 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_new_ir=True, check_prim_pir=True) def init_dtype(self): self.dtype = np.float32 @@ -1510,7 +1541,7 @@ def init_shape(self): class TestSqrtBF16(OpTest): def setUp(self): self.op_type = "sqrt" - self.prim_op_type = "prim" + self.prim_op_type = "comp" self.python_api = paddle.sqrt self.public_python_api = paddle.sqrt self.init_dtype() @@ -1537,12 +1568,19 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place( + place, check_new_ir=True, check_prim_pir=True + ) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -1571,12 +1609,20 @@ def test_check_grad(self): if self.dtype == np.float16: return self.check_grad( - ['X'], 'Out', check_dygraph=True, check_prim=True, check_new_ir=True + ['X'], + 'Out', + check_dygraph=True, + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) def test_check_output(self): self.check_output( - check_dygraph=True, check_prim=True, check_new_ir=True + check_dygraph=True, + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -1603,12 +1649,20 @@ def test_check_grad(self): if self.dtype == np.float16: return self.check_grad( - ['X'], 'Out', check_dygraph=True, check_prim=True, check_new_ir=True + ['X'], + 'Out', + check_dygraph=True, + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) def test_check_output(self): self.check_output( - check_dygraph=True, check_prim=True, check_new_ir=True + check_dygraph=True, + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) def init_dtype(self): @@ -3266,7 +3320,13 @@ def test_check_output(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class Test_Log_Op_Fp16(unittest.TestCase): @@ -4576,10 +4636,14 @@ def test_check_grad(self): create_test_act_fp16_class(TestActivation) -create_test_act_fp16_class(TestExpFp32_Prim, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestExpFp32_Prim, check_prim=True, enable_cinn=True, check_prim_pir=True +) create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True) -create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True +) create_test_act_fp16_class(TestLogSigmoid) create_test_act_fp16_class( TestTanh, check_prim=True, check_prim_pir=True, enable_cinn=True @@ -4588,10 +4652,18 @@ def test_check_grad(self): create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class( - TestSqrt, check_prim=True, enable_cinn=True, check_new_ir=True + TestSqrt, + check_prim=True, + enable_cinn=True, + check_new_ir=True, + check_prim_pir=True, ) create_test_act_fp16_class( - TestSqrtComp, check_prim=True, enable_cinn=True, check_new_ir=True + TestSqrtComp, + check_prim=True, + enable_cinn=True, + check_new_ir=True, + check_prim_pir=True, ) create_test_act_fp16_class( TestAbs, check_prim=True, enable_cinn=True, check_new_ir=True @@ -4730,17 +4802,23 @@ def test_check_grad(self): create_test_act_bf16_class(TestActivation) -create_test_act_bf16_class(TestExpFp32_Prim, check_prim=True) +create_test_act_bf16_class( + TestExpFp32_Prim, check_prim=True, check_prim_pir=True +) create_test_act_bf16_class(TestExpm1) create_test_act_bf16_class(TestSigmoid, check_prim=True) -create_test_act_bf16_class(TestSilu, check_prim=True) +create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestLogSigmoid) -create_test_act_bf16_class(TestTanh, check_prim=True) +create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestTanhshrink) create_test_act_bf16_class(TestHardShrink) create_test_act_bf16_class(TestSoftshrink) -create_test_act_bf16_class(TestSqrt, check_prim=True, check_new_ir=True) -create_test_act_bf16_class(TestSqrtComp, check_prim=True, check_new_ir=True) +create_test_act_bf16_class( + TestSqrt, check_prim=True, check_new_ir=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestSqrtComp, check_prim=True, check_new_ir=True, check_prim_pir=True +) create_test_act_bf16_class(TestAbs, check_prim=True, check_new_ir=True) create_test_act_bf16_class(TestCeil, grad_check=False, check_new_ir=True) create_test_act_bf16_class(TestFloor, grad_check=False, check_prim=True) diff --git a/test/legacy_test/test_dropout_op.py b/test/legacy_test/test_dropout_op.py index f65e4d2b4b855b..2a2e13a8b984f6 100644 --- a/test/legacy_test/test_dropout_op.py +++ b/test/legacy_test/test_dropout_op.py @@ -2098,7 +2098,9 @@ def test_static_comp(self): ) core._set_prim_forward_enabled(True) - [output] = decompose(mp, [output]) # decompose forward + [output] = decompose( + mp, [output], whitelist={"pd_op.dropout"} + ) # decompose forward self.assertTrue( 'pd_op.dropout' not in [op.name() for op in mp.global_block().ops] diff --git a/test/legacy_test/test_erf_op.py b/test/legacy_test/test_erf_op.py index 24f32175151d65..36d785118686f0 100644 --- a/test/legacy_test/test_erf_op.py +++ b/test/legacy_test/test_erf_op.py @@ -49,6 +49,14 @@ def test_check_output(self): def test_check_grad(self): self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + def test_check_grad_prim_pir(self): + # Todo(CZ): float64 loss greater than 1e-8 + if self.dtype == "float64": + self.dtype = "float32" + self.rev_comp_atol = 1e-7 + self.rev_comp_rtol = 1e-7 + self.check_grad(['X'], 'Out', check_prim_pir=True) + class TestErfOp_ZeroDim(TestErfOp): def init_shape(self): @@ -96,7 +104,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) @unittest.skipIf( @@ -126,7 +140,12 @@ def test_check_output(self): def test_check_grad(self): place = paddle.base.core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_expand_v2_op.py b/test/legacy_test/test_expand_v2_op.py index f7ba37fb60cbb4..567c3acd2ef7aa 100644 --- a/test/legacy_test/test_expand_v2_op.py +++ b/test/legacy_test/test_expand_v2_op.py @@ -50,7 +50,13 @@ def test_check_output(self): self.check_output(check_cinn=True, check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestExpandV2OpRank1_ZeroDim1(TestExpandV2OpRank1): @@ -244,7 +250,13 @@ def test_check_output(self): self.check_output(check_cinn=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) # Situation 8: input x is BF16 @@ -273,7 +285,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -438,7 +455,7 @@ def test_check_output(self): self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_prim_pir=True) class TestExpandV2OpCompRank2_DimExpanding(TestExpandV2CompOpRank1): diff --git a/test/legacy_test/test_full_like_op.py b/test/legacy_test/test_full_like_op.py index 137e536126bb46..5cbafdacb847d3 100644 --- a/test/legacy_test/test_full_like_op.py +++ b/test/legacy_test/test_full_like_op.py @@ -148,7 +148,9 @@ def init_data(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) def if_enable_cinn(self): pass diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index a10faff2ac1f35..d1be1eddbc5073 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -56,7 +56,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithEmptyIndexFP16(TestGatherNdOpWithEmptyIndex): @@ -80,7 +86,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -117,7 +128,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithIndex1_ZeroDim(TestGatherNdOpWithIndex1): @@ -168,7 +185,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -205,7 +227,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithLowIndexFP16(TestGatherNdOpWithLowIndex): @@ -235,6 +263,7 @@ def test_check_grad(self): check_prim=True, check_new_ir=True, numeric_grad_delta=0.5, + check_prim_pir=True, ) @@ -282,6 +311,7 @@ def test_check_grad(self): check_prim=True, check_new_ir=True, numeric_grad_delta=0.05, + check_prim_pir=True, ) @@ -312,6 +342,7 @@ def test_check_grad(self): check_prim=True, check_new_ir=True, numeric_grad_delta=0.5, + check_prim_pir=True, ) @@ -345,7 +376,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithSameIndexAsXFP16(TestGatherNdOpWithSameIndexAsX): @@ -375,6 +412,7 @@ def test_check_grad(self): check_prim=True, check_new_ir=True, numeric_grad_delta=0.5, + check_prim_pir=True, ) @@ -410,7 +448,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithHighRankSameFP16(TestGatherNdOpWithHighRankSame): @@ -434,7 +478,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -471,7 +520,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithHighRankDiffFP16(TestGatherNdOpWithHighRankDiff): @@ -495,7 +550,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_pad_op.py b/test/legacy_test/test_pad_op.py index 8054d7c75ffb11..951d895416ba9b 100644 --- a/test/legacy_test/test_pad_op.py +++ b/test/legacy_test/test_pad_op.py @@ -60,7 +60,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def initTestCase(self): self.shape = (16, 16) @@ -101,7 +107,13 @@ def get_dtype(self): return np.float16 def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) cls_name = "{}_{}".format(parent.__name__, "Fp16") TestPadFp16.__name__ = cls_name @@ -258,7 +270,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_slice_op.py b/test/legacy_test/test_slice_op.py index 065251b246928e..986b2bc2b01b2e 100644 --- a/test/legacy_test/test_slice_op.py +++ b/test/legacy_test/test_slice_op.py @@ -76,6 +76,7 @@ def test_check_grad_normal(self): max_relative_error=0.006, check_prim=True, check_new_ir=True, + check_prim_pir=True, ) @@ -166,6 +167,7 @@ def test_check_grad_normal(self): max_relative_error=0.006, check_prim=True, check_new_ir=True, + check_prim_pir=True, ) @@ -507,7 +509,7 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, check_prim=True, check_new_ir=True + place, check_prim=True, check_new_ir=True, check_prim_pir=True ) def test_check_grad_normal(self): @@ -520,6 +522,7 @@ def test_check_grad_normal(self): 'Out', check_prim=True, check_new_ir=True, + check_prim_pir=True, ) @@ -555,7 +558,7 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, check_prim=True, check_new_ir=True + place, check_prim=True, check_new_ir=True, check_prim_pir=True ) def test_check_grad_normal(self): @@ -568,6 +571,7 @@ def test_check_grad_normal(self): numeric_grad_delta=0.5, check_prim=True, check_new_ir=True, + check_prim_pir=True, ) @@ -600,7 +604,13 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad_normal(self): - self.check_grad(['Input'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['Input'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) # Test python API diff --git a/test/legacy_test/test_softmax_op.py b/test/legacy_test/test_softmax_op.py index e684daa695a23e..af65ee776c71db 100644 --- a/test/legacy_test/test_softmax_op.py +++ b/test/legacy_test/test_softmax_op.py @@ -84,9 +84,17 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode if self.use_cudnn: place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_new_ir=True) + self.check_output_with_place( + place, + atol=1e-5, + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) else: - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) def test_check_grad(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode @@ -100,6 +108,7 @@ def test_check_grad(self): max_relative_error=0.01, check_dygraph=(not self.use_mkldnn), check_new_ir=True, + check_prim_pir=True, ) else: self.check_grad( @@ -109,6 +118,7 @@ def test_check_grad(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_new_ir=True, + check_prim_pir=True, ) @@ -146,9 +156,36 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode if self.use_cudnn: place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_new_ir=True) + self.check_output_with_place( + place, atol=1e-5, check_new_ir=True, check_prim_pir=True + ) else: - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) + + def test_check_grad(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if self.use_cudnn or self.dtype == np.float16: + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, + ["X"], + "Out", + max_relative_error=0.01, + check_dygraph=(not self.use_mkldnn), + check_new_ir=True, + ) + else: + self.check_grad( + ["X"], + "Out", + max_relative_error=0.01, + check_dygraph=(not self.use_mkldnn), + check_prim=True, + check_new_ir=True, + ) @unittest.skipIf( @@ -158,6 +195,8 @@ class TestSoftmaxOp_ZeroDim2(TestSoftmaxOp): def setUp(self): self.op_type = "softmax" self.python_api = F.softmax + self.public_python_api = F.softmax + self.prim_op_type = "comp" self.use_cudnn = True self.use_mkldnn = False # explicilty use float32 for ROCm, as MIOpen does not yet support float64 @@ -180,9 +219,40 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode if self.use_cudnn: place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_new_ir=True) + self.check_output_with_place( + place, + check_prim=True, + atol=1e-5, + check_new_ir=True, + check_prim_pir=True, + ) else: - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) + + def test_check_grad(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if self.use_cudnn or self.dtype == np.float16: + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, + ["X"], + "Out", + max_relative_error=0.01, + check_dygraph=(not self.use_mkldnn), + check_new_ir=True, + ) + else: + self.check_grad( + ["X"], + "Out", + max_relative_error=0.01, + check_dygraph=(not self.use_mkldnn), + check_prim=True, + check_new_ir=True, + ) class TestSoftmaxOp2(TestSoftmaxOp): @@ -357,7 +427,11 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, atol=1e-3, check_new_ir=True + place, + atol=1e-3, + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) # FIXME: If the x_shape is [10, 10], gradient failed. @@ -386,7 +460,11 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, atol=1e-3, check_new_ir=True + place, + atol=1e-3, + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -438,6 +516,7 @@ def test_check_output(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_new_ir=(not self.use_mkldnn), + check_prim_pir=(not self.use_mkldnn), ) def test_check_grad(self): @@ -450,6 +529,7 @@ def test_check_grad(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_new_ir=(not self.use_mkldnn), + check_prim_pir=(not self.use_mkldnn), ) diff --git a/test/legacy_test/test_squeeze2_op.py b/test/legacy_test/test_squeeze2_op.py index 1ee72ad2a39e0b..f1138dd2460025 100755 --- a/test/legacy_test/test_squeeze2_op.py +++ b/test/legacy_test/test_squeeze2_op.py @@ -56,11 +56,20 @@ def if_enable_cinn(self): def test_check_output(self): self.check_output( - no_check_set=['XShape'], check_prim=True, check_new_ir=True + no_check_set=['XShape'], + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) def test_check_grad(self): - self.check_grad(["X"], "Out", check_prim=True, check_new_ir=True) + self.check_grad( + ["X"], + "Out", + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def init_dtype(self): self.dtype = np.float64 diff --git a/test/legacy_test/test_stack_op.py b/test/legacy_test/test_stack_op.py index 44abff4dafeb58..5057db55ed9846 100644 --- a/test/legacy_test/test_stack_op.py +++ b/test/legacy_test/test_stack_op.py @@ -63,11 +63,17 @@ def setUp(self): self.attrs = {'axis': self.axis} def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) def test_check_grad(self): self.check_grad( - self.get_x_names(), 'Y', check_prim=True, check_new_ir=True + self.get_x_names(), + 'Y', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -189,11 +195,17 @@ def setUp(self): self.attrs = {'axis': self.axis} def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) def test_check_grad(self): self.check_grad( - self.get_x_names(), 'Y', check_prim=True, check_new_ir=True + self.get_x_names(), + 'Y', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_tile_op.py b/test/legacy_test/test_tile_op.py index 40dc04b0537707..308167b31710b3 100644 --- a/test/legacy_test/test_tile_op.py +++ b/test/legacy_test/test_tile_op.py @@ -50,7 +50,13 @@ def test_check_output(self): self.check_output(check_cinn=self.check_cinn, check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestTileOpRank_ZeroDim1(TestTileOpRank1): @@ -265,7 +271,13 @@ def test_check_output(self): self.check_output(check_cinn=self.check_cinn, check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) @unittest.skipIf( @@ -305,7 +317,12 @@ def init_data(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_unsqueeze2_op.py b/test/legacy_test/test_unsqueeze2_op.py index 36fa88cb1035ac..289591285b7055 100755 --- a/test/legacy_test/test_unsqueeze2_op.py +++ b/test/legacy_test/test_unsqueeze2_op.py @@ -44,11 +44,20 @@ def if_enable_cinn(self): def test_check_output(self): self.check_output( - no_check_set=["XShape"], check_prim=True, check_new_ir=True + no_check_set=["XShape"], + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) def test_check_grad(self): - self.check_grad(["X"], "Out", check_prim=True, check_new_ir=True) + self.check_grad( + ["X"], + "Out", + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def init_test_case(self): self.ori_shape = (3, 40)