From 8dd152d45c6a3fe1509ee5b4e6bebb607df06e06 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Thu, 12 Oct 2023 12:48:40 +0000 Subject: [PATCH] fix code --- python/paddle/decomposition/rules.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index 4de865497f67f7..6e59f9858e74a7 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -56,10 +56,8 @@ def gelu(x, approximate): tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) out = x * half * (one + tanh_out) return out - else: # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - cdf = half * (one + _pir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype))) out = x * cdf return out @@ -71,7 +69,6 @@ def sqrt(x): define composite rule of op sqrt res = pow(x, 0.5) """ - # breakpoint() is_amp = False from paddle.base.data_feeder import convert_dtype @@ -255,7 +252,6 @@ def silu(x): @register_decomp('pd_op.softmax') def softmax(x, axis): """define composite rule of op softmax""" - # breakpoint() is_amp = False from paddle.base.data_feeder import convert_dtype @@ -285,9 +281,8 @@ def full_like(x, fill_value, dtype, place=None): """define composite rule of op full_like.""" """op name: full_like op type name: fill_any_like.""" """arg place is not used, add it here to keep same as python api.""" - val = full( - x.shape, fill_value, dtype - ) # x.shape = [10, 10], val.shape=[-1, -1] + fill_value = fill_value.get_defining_op().attrs()["value"] + val = full(x.shape, fill_value, dtype) return val @@ -297,11 +292,9 @@ def stack(x, axis): define composite rule of op stack unsqueeze each dimension of the input (use reshape), and then concat """ - x_shape = x[0].shape if axis < 0: axis += len(x_shape) + 1 - # breakpoint() out_shape = x_shape[:axis] + [1] + x_shape[axis:] out = concat([reshape(item, out_shape) for item in x], axis) return out