Skip to content

Commit

Permalink
【PIR API adaptor No.140-142】 Migrate lstsq/lu/lu_unpack into pir (Pad…
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRyanHuang authored and SecretXV committed Nov 28, 2023
1 parent 9a73eef commit a0c28fb
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def get_device():
"""
device = ''
place = framework._current_expected_place()
place = framework._current_expected_place_()
if isinstance(place, core.CPUPlace):
device = 'cpu'
elif isinstance(place, core.CUDAPlace):
Expand Down
35 changes: 28 additions & 7 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2451,7 +2451,7 @@ def lu(x, pivot=True, get_infos=False, name=None):
>>> # one can verify : X = P @ L @ U ;
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
lu, p, info = _C_ops.lu(x, pivot)
else:
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'lu')
Expand Down Expand Up @@ -2554,7 +2554,7 @@ def lu_unpack(x, y, unpack_ludata=True, unpack_pivots=True, name=None):
raise ValueError(
f"The shape of Pivots should be (*, K), but received ndim is [{y.ndim} < 1]"
)
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
P, L, U = _C_ops.lu_unpack(x, y, unpack_ludata, unpack_pivots)
return P, L, U
else:
Expand Down Expand Up @@ -3454,7 +3454,16 @@ def lstsq(x, y, rcond=None, driver=None, name=None):
else:
raise RuntimeError("Only support lstsq api for CPU or CUDA device.")

if not (x.dtype == y.dtype and x.dtype in (paddle.float32, paddle.float64)):
if not (
x.dtype == y.dtype
and x.dtype
in (
paddle.float32,
paddle.float64,
paddle.base.core.DataType.FLOAT32,
paddle.base.core.DataType.FLOAT64,
)
):
raise ValueError(
"Only support x and y have the same dtype such as 'float32' and 'float64'."
)
Expand All @@ -3475,17 +3484,29 @@ def lstsq(x, y, rcond=None, driver=None, name=None):
)

if rcond is None:
if x.dtype == paddle.float32:
if (
x.dtype == paddle.float32
or x.dtype == paddle.base.core.DataType.FLOAT32
):
rcond = 1e-7 * max(x.shape[-2], x.shape[-1])
elif x.dtype == paddle.float64:
elif (
x.dtype == paddle.float64
or x.dtype == paddle.base.core.DataType.FLOAT64
):
rcond = 1e-15 * max(x.shape[-2], x.shape[-1])

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
solution, residuals, rank, singular_values = _C_ops.lstsq(
x, y, rcond, driver
)
if driver == "gels":
rank = paddle.empty(shape=[0], dtype=paddle.int32)
if in_dynamic_mode():
rank = paddle.empty(shape=[0], dtype=paddle.int32)

else:
rank = paddle.empty(
shape=[0], dtype=paddle.base.core.DataType.INT32
)
singular_values = paddle.empty(shape=[0], dtype=x.dtype)
elif driver == "gelsy":
singular_values = paddle.empty(shape=[0], dtype=x.dtype)
Expand Down
10 changes: 7 additions & 3 deletions test/legacy_test/test_linalg_lstsq_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class LinalgLstsqTestCase(unittest.TestCase):
def setUp(self):
self.devices = ["cpu"]
self.init_config()
if core.is_compiled_with_cuda() and self.driver == "gels":
self.devices.append("gpu:0")
self.devices.append("gpu")
self.generate_input()
self.generate_output()
np.random.seed(2022)
Expand Down Expand Up @@ -91,12 +92,15 @@ def test_eager_dygraph(self):
self._result_sg_values = results[3].numpy()
self.assert_np_close()

@test_with_pir_api
def test_static(self):
paddle.enable_static()
for dev in self.devices:
paddle.set_device(dev)
place = base.CPUPlace() if dev == "cpu" else base.CUDAPlace(0)
with base.program_guard(base.Program(), base.Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(
name="x",
shape=self._input_shape_1,
Expand All @@ -112,7 +116,6 @@ def test_static(self):
)
exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"x": self._input_data_1, "y": self._input_data_2},
fetch_list=[results],
)
Expand Down Expand Up @@ -282,6 +285,7 @@ class TestLinalgLstsqAPIError(unittest.TestCase):
def setUp(self):
pass

@test_with_pir_api
def test_api_errors(self):
def test_x_bad_shape():
x = paddle.to_tensor(np.random.random(size=(5)), dtype=np.float32)
Expand Down
11 changes: 7 additions & 4 deletions test/legacy_test/test_lu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def scipy_lu(A, pivot):
Expand Down Expand Up @@ -156,10 +157,10 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], ['Out'])
self.check_grad(['X'], ['Out'], check_pir=True)


# m = n 2D
Expand Down Expand Up @@ -238,6 +239,7 @@ def run_lu_dygraph(shape, dtype):
for tensor_shape, dtype in itertools.product(tensor_shapes, dtypes):
run_lu_dygraph(tensor_shape, dtype)

@test_with_pir_api
def test_static(self):
paddle.enable_static()

Expand All @@ -257,7 +259,9 @@ def run_lu_static(shape, dtype):
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for place in places:
with base.program_guard(base.Program(), base.Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
batch_size = a.size // (a.shape[-1] * a.shape[-2])
sP, sl, sU = scipy_lu(a, pivot)
sL = np.tril(sl, -1)
Expand All @@ -284,7 +288,6 @@ def run_lu_static(shape, dtype):
lu, p = paddle.linalg.lu(x, pivot=pivot)
exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input": a},
fetch_list=[lu, p],
)
Expand Down
14 changes: 9 additions & 5 deletions test/legacy_test/test_lu_unpack_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def scipy_lu_unpack(A):
Expand Down Expand Up @@ -138,7 +139,9 @@ def setUp(self):
lu = lu.numpy()
pivots = pivots.numpy()
else:
with base.program_guard(base.Program(), base.Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
place = base.CPUPlace()
if core.is_compiled_with_cuda():
place = base.CUDAPlace(0)
Expand All @@ -148,7 +151,6 @@ def setUp(self):
lu, p = paddle.linalg.lu(xv)
exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input": x},
fetch_list=[lu, p],
)
Expand All @@ -168,7 +170,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], ['L', 'U'])
Expand Down Expand Up @@ -258,6 +260,7 @@ def run_lu_unpack_dygraph(shape, dtype):
for tensor_shape, dtype in itertools.product(tensor_shapes, dtypes):
run_lu_unpack_dygraph(tensor_shape, dtype)

@test_with_pir_api
def test_static(self):
paddle.enable_static()

Expand All @@ -275,7 +278,9 @@ def run_lu_static(shape, dtype):
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for place in places:
with base.program_guard(base.Program(), base.Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
sP, sL, sU = scipy_lu_unpack(a)

x = paddle.static.data(
Expand All @@ -285,7 +290,6 @@ def run_lu_static(shape, dtype):
pP, pL, pU = paddle.linalg.lu_unpack(lu, p)
exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input": a},
fetch_list=[pP, pL, pU],
)
Expand Down

0 comments on commit a0c28fb

Please sign in to comment.