Skip to content

Commit

Permalink
Add yaml for matrix rank op (#41466)
Browse files Browse the repository at this point in the history
* modify matrix_rank

* add matrix_rank shape

* add matrix_rank shape

* Add yaml for matrix_rank OP

* Add UT

Co-authored-by: zhoujianqian <[email protected]>
  • Loading branch information
From00 and Zjq9409 authored Apr 7, 2022
1 parent 5516f18 commit c77a263
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 2 deletions.
51 changes: 51 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
}
}

// Used in MatrixRankTolInferMeta
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
}

} // namespace detail

void AllValueCompareInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -1465,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
out->share_lod(x);
}

void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
2,
phi::errors::InvalidArgument("The dims of input must be greater than 2"));

if (hermitian) {
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
PADDLE_ENFORCE_EQ(rows,
cols,
phi::errors::InvalidArgument(
"if hermitian == true, matrix should be n*n"));
}
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
auto dim_tol = atol_tensor.dims();
if (dim_x_batch == dim_tol) {
out->set_dims(dim_x_batch);
} else {
int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
int axis = std::abs(dim_x_batch.size() - dim_tol.size());
std::vector<int> x_batch_dims_array(max_dim);
std::vector<int> tol_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
phi::funcs::GetBroadcastDimsArrays(dim_x_batch,
dim_tol,
x_batch_dims_array.data(),
tol_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out->set_dims(phi::make_ddim(out_dims_array));
}
out->share_lod(x);
}

void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
auto dim_x = x.dims();
auto dim_vec = vec.dims();
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
int y_num_col_dims,
MetaTensor* out);

void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
MetaTensor* out);

void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);

void PReluInferMeta(const MetaTensor& x,
Expand Down
35 changes: 35 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ limitations under the License. */

namespace phi {

namespace detail {
// Used in MatrixRankInferMeta
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
}
} // namespace detail

void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
Expand Down Expand Up @@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
out->set_dtype(x.dtype());
}

void MatrixRankInferMeta(const MetaTensor& x,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
2,
phi::errors::InvalidArgument("The dims of input must be greater than 2"));

if (hermitian) {
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
PADDLE_ENFORCE_EQ(rows,
cols,
phi::errors::InvalidArgument(
"if hermitian == true, matrix should be n*n"));
}
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
out->set_dims(dim_x_batch);
out->share_lod(x);
}

void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input,

void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);

void MatrixRankInferMeta(const MetaTensor& x,
bool use_default_tol,
bool hermitian,
MetaTensor* out);

void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
Expand Down
29 changes: 28 additions & 1 deletion python/paddle/fluid/tests/unittests/test_matrix_rank_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@
np.random.seed(SEED)


def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False):
return paddle.linalg.matrix_rank(x, tol, hermitian)


class TestMatrixRankOP(OpTest):
def setUp(self):
self.python_api = matrix_rank_wraper
self.op_type = "matrix_rank"
self.init_data()
self.inputs = {'X': self.x}
Expand All @@ -44,7 +49,7 @@ def setUp(self):
self.outputs = {'Out': self.out}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def init_data(self):
self.x = np.eye(3, dtype=np.float32)
Expand Down Expand Up @@ -110,6 +115,28 @@ def init_data(self):
self.hermitian)


class TestMatrixRankOP6(TestMatrixRankOP):
def init_data(self):
self.x = np.random.rand(3, 4, 5, 6).astype(np.float32)
self.tol_tensor = None
self.tol = None
self.use_default_tol = False
self.hermitian = False
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)


class TestMatrixRankOP7(TestMatrixRankOP):
def init_data(self):
self.x = np.eye(200, dtype=np.float64)
self.tol_tensor = np.random.random([200, 200]).astype(self.x.dtype)
self.tol = None
self.use_default_tol = True
self.hermitian = True
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
self.hermitian)


class TestMatrixRankAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
Expand Down
20 changes: 19 additions & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,8 +1284,26 @@ def matrix_rank(x, tol=None, hermitian=False, name=None):
# [1, 1, 1, 1]]
"""
if in_dygraph_mode():
if isinstance(tol, Variable):
if tol.dtype != x.dtype:
tol_tensor = cast(tol, x.dtype)
else:
tol_tensor = tol
use_default_tol = False
return _C_ops.final_state_matrix_rank_tol(
x, tol_tensor, use_default_tol, hermitian)

if paddle.in_dynamic_mode():
if tol is None:
tol_attr = 0.0
use_default_tol = True
else:
tol_attr = float(tol)
use_default_tol = False
return _C_ops.final_state_matrix_rank(x, tol_attr, use_default_tol,
hermitian)

if _in_legacy_dygraph():
if tol is None:
tol_tensor = None
tol_attr = 0.0
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,23 @@
func : matrix_power
backward : matrix_power_grad

- api : matrix_rank
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
infer_meta :
func : MatrixRankInferMeta
param : [x, use_default_tol, hermitian]
kernel :
func : matrix_rank

- api : matrix_rank_tol
args : (Tensor x, Tensor atol_tensor, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
infer_meta :
func : MatrixRankTolInferMeta
kernel :
func : matrix_rank_tol

- api : max
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
output : Tensor(out)
Expand Down

0 comments on commit c77a263

Please sign in to comment.