diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index f9a920730967d3..e0eeeb10a3a4d2 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -72,8 +72,20 @@ ] -PRIM_VJP = ['divide_grad', 'sum_grad'] # vjp list of primitive op -CUSTOM_VJP = ['gelu_grad'] # custom vjp list of composite op +PRIM_VJP = [ + 'divide_grad', + 'sum_grad', + 'cast_grad', + 'add_grad', + 'multiply_grad', + 'elementwise_pow_grad', + 'reshape_grad', + 'split_grad', + 'tanh_grad', + 'transpose_grad', + 'concat_grad', +] # vjp list of primitive op +CUSTOM_VJP = ['gelu_grad', 'layer_norm_grad'] # custom vjp list of composite op VJP_COMPS = PRIM_VJP + CUSTOM_VJP BACKENDS = [ @@ -149,6 +161,8 @@ 'embedding_grad', 'sqrt', 'uniform', + 'split', + 'transpose', ] diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index 1ab275ceaecbf3..6737a73d69eb50 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -106,7 +106,7 @@ paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{ {% else %} std::vector {{api.outputs[i].name}}(stop_gradients[{{i}}].size(), nullptr); for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { - {{api.outputs[i].name}} = !stop_gradients[{{i}}][i] ? &vjp_res[{{i}}][i] : nullptr; + {{api.outputs[i].name}}[i] = !stop_gradients[{{i}}][i] ? &vjp_res[{{i}}][i] : nullptr; } {% endif %} {% endfor %} diff --git a/paddle/fluid/primitive/primitive.yaml b/paddle/fluid/primitive/primitive.yaml index a42e2503e31baf..ccf9673bafba07 100644 --- a/paddle/fluid/primitive/primitive.yaml +++ b/paddle/fluid/primitive/primitive.yaml @@ -49,3 +49,4 @@ - erf - tanh - full +- cast diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index eb640a4643ed3b..96b4d051b7cde5 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -134,32 +134,371 @@ void gelu_grad(const Tensor& x, // Promote to fp32 when the input type is fp16 for keeping consistent with // phi kernel - // Scale only support fp32 attr in static graph mode, use elementwise_xx - // when precision is over fp32. - if (approximate) { - auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; - auto kKappa = 0.044715; - auto x_sq = x * x; - auto x_cube = x_sq * x; - auto inner = kBeta * (x + kKappa * x_cube); - auto tanh_inner = tanh(inner); - - auto left = scale(x, 0.5); - auto right = scale(tanh_inner, 1., 1.); - - auto left_derivative = scale(right, 0.5); - - auto tanh_derivative = scale(tanh_inner * tanh_inner, -1., 1.); - auto inner_derivative = kBeta * (scale(3 * kKappa * x_sq, 1., 1.)); - auto right_derivative = left * tanh_derivative * inner_derivative; - - set_output(out_grad * (left_derivative + right_derivative), x_grad); + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + auto promoted_x = cast(x, phi::DataType::FLOAT32); + auto promoted_out_grad = cast(out_grad, phi::DataType::FLOAT32); + if (approximate) { + float kbeta = M_SQRT2 * M_2_SQRTPI * 0.5; + float kkappa = 0.044715; + auto x_sq = promoted_x * promoted_x; + auto x_cube = x_sq * promoted_x; + auto inner = kbeta * (promoted_x + kkappa * x_cube); + auto tanh_inner = tanh(inner); + + auto left = scale(promoted_x, 0.5); + auto right = scale(tanh_inner, 1., 1.); + + auto left_derivative = scale(right, 0.5); + + auto tanh_derivative = scale(tanh_inner * tanh_inner, -1., 1.); + auto inner_derivative = kbeta * (scale(3 * kkappa * x_sq, 1., 1.)); + auto right_derivative = left * tanh_derivative * inner_derivative; + + set_output( + cast(promoted_out_grad * (left_derivative + right_derivative), + x.type()), + x_grad); + } else { + float kalpha = M_SQRT1_2; + float kbeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; + auto cdf = scale(scale(erf(kalpha * promoted_x), 1., 1.), 0.5); + auto pdf = kbeta * exp(scale(promoted_x * promoted_x, -0.5)); + set_output( + cast(promoted_out_grad * (cdf + promoted_x * pdf), x.type()), + x_grad); + } } else { - auto kAlpha = M_SQRT1_2; - auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; - auto cdf = scale(scale(erf(kAlpha * x), 1., 1.), 0.5); - auto pdf = kBeta * exp(scale(x * x, -0.5)); - set_output(out_grad * (cdf + x * pdf), x_grad); + // Scale only support fp32 attr in static graph mode, use elementwise_xx + // when precision is over fp32. + if (approximate) { + auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; + auto kKappa = 0.044715; + auto x_sq = x * x; + auto x_cube = x_sq * x; + auto inner = kBeta * (x + kKappa * x_cube); + auto tanh_inner = tanh(inner); + + auto left = scale(x, 0.5); + auto right = scale(tanh_inner, 1., 1.); + + auto left_derivative = scale(right, 0.5); + + auto tanh_derivative = scale(tanh_inner * tanh_inner, -1., 1.); + auto inner_derivative = kBeta * (scale(3 * kKappa * x_sq, 1., 1.)); + auto right_derivative = left * tanh_derivative * inner_derivative; + + set_output(out_grad * (left_derivative + right_derivative), x_grad); + } else { + auto kAlpha = M_SQRT1_2; + auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; + auto cdf = scale(scale(erf(kAlpha * x), 1., 1.), 0.5); + auto pdf = kBeta * exp(scale(x * x, -0.5)); + set_output(out_grad * (cdf + x * pdf), x_grad); + } + } +} + +template +void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) { + if (grad_x) { + auto grad_x_tmp = reshape(grad_out, phi::vectorize(x.dims())); + set_output(grad_x_tmp, grad_x); + } +} + +template +void transpose_grad(const Tensor& grad_out, + const std::vector& perm, + Tensor* grad_x) { + if (grad_x) { + std::vector reverse_perm(perm); + // make origin ranks + for (int i = 0; i < static_cast(perm.size()); ++i) { + if (perm[i] >= 0) { + reverse_perm[perm[i]] = i; + } else { + reverse_perm[perm[i] + perm.size()] = i; + } + } + auto grad_x_tmp = transpose(grad_out, reverse_perm); + set_output(grad_x_tmp, grad_x); + } +} + +template +void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { + if (!grad_x) return; + auto grad_x_tmp = grad_out * (1 - out * out); + set_output(grad_x_tmp, grad_x); +} + +template +void concat_grad(const std::vector& x, + const Tensor& out_grad, + const Scalar& axis, + std::vector x_grad) { + int axis_value = axis.to(); + int rank = x[0].dims().size(); + if (axis_value < 0) { + axis_value = axis_value + rank; + } + axis_value = axis_value > 0 ? axis_value : 0; + std::vector sections; + int x_num = x.size(); + for (int i = 0; i < x_num; ++i) { + sections.push_back(x[i].dims()[axis_value]); + } + std::vector x_grad_tmp = + split(out_grad, IntArray(sections), axis_value); + for (int i = 0; i < x_num; ++i) { + if (x_grad[i]) { + set_output(x_grad_tmp.at(i), x_grad.at(i)); + } + } +} + +template +void split_grad(const std::vector& out_grad, + const Scalar& axis, + Tensor* x_grad) { + if (x_grad) { + auto grad = concat(out_grad, axis); + set_output(grad, x_grad); + } +} + +template +void cast_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto res = cast(out_grad, x.dtype()); + set_output(res, x_grad); + } +} + +template +void add_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* dx, + Tensor* dy) { + if (dy) { + if (x.dims() != y.dims()) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + set_output(out_grad, dy); + } else { + auto dy_reduce_res = + out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + set_output(dy_tmp, dy); + } + + } else { + set_output(out_grad, dy); + } + } + if (dx) { + if (y.dims() != x.dims()) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + set_output(out_grad, dx); + } else { + auto dx_reduce_res = + out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + set_output(dx_tmp, dx); + } + } else { + set_output(out_grad, dx); + } + } +} + +template +void multiply_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* x_grad, + Tensor* y_grad) { + if (x_grad) { + auto x_grad_unreduce = out_grad * y; + if (x_grad_unreduce.dims() != x.dims()) { + auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims()); + if (!axes.size()) { + set_output(x_grad_unreduce, x_grad); + } else { + auto x_grad_reduced = x_grad_unreduce.sum( + phi::vectorize(axes), x_grad_unreduce.dtype(), false); + if (x_grad_reduced.dims().size() != x.dims().size()) { + x_grad_reduced = reshape(x_grad_reduced, x.shape()); + } + set_output(x_grad_reduced, x_grad); + } + } else { + set_output(x_grad_unreduce, x_grad); + } + } + if (y_grad) { + auto y_grad_unreduce = out_grad * x; + if (y_grad_unreduce.dims() != y.dims()) { + auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims()); + if (!axes.size()) { + set_output(y_grad_unreduce, y_grad); + } else { + auto y_grad_reduced = y_grad_unreduce.sum( + phi::vectorize(axes), y_grad_unreduce.dtype(), false); + if (y_grad_reduced.dims().size() != y.dims().size()) { + y_grad_reduced = reshape(y_grad_reduced, y.shape()); + } + set_output(y_grad_reduced, y_grad); + } + } else { + set_output(y_grad_unreduce, y_grad); + } + } +} + +template +void elementwise_pow_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + Tensor* dx, + Tensor* dy) { + if (dy) { + // dy = lnx * x^y + auto lnx = log(x); + auto x_pow_y = elementwise_pow(x, y); + auto dy_res = lnx * x_pow_y * out_grad; + if (x.dims() != y.dims()) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + set_output(dy_res, dy); + } else { + auto dy_reduce_res = + dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + set_output(dy_tmp, dy); + } + } else { + set_output(dy_res, dy); + } + } // indicate we will compute dy + if (dx) { + // dx = y * x^(y-1) + auto tmp_z = y - 1.0; + auto x_pow_z = elementwise_pow(x, tmp_z); + auto dx_res = y * x_pow_z * out_grad; + if (y.dims() != x.dims()) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + set_output(dx_res, dx); + } else { + auto dx_reduce_res = + dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + set_output(dx_tmp, dx); + } + + } else { + set_output(dx_res, dx); + } + } // indicate we will compute dx +} + +template +void layer_norm_grad(const Tensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + const Tensor& mean, + const Tensor& variance, + const Tensor& out_grad, + float epsilon, + int begin_norm_axis, + Tensor* x_grad, + Tensor* scale_grad, + Tensor* bias_grad) { + auto x_dims = x.dims(); + auto shape_1 = 1; // front part + auto shape_2 = 1; // back part + for (int i = 0; i < begin_norm_axis; ++i) { + shape_1 *= x_dims[i]; + } + for (int i = begin_norm_axis; i < x.dims().size(); ++i) { + shape_2 *= x_dims[i]; + } + auto scale_ptr = scale.get_ptr(); + auto bias_ptr = bias.get_ptr(); + + auto x_cast = reshape(x, std::vector({shape_1, shape_2})); + auto out_grad_cast = + reshape(out_grad, std::vector({shape_1, shape_2})); + auto mean_ = reshape(mean, std::vector({shape_1, 1})); + auto variance_ = reshape(variance, std::vector({shape_1, 1})); + + Tensor scale_cast; + if (scale_ptr) { + scale_cast = reshape(*scale_ptr, std::vector({1, shape_2})); + } + + // cast dtype to float32 if dtype =float16 or bfloat16 + + auto x_sub_mean = x_cast - mean_; // M,N + auto tmp = (1.0 / (variance_ + epsilon)); // M,1 + // auto sqrt_var_1 = sqrt(tmp); // M,1 + auto sqrt_var_1 = elementwise_pow( + tmp, full(phi::vectorize(tmp.dims()), 0.5, tmp.dtype())); + auto x_sub_mean_mul_sqrt_var_1 = x_sub_mean * sqrt_var_1; + + if (x_grad) { + auto out_grad_scale = out_grad_cast; // M,N + if (scale_ptr) { + out_grad_scale = out_grad_cast * scale_cast; // M,N * 1,N = M,N + } + + auto dx_end = sqrt_var_1 * out_grad_scale; + auto d_mean = + dx_end.sum(std::vector({1}), x_cast.dtype(), true); // M,1 + + auto d_std_1 = + (tmp * x_sub_mean * out_grad_scale) + .sum(std::vector({1}), x_cast.dtype(), true); // M,1 + auto d_std = d_std_1 * x_sub_mean_mul_sqrt_var_1; // M,1 * M,N = M,N + + auto d_mean_d_std = (1.0 / shape_2) * (d_mean + d_std); + auto x_grad_tmp = dx_end - d_mean_d_std; + x_grad_tmp = reshape(x_grad_tmp, phi::vectorize(x.dims())); + + set_output(x_grad_tmp, x_grad); + } + + if (scale_grad) { + if (scale_ptr) { + auto scale_grad_tmp = + (x_sub_mean_mul_sqrt_var_1 * out_grad_cast) + .sum(std::vector({0}), x_cast.dtype(), true); + scale_grad_tmp = reshape(scale_grad_tmp, scale_ptr->shape()); + set_output(scale_grad_tmp, scale_grad); + } else { + scale_grad = nullptr; + } + } + + if (bias_grad) { + if (bias_ptr) { + auto bias_grad_tmp = + out_grad_cast.sum(std::vector({0}), x_cast.dtype(), true); + bias_grad_tmp = reshape(bias_grad_tmp, bias_ptr->shape()); + set_output(bias_grad_tmp, bias_grad); + } else { + bias_grad = nullptr; + } } } diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 22fd0f40a36b5c..80ecad93997db8 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -464,6 +464,10 @@ void BindOpResult(py::module *m) { [](OpResult &self, OpResult &other) { return paddle::dialect::add(self, other); }) + .def("__add__", + [](OpResult &self, float &bias) { + return paddle::dialect::scale(self, 1.0, bias, false); + }) .def("__sub__", [](OpResult &self, OpResult &other) { return paddle::dialect::subtract(self, other); diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index a8260bb8168652..9b5db92c547002 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -224,7 +224,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param: [x, y] - composite : elementwise_pow_grad(x, y, out_grad, axis, x_grad, y_grad) + composite : elementwise_pow_grad(x, y, out_grad, x_grad, y_grad) kernel : func : elementwise_pow_grad diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index e33c3a38bff749..f8a2aae71b0cd7 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -94,7 +94,6 @@ def prepare_grad_outputs(grad_outputs, outputs, state): dtype=output.dtype, ) fillop = output_grad.get_defining_op() - update_bwdop_structure( backward_ops, state.op_to_opgrad[output.get_defining_op()], @@ -138,14 +137,14 @@ def prepare_grad_outputs(grad_outputs, outputs, state): 0.0, opresult.dtype, ) - fillop = grad.get_defining_op() + fillop = grad_value.get_defining_op() update_bwdop_structure( backward_ops, state.op_to_opgrad[opresult.get_defining_op()], fillop, ) - state.value_to_valuegrad[opresult] = [grad_value] + state.value_to_valuegrad[opresult] = [[grad_value]] visited_output.add(opresult) diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index e9d04ede061ce9..26a4ae73debd01 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -63,3 +63,83 @@ def gelu_composite(x, approximate): cdf = half * (one + _ir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype))) out = x * cdf return out + + +@register_decomp('pd_op.rsqrt') +def rsqrt_composite(x): + """define composite rule of op rsqrt.""" + # rsqrt(x) = x^(-0.5) + is_amp = False + from paddle.base.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype) + res = pow(x, y) + return res if not is_amp else cast(res, dtype) + + +@register_decomp('pd_op.pow') +def pow_composite(x, y): + """ + define composite rule of op pow + res = x^y + """ + is_amp = False + from paddle.base.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + + if isinstance(y, (int, float)): + y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype) + res = pow(x, y) + if is_amp: + res = cast(res, dtype) + return res + + +@register_decomp('pd_op.layer_norm') +def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): + """ + define composite rule of op layer_norm + out = (x - mean(x)) / sqrt(var + epsilon)) + var = mean((x-mean(x))^2) + """ + is_amp = False + from paddle.base.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + scale = cast(scale, "float32") if scale else scale + bias = cast(bias, "float32") if bias else bias + + axis = tuple(range(begin_norm_axis, len(x.shape))) + mean_ = mean(x, axis=axis, keepdim=True) + difference = x - mean_ + var_tmp1 = difference * difference + variance = mean(var_tmp1, axis=axis, keepdim=True) + var_tmp3 = variance + epsilon + rsqrt_var = rsqrt(var_tmp3) + out = difference * rsqrt_var + + if scale is not None: + if x.shape[begin_norm_axis:] != scale.shape: + scale = reshape(scale, x.shape[begin_norm_axis:]) + out = out * scale + if bias is not None: + if x.shape[begin_norm_axis:] != bias.shape: + bias = reshape(bias, x.shape[begin_norm_axis:]) + out = out + bias + + mean_ = reshape(mean_, [-1]) + variance = reshape(variance, [-1]) + if is_amp: + out = cast(out, dtype) + return out, mean_, variance diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c3e814cc906d4d..f764fbb45996d2 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -833,8 +833,7 @@ def full_like(x, fill_value, dtype=None, name=None): if in_dynamic_mode(): return _C_ops.full_like(x, fill_value, dtype, x.place) elif in_pir_mode(): - place = _current_expected_place() - return _C_ops.full_like(x, fill_value, dtype, place) + return _C_ops.full_like(x, fill_value, dtype, core.Place()) else: helper = LayerHelper("full_like", **locals()) check_variable_and_dtype( @@ -881,7 +880,11 @@ def full_like(x, fill_value, dtype=None, name=None): def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): if in_dynamic_or_pir_mode(): - place = _current_expected_place() + place = ( + _current_expected_place() + if not in_pir_mode() + else paddle.base.core.Place() + ) if force_cpu: place = core.CPUPlace() if isinstance(shape, (list, tuple)): diff --git a/test/legacy_test/prim_op_test.py b/test/legacy_test/prim_op_test.py index e472c70813c737..f28957cdc89bec 100644 --- a/test/legacy_test/prim_op_test.py +++ b/test/legacy_test/prim_op_test.py @@ -22,7 +22,8 @@ import paddle from paddle.autograd.ir_backward import grad as ir_grad -from paddle.base import core +from paddle.base import Scope, core +from paddle.base.executor import scope_guard from paddle.base.framework import ( OpProtoHolder, _dygraph_tracer, @@ -409,7 +410,8 @@ def check(self): self.check_jit_comp_with_cinn() else: if self.enable_check_static_comp: - self.check_static_comp() + with scope_guard(Scope()): + self.check_static_comp() def get_kernel_sig(self): with dygraph_guard(): @@ -870,7 +872,8 @@ def check(self): self.check_jit_comp_with_cinn() else: if self.enable_check_static_comp: - self.check_static_comp() + with scope_guard(Scope()): + self.check_static_comp() def get_output_dict(self, np_outputs, api_outputs, outputs_sig): assert len(api_outputs) <= len(outputs_sig), ( diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 8b16ee5750eacb..8d1ee1ac5091ad 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -693,9 +693,21 @@ def test_check_grad(self): return # TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype if self.dtype == np.complex64 or self.dtype == np.complex128: - self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=False) + self.check_grad( + ['X'], + 'Out', + check_prim=False, + check_prim_pir=False, + check_new_ir=False, + ) else: - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def init_dtype(self): # TODO If dtype is float64, the output (Out) has diff at CPUPlace @@ -1615,7 +1627,9 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=True + ) def test_check_grad(self): if self.dtype == np.float16: @@ -1626,6 +1640,7 @@ def test_check_grad(self): max_relative_error=0.0005, check_prim=True, check_new_ir=True, + check_prim_pir=True, ) @@ -2480,12 +2495,22 @@ def setUp(self): self.cinn_atol = 1e-8 def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, + check_new_ir=True, + check_prim_pir=False, + ) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGelu(TestActivation): @@ -2518,12 +2543,20 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_new_ir=True, check_prim_pir=False + ) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestGelu_ZeroDim(TestGelu): @@ -3575,12 +3608,20 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_prim_pir=True, check_new_ir=True + ) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, + ) class TestPow_ZeroDim(TestPow): @@ -4397,6 +4438,7 @@ def create_test_act_fp16_class( grad_check=True, check_dygraph=True, check_prim=False, + check_prim_pir=False, enable_cinn=False, grad_atol=1e-2, **kwargs @@ -4425,6 +4467,7 @@ def test_check_output(self): atol=atol, check_dygraph=check_dygraph, check_prim=check_prim, + check_prim_pir=check_prim_pir, ) def test_check_grad(self): @@ -4437,6 +4480,7 @@ def test_check_grad(self): 'Out', check_dygraph=check_dygraph, check_prim=check_prim, + check_prim_pir=check_prim_pir, max_relative_error=grad_atol, ) @@ -4451,7 +4495,9 @@ def test_check_grad(self): create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestLogSigmoid) -create_test_act_fp16_class(TestTanh, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestTanh, check_prim=True, check_prim_pir=True, enable_cinn=True +) create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestSoftshrink) @@ -4478,6 +4524,7 @@ def test_check_grad(self): create_test_act_fp16_class( TestGelu, check_prim=True, + check_prim_pir=True, check_new_ir=True, enable_cinn=True, rev_comp_rtol=1e-3, @@ -4499,7 +4546,7 @@ def test_check_grad(self): create_test_act_fp16_class(TestLog10) create_test_act_fp16_class(TestLog1p) create_test_act_fp16_class(TestSquare) -create_test_act_fp16_class(TestPow, check_prim=True) +create_test_act_fp16_class(TestPow, check_prim=True, check_prim_pir=True) create_test_act_fp16_class(TestPow_API) create_test_act_fp16_class(TestSTanh) create_test_act_fp16_class(TestSoftplus) @@ -4521,7 +4568,11 @@ def test_check_grad(self): ) create_test_act_fp16_class(TestLeakyRelu_ZeroDim, check_prim=True) create_test_act_fp16_class( - TestRsqrt, check_prim=True, enable_cinn=True, check_new_ir=True + TestRsqrt, + check_prim=True, + enable_cinn=True, + check_new_ir=True, + check_prim_pir=True, ) @@ -4645,7 +4696,9 @@ def test_check_grad(self): create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True) create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True) create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True) -create_test_act_bf16_class(TestRsqrt, check_prim=True, check_new_ir=True) +create_test_act_bf16_class( + TestRsqrt, check_prim=True, check_new_ir=True, check_prim_pir=True +) if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_cast_op.py b/test/legacy_test/test_cast_op.py index 47bc23d76f601b..448629431d0b1a 100644 --- a/test/legacy_test/test_cast_op.py +++ b/test/legacy_test/test_cast_op.py @@ -52,10 +52,16 @@ def init_shapes(self): self.input_shape = [10, 10] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_prim_pir=True, check_new_ir=True) def test_grad(self): - self.check_grad(['X'], ['Out'], check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + ['Out'], + check_prim=True, + check_prim_pir=True, + check_new_ir=True, + ) class TestCastOpFp32ToFp64_ZeroDim(TestCastOpFp32ToFp64): diff --git a/test/legacy_test/test_concat_op.py b/test/legacy_test/test_concat_op.py index dc9702beeb0146..153e1cc06d3085 100644 --- a/test/legacy_test/test_concat_op.py +++ b/test/legacy_test/test_concat_op.py @@ -61,18 +61,51 @@ def test_check_grad(self): if self.dtype == np.uint16: place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['x0'], 'Out', check_prim=True, check_new_ir=True + place, + ['x0'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) self.check_grad_with_place( - place, ['x1'], 'Out', check_prim=True, check_new_ir=True + place, + ['x1'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) self.check_grad_with_place( - place, ['x2'], 'Out', check_prim=True, check_new_ir=True + place, + ['x2'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, ) else: - self.check_grad(['x0'], 'Out', check_prim=True, check_new_ir=True) - self.check_grad(['x1'], 'Out', check_prim=True, check_new_ir=True) - self.check_grad(['x2'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['x0'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) + self.check_grad( + ['x1'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) + self.check_grad( + ['x2'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def init_test_data(self): if self.dtype == np.uint16: @@ -213,9 +246,27 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['x0'], 'Out', check_prim=True, check_new_ir=True) - self.check_grad(['x1'], 'Out', check_prim=True, check_new_ir=True) - self.check_grad(['x2'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['x0'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) + self.check_grad( + ['x1'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) + self.check_grad( + ['x2'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def init_test_data(self): if self.dtype == np.uint16: @@ -301,8 +352,10 @@ def create_test_fp16(parent): class TestConcatFp16(parent): def setUp(self): self.op_type = "concat" + self.prim_op_type = "prim" self.python_api = paddle.concat self.public_python_api = paddle.concat + self.enable_cinn = False self.dtype = self.get_dtype() self.init_test_data() self.inputs = { @@ -332,18 +385,51 @@ def test_check_grad(self): if self.dtype == np.uint16: place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['x0'], 'Out', check_new_ir=True + place, + ['x0'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, ) self.check_grad_with_place( - place, ['x1'], 'Out', check_new_ir=True + place, + ['x1'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, ) self.check_grad_with_place( - place, ['x2'], 'Out', check_new_ir=True + place, + ['x2'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, ) else: - self.check_grad(['x0'], 'Out', check_new_ir=True) - self.check_grad(['x1'], 'Out', check_new_ir=True) - self.check_grad(['x2'], 'Out', check_new_ir=True) + self.check_grad( + ['x0'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, + ) + self.check_grad( + ['x1'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, + ) + self.check_grad( + ['x2'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, + ) def get_dtype(self): return np.float16 @@ -371,6 +457,7 @@ def create_test_bf16(parent): class TestConcatBf16(parent): def setUp(self): self.op_type = "concat" + self.prim_op_type = "prim" self.python_api = paddle.concat self.public_python_api = paddle.concat self.enable_cinn = False @@ -403,18 +490,51 @@ def test_check_grad(self): if self.dtype == np.uint16: place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['x0'], 'Out', check_new_ir=True + place, + ['x0'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, ) self.check_grad_with_place( - place, ['x1'], 'Out', check_new_ir=True + place, + ['x1'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, ) self.check_grad_with_place( - place, ['x2'], 'Out', check_new_ir=True + place, + ['x2'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, ) else: - self.check_grad(['x0'], 'Out', check_new_ir=True) - self.check_grad(['x1'], 'Out', check_new_ir=True) - self.check_grad(['x2'], 'Out', check_new_ir=True) + self.check_grad( + ['x0'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, + ) + self.check_grad( + ['x1'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, + ) + self.check_grad( + ['x2'], + 'Out', + check_new_ir=True, + check_prim=True, + check_prim_pir=True, + ) def get_dtype(self): return np.uint16 diff --git a/test/legacy_test/test_elementwise_add_op.py b/test/legacy_test/test_elementwise_add_op.py index 279d1997f160e3..6d4faa1e5df327 100644 --- a/test/legacy_test/test_elementwise_add_op.py +++ b/test/legacy_test/test_elementwise_add_op.py @@ -57,6 +57,7 @@ def test_check_output(self): self.check_output( check_dygraph=self.check_dygraph(), check_prim=self.check_prim, + check_prim_pir=self.check_dygraph(), check_new_ir=self.check_dygraph(), ) @@ -69,6 +70,7 @@ def test_check_grad_normal(self): 'Out', check_dygraph=self.check_dygraph(), check_prim=self.check_prim, + check_prim_pir=self.check_dygraph(), check_new_ir=self.check_dygraph(), ) @@ -82,6 +84,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_dygraph=self.check_dygraph(), check_prim=self.check_prim, + check_prim_pir=self.check_dygraph(), check_new_ir=self.check_dygraph(), ) @@ -95,6 +98,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_dygraph=self.check_dygraph(), check_prim=self.check_prim, + check_prim_pir=self.check_dygraph(), check_new_ir=self.check_dygraph(), ) @@ -152,6 +156,7 @@ def test_check_output(self): atol=1e-3, check_dygraph=self.check_dygraph(), check_prim=self.check_prim, + check_prim_pir=self.check_dygraph(), check_new_ir=self.check_dygraph(), ) @@ -167,6 +172,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), check_prim=True, + check_prim_pir=True, check_new_ir=True, ) @@ -178,6 +184,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set('Y'), check_prim=True, + check_prim_pir=True, check_new_ir=True, ) @@ -221,6 +228,7 @@ def test_check_grad_normal(self): ['X', 'Y'], 'Out', check_prim=True, + check_prim_pir=True, check_new_ir=True, ) @@ -232,6 +240,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), check_prim=True, + check_prim_pir=True, check_new_ir=True, ) @@ -243,6 +252,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set('Y'), check_prim=True, + check_prim_pir=True, check_new_ir=True, ) diff --git a/test/legacy_test/test_elementwise_mul_op.py b/test/legacy_test/test_elementwise_mul_op.py index 86f4e764916e07..43b9fbe7b9dde4 100644 --- a/test/legacy_test/test_elementwise_mul_op.py +++ b/test/legacy_test/test_elementwise_mul_op.py @@ -49,6 +49,7 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode self.check_output( check_dygraph=(not self.use_mkldnn), + check_prim_pir=(not self.use_mkldnn), check_new_ir=(not self.use_mkldnn), ) @@ -59,6 +60,7 @@ def test_check_grad_normal(self): 'Out', check_dygraph=(not self.use_mkldnn), check_prim=True, + check_prim_pir=(not self.use_mkldnn), check_new_ir=(not self.use_mkldnn), ) @@ -70,6 +72,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_dygraph=(not self.use_mkldnn), check_prim=True, + check_prim_pir=(not self.use_mkldnn), check_new_ir=(not self.use_mkldnn), ) @@ -81,6 +84,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_dygraph=(not self.use_mkldnn), check_prim=True, + check_prim_pir=(not self.use_mkldnn), check_new_ir=(not self.use_mkldnn), ) @@ -102,6 +106,7 @@ def if_enable_cinn(self): class TestComplexElementwiseMulOpWithCheckGrad(ElementwiseMulOp): def setUp(self): self.op_type = "elementwise_mul" + self.prim_op_type = "prim" self.python_api = paddle.multiply self.public_python_api = paddle.multiply self.dtype = np.complex128 @@ -199,7 +204,13 @@ def test_check_output(self): self.check_output() def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X', 'Y'], + 'Out', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, + ) def test_check_grad_ingore_x(self): self.check_grad( @@ -207,6 +218,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), check_prim=True, + check_prim_pir=True, check_new_ir=True, ) @@ -216,6 +228,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set('Y'), check_prim=True, + check_prim_pir=True, check_new_ir=True, ) @@ -427,6 +440,7 @@ def test_check_grad_normal(self): 'Out', check_dygraph=(not self.use_mkldnn), check_prim=True, + check_prim_pir=(not self.use_mkldnn), check_new_ir=(not self.use_mkldnn), ) @@ -438,6 +452,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_dygraph=(not self.use_mkldnn), check_prim=True, + check_prim_pir=(not self.use_mkldnn), check_new_ir=(not self.use_mkldnn), ) @@ -449,6 +464,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_dygraph=(not self.use_mkldnn), check_prim=True, + check_prim_pir=(not self.use_mkldnn), check_new_ir=(not self.use_mkldnn), ) @@ -503,6 +519,7 @@ def setUp(self): class TestComplexElementwiseMulOp(OpTest): def setUp(self): self.op_type = "elementwise_mul" + self.prim_op_type = "prim" self.python_api = paddle.multiply self.init_base_dtype() self.init_input_output() diff --git a/test/legacy_test/test_elementwise_pow_op.py b/test/legacy_test/test_elementwise_pow_op.py index e406845960abc5..c718ce16292b9b 100644 --- a/test/legacy_test/test_elementwise_pow_op.py +++ b/test/legacy_test/test_elementwise_pow_op.py @@ -44,7 +44,7 @@ def test_check_output(self): if hasattr(self, 'attrs'): self.check_output(check_dygraph=False) else: - self.check_output(check_new_ir=True) + self.check_output(check_prim_pir=True, check_new_ir=True) def test_check_grad_normal(self): if hasattr(self, 'attrs'): @@ -53,7 +53,11 @@ def test_check_grad_normal(self): ) else: self.check_grad( - ['X', 'Y'], 'Out', check_prim=True, check_new_ir=True + ['X', 'Y'], + 'Out', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, ) @@ -190,6 +194,8 @@ class TestElementwisePowOpInt(OpTest): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.public_python_api = paddle.pow + self.prim_op_type = "prim" self.inputs = {'X': np.asarray([1, 3, 6]), 'Y': np.asarray([1, 1, 1])} self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} @@ -198,7 +204,7 @@ def test_check_output(self): if hasattr(self, 'attrs'): self.check_output(check_dygraph=False) else: - self.check_output(check_new_ir=True) + self.check_output(check_prim_pir=True, check_new_ir=True) class TestElementwisePowGradOpInt(unittest.TestCase): @@ -254,7 +260,7 @@ def test_check_output(self): if hasattr(self, 'attrs'): self.check_output(check_dygraph=False) else: - self.check_output(check_new_ir=True) + self.check_output(check_prim_pir=True, check_new_ir=True) def test_check_grad(self): self.check_grad( @@ -264,6 +270,7 @@ def test_check_grad(self): self.inputs['X'], self.inputs['Y'], 1 / self.inputs['X'].size ), check_prim=True, + check_prim_pir=True, check_new_ir=True, ) @@ -290,7 +297,7 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(out)} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_prim_pir=True, check_new_ir=True) def test_check_grad(self): self.check_grad(['X', 'Y'], 'Out') @@ -301,7 +308,7 @@ def test_check_grad(self): 'Out', check_prim=True, only_check_prim=True, - check_new_ir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_layer_norm_op.py b/test/legacy_test/test_layer_norm_op.py index b023ff6488e481..3fb01bb3d0b62a 100644 --- a/test/legacy_test/test_layer_norm_op.py +++ b/test/legacy_test/test_layer_norm_op.py @@ -141,8 +141,9 @@ def test_check_output(self): no_check_set=["Mean", "Variance"], atol=self.ori_atol, rtol=self.ori_rtol, - check_prim=True, - check_new_ir=True, + check_prim=self.check_prim, + check_prim_pir=self.check_prim_pir, + check_new_ir=self.check_new_ir, ) def test_check_grad(self): @@ -150,8 +151,9 @@ def test_check_grad(self): self.check_grad_input_list, ['Y'], max_relative_error=self.max_relative_error, - check_prim=True, - check_new_ir=True, + check_prim=self.check_prim, + check_prim_pir=self.check_prim_pir, + check_new_ir=self.check_new_ir, ) def initConfig(self): @@ -173,6 +175,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = True self.has_bias = True + self.check_prim = True + self.check_prim_pir = True + self.check_new_ir = True def initTestCase(self): np.random.seed(123) @@ -240,8 +245,9 @@ def test_check_output(self): no_check_set=["Mean", "Variance"], atol=self.ori_atol, rtol=self.ori_rtol, - check_prim=True, - check_new_ir=True, + check_prim=self.check_prim, + check_prim_pir=self.check_prim_pir, + check_new_ir=self.check_new_ir, ) def test_check_grad(self): @@ -250,8 +256,9 @@ def test_check_grad(self): self.check_grad_input_list, ['Y'], max_relative_error=self.max_relative_error, - check_prim=True, - check_new_ir=True, + check_prim=self.check_prim, + check_prim_pir=self.check_prim_pir, + check_new_ir=self.check_new_ir, ) def initConfig(self): @@ -266,6 +273,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = True self.has_bias = True + self.check_prim = True + self.check_prim_pir = True + self.check_new_ir = True def initTestCase(self): np.random.seed(123) @@ -335,6 +345,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = False self.has_bias = False + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True @unittest.skipIf( @@ -356,6 +369,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = False self.has_bias = False + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True @unittest.skipIf( @@ -382,6 +398,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = True self.has_bias = False + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True @unittest.skipIf( @@ -403,6 +422,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = True self.has_bias = False + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True @unittest.skipIf( @@ -429,6 +451,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = False self.has_bias = True + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True @unittest.skipIf( @@ -450,6 +475,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = False self.has_bias = True + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest): @@ -467,6 +495,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = True self.has_bias = True + self.check_prim = True + self.check_prim_pir = True + self.check_new_ir = True class TestLayerNormOpByOpTestFP32_case2(TestLayerNormOpByOpTest): @@ -484,6 +515,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = False self.has_bias = False + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True class TestLayerNormOpByOpTestFP32_case3(TestLayerNormOpByOpTest): @@ -501,6 +535,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = True self.has_bias = False + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True class TestLayerNormOpByOpTestFP32_case4(TestLayerNormOpByOpTest): @@ -518,6 +555,9 @@ def initConfig(self): self.begin_norm_axis = 1 self.has_scale = False self.has_bias = True + self.check_prim = False + self.check_prim_pir = False + self.check_new_ir = True class TestLayerNormOp(unittest.TestCase): diff --git a/test/legacy_test/test_reshape_op.py b/test/legacy_test/test_reshape_op.py index c9ab6baf41ef66..0a9132ca55b49a 100755 --- a/test/legacy_test/test_reshape_op.py +++ b/test/legacy_test/test_reshape_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from op_test import OpTest, convert_float_to_uint16 +from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci import paddle from paddle import base @@ -43,11 +43,17 @@ def init_data(self): self.new_shape = (12, 10) self.infered_shape = (12, 10) - def test_check_output(self): + def _test_check_output(self): self.check_output(no_check_set=['XShape'], check_new_ir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_prim=True, check_new_ir=True) + self.check_grad( + ["X"], + "Out", + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) class TestReshapeOp_ZeroDim1(TestReshapeOp): @@ -120,7 +126,7 @@ def test_check_output(self): self.check_output(no_check_set=['XShape']) def test_check_grad(self): - self.check_grad(["X"], "Out", check_prim=True) + self.check_grad(["X"], "Out", check_prim=True, check_prim_pir=True) class TestReshapeFP16Op(OpTest): @@ -148,7 +154,7 @@ def test_check_output(self): self.check_output(no_check_set=['XShape']) def test_check_grad(self): - self.check_grad(["X"], "Out", check_prim=True) + self.check_grad(["X"], "Out", check_prim=True, check_prim_pir=True) class TestReshapeOpDimInfer1(TestReshapeOp): @@ -340,6 +346,9 @@ def init_dtype(self): self.dtype = np.uint8 +@skip_check_grad_ci( + "we don't need to check grad for the bool type of reshape op" +) class TestReshapeOpBool(TestReshapeOp): def setUp(self): self.init_data() diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 964e127aafb819..92dfe72f8443e3 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -61,7 +61,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X'], ['out0', 'out1', 'out2'], check_prim=True, check_new_ir=True + ['X'], + ['out0', 'out1', 'out2'], + check_prim=True, + check_prim_pir=True, + check_new_ir=True, ) @@ -117,7 +121,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X'], ['out0', 'out1', 'out2'], check_prim=True, check_new_ir=True + ['X'], + ['out0', 'out1', 'out2'], + check_prim=True, + check_prim_pir=True, + check_new_ir=True, ) @@ -243,7 +251,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X'], ['out0', 'out1', 'out2'], check_prim=True, check_new_ir=True + ['X'], + ['out0', 'out1', 'out2'], + check_prim=True, + check_prim_pir=True, + check_new_ir=True, ) @@ -291,7 +303,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'out2', check_prim=True, check_new_ir=True + place, + ['X'], + 'out2', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, ) cls_name = "{}_{}".format(parent.__name__, "BF16Op") diff --git a/test/legacy_test/test_sum_op.py b/test/legacy_test/test_sum_op.py index 63a68442936ab2..c154625fb51f44 100644 --- a/test/legacy_test/test_sum_op.py +++ b/test/legacy_test/test_sum_op.py @@ -58,11 +58,20 @@ def init_kernel_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_prim=True, check_cinn=True, check_new_ir=True) + self.check_output( + check_prim=True, + check_cinn=True, + check_new_ir=True, + ) def test_check_grad(self): self.check_grad( - ['x0'], 'Out', check_prim=True, check_cinn=True, check_new_ir=True + ['x0'], + 'Out', + check_prim=True, + check_cinn=True, + check_prim_pir=True, + check_new_ir=True, ) @@ -304,7 +313,13 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_grad(['x0'], 'Out', check_cinn=True, check_new_ir=True) + self.check_grad( + ['x0'], + 'Out', + check_cinn=True, + check_prim_pir=True, + check_new_ir=True, + ) def create_test_sum_fp16_class(parent): @@ -330,7 +345,9 @@ def test_w_is_selected_rows(self): class TestSumBF16Op(OpTest): def setUp(self): self.op_type = "sum" + self.prim_op_type = "prim" self.python_api = paddle.add_n + self.public_python_api = paddle.add_n self.init_kernel_type() x0 = np.random.random((3, 40)).astype(np.float32) x1 = np.random.random((3, 40)).astype(np.float32) @@ -354,7 +371,13 @@ def test_check_output(self): def test_check_grad(self): # new dynamic graph mode does not support unit16 type - self.check_grad(['x0'], 'Out', check_dygraph=False, check_new_ir=True) + self.check_grad( + ['x0'], + 'Out', + check_dygraph=False, + check_prim_pir=True, + check_new_ir=True, + ) class API_Test_Add_n(unittest.TestCase): diff --git a/test/legacy_test/test_transpose_op.py b/test/legacy_test/test_transpose_op.py index c8d91f59f8c490..52f85ef1e0a708 100644 --- a/test/legacy_test/test_transpose_op.py +++ b/test/legacy_test/test_transpose_op.py @@ -52,7 +52,13 @@ def test_check_output(self): self.check_output(no_check_set=['XShape'], check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_new_ir=True, + check_prim_pir=True, + ) def if_enable_cinn(self): pass @@ -209,7 +215,13 @@ def test_check_output(self): base.core.disable_autotune() def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, + ) class TestAutoTuneTransposeFP16Op(OpTest): @@ -246,7 +258,13 @@ def test_check_output(self): base.core.disable_autotune() def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, + ) class TestAutoTuneTransposeBF16Op(OpTest): @@ -290,7 +308,13 @@ def test_check_output(self): base.core.disable_autotune() def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, + ) class TestTransposeFP16Op(OpTest): @@ -325,7 +349,13 @@ def test_check_output(self): self.check_output(no_check_set=['XShape'], check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_prim_pir=True, + check_new_ir=True, + ) def initTestCase(self): self.shape = (3, 40)