Skip to content

Commit

Permalink
fix code
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Oct 12, 2023
1 parent b23afd5 commit 8dd152d
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ def gelu(x, approximate):
tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x))
out = x * half * (one + tanh_out)
return out

else:
# gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))

cdf = half * (one + _pir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype)))
out = x * cdf
return out
Expand All @@ -71,7 +69,6 @@ def sqrt(x):
define composite rule of op sqrt
res = pow(x, 0.5)
"""
# breakpoint()
is_amp = False
from paddle.base.data_feeder import convert_dtype

Expand Down Expand Up @@ -255,7 +252,6 @@ def silu(x):
@register_decomp('pd_op.softmax')
def softmax(x, axis):
"""define composite rule of op softmax"""
# breakpoint()
is_amp = False
from paddle.base.data_feeder import convert_dtype

Expand Down Expand Up @@ -285,9 +281,8 @@ def full_like(x, fill_value, dtype, place=None):
"""define composite rule of op full_like."""
"""op name: full_like op type name: fill_any_like."""
"""arg place is not used, add it here to keep same as python api."""
val = full(
x.shape, fill_value, dtype
) # x.shape = [10, 10], val.shape=[-1, -1]
fill_value = fill_value.get_defining_op().attrs()["value"]
val = full(x.shape, fill_value, dtype)
return val


Expand All @@ -297,11 +292,9 @@ def stack(x, axis):
define composite rule of op stack
unsqueeze each dimension of the input (use reshape), and then concat
"""

x_shape = x[0].shape
if axis < 0:
axis += len(x_shape) + 1
# breakpoint()
out_shape = x_shape[:axis] + [1] + x_shape[axis:]
out = concat([reshape(item, out_shape) for item in x], axis)
return out
Expand Down

0 comments on commit 8dd152d

Please sign in to comment.