Skip to content

Commit

Permalink
[Cherry pick] Add some final state OPs (#41737)
Browse files Browse the repository at this point in the history
* Add yaml for matrix rank op (#41466)

* modify matrix_rank

* add matrix_rank shape

* add matrix_rank shape

* Add yaml for matrix_rank OP

* Add UT

Co-authored-by: zhoujianqian <[email protected]>

* Add yaml for eye OP (#41476)

* [cherry-pick] Add yaml config for matrix_rank, eye, deformable_conv and
deformable_conv_v1 OPs
* Add yaml for deformable_conv and deformable_conv_v1 OPs

* Add UT

* Add to skipped_phi_api list for infrt

Co-authored-by: zhoujianqian <[email protected]>
  • Loading branch information
From00 and Zjq9409 authored Apr 13, 2022
1 parent 4c11be8 commit 33583dd
Show file tree
Hide file tree
Showing 18 changed files with 322 additions and 20 deletions.
21 changes: 21 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,27 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
logits_grad->set_dtype(softmax.dtype());
}

void DeformableConvGradInferMeta(const MetaTensor& x,
const MetaTensor& offset,
const MetaTensor& filter,
paddle::optional<const MetaTensor&> mask,
const MetaTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
MetaTensor* dx,
MetaTensor* offset_grad,
MetaTensor* filter_grad,
MetaTensor* mask_grad) {
GeneralTernaryGradInferMeta(x, offset, filter, dx, offset_grad, filter_grad);
if (mask) {
UnchangedInferMeta(mask.get(), mask_grad);
}
}

void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
MetaTensor* logits_grad,
MetaConfig config = MetaConfig());

void DeformableConvGradInferMeta(const MetaTensor& x,
const MetaTensor& offset,
const MetaTensor& filter,
paddle::optional<const MetaTensor&> mask,
const MetaTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
MetaTensor* dx,
MetaTensor* offset_grad,
MetaTensor* filter_grad,
MetaTensor* mask_grad);

void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
Expand Down
51 changes: 51 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
}
}

// Used in MatrixRankTolInferMeta
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
}

} // namespace detail

void AllValueCompareInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -1465,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
out->share_lod(x);
}

void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
2,
phi::errors::InvalidArgument("The dims of input must be greater than 2"));

if (hermitian) {
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
PADDLE_ENFORCE_EQ(rows,
cols,
phi::errors::InvalidArgument(
"if hermitian == true, matrix should be n*n"));
}
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
auto dim_tol = atol_tensor.dims();
if (dim_x_batch == dim_tol) {
out->set_dims(dim_x_batch);
} else {
int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
int axis = std::abs(dim_x_batch.size() - dim_tol.size());
std::vector<int> x_batch_dims_array(max_dim);
std::vector<int> tol_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
phi::funcs::GetBroadcastDimsArrays(dim_x_batch,
dim_tol,
x_batch_dims_array.data(),
tol_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out->set_dims(phi::make_ddim(out_dims_array));
}
out->share_lod(x);
}

void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
auto dim_x = x.dims();
auto dim_vec = vec.dims();
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
int y_num_col_dims,
MetaTensor* out);

void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
MetaTensor* out);

void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);

void PReluInferMeta(const MetaTensor& x,
Expand Down
35 changes: 35 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ limitations under the License. */

namespace phi {

namespace detail {
// Used in MatrixRankInferMeta
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
}
} // namespace detail

void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
Expand Down Expand Up @@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
out->set_dtype(x.dtype());
}

void MatrixRankInferMeta(const MetaTensor& x,
bool use_default_tol,
bool hermitian,
MetaTensor* out) {
auto dim_x = x.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
2,
phi::errors::InvalidArgument("The dims of input must be greater than 2"));

if (hermitian) {
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
PADDLE_ENFORCE_EQ(rows,
cols,
phi::errors::InvalidArgument(
"if hermitian == true, matrix should be n*n"));
}
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
out->set_dims(dim_x_batch);
out->share_lod(x);
}

void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input,

void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);

void MatrixRankInferMeta(const MetaTensor& x,
bool use_default_tol,
bool hermitian,
MetaTensor* out);

void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,10 +1752,12 @@ def eye(num_rows,
else:
num_columns = num_rows

if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_eye(num_rows, num_columns, dtype,
_current_expected_place())
elif _in_legacy_dygraph():
out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns',
num_columns)

else:
helper = LayerHelper("eye", **locals())
check_dtype(dtype, 'dtype',
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ set_tests_properties(test_lstm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_star_gan_with_gradient_penalty PROPERTIES TIMEOUT 120)

set_tests_properties(test_bicubic_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_profiler PROPERTIES TIMEOUT 120)
set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120)
Expand Down Expand Up @@ -1045,7 +1045,7 @@ set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT
set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120)
set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120)
set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_executor_transformer_auto_growth PROPERTIES TIMEOUT 120)
set_tests_properties(test_py_reader_using_executor PROPERTIES TIMEOUT 120)
set_tests_properties(test_elementwise_add_op PROPERTIES TIMEOUT 120)
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/fluid/tests/unittests/test_deform_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import paddle.nn.initializer as I
import numpy as np
import unittest
from paddle.fluid.framework import _test_eager_guard
from unittest import TestCase


Expand Down Expand Up @@ -183,6 +184,10 @@ def test_identity(self):
self.place = paddle.CUDAPlace(0)
self._test_identity()

def test_identity_with_eager_guard(self):
with _test_eager_guard():
self.test_identity()


class TestDeformConv2DFunctional(TestCase):
batch_size = 4
Expand Down Expand Up @@ -418,6 +423,10 @@ def test_identity(self):
self.place = paddle.CUDAPlace(0)
self._test_identity()

def test_identity_with_eager_guard(self):
with _test_eager_guard():
self.test_identity()


# testcases for DeformConv2D
class TestDeformConv2DWithPadding(TestDeformConv2D):
Expand Down
35 changes: 31 additions & 4 deletions python/paddle/fluid/tests/unittests/test_deformable_conv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from __future__ import print_function

import paddle
import unittest
import numpy as np

import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard

paddle.enable_static()


def dmc_bilinear(data_im, height, width, h, w):
Expand Down Expand Up @@ -108,8 +110,24 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
return out


def deform_conv2d_wrapper(x,
offset,
weight,
mask=None,
stride=1,
padding=0,
dilation=1,
deformable_groups=1,
groups=1,
im2col_step=1):
return paddle.vision.ops.deform_conv2d(x, offset, weight, None, stride,
padding, dilation, deformable_groups,
groups, mask)


class TestModulatedDeformableConvOp(OpTest):
def setUp(self):
self.python_api = deform_conv2d_wrapper
self.op_type = "deformable_conv"
self.init_type()
self.init_group()
Expand Down Expand Up @@ -148,13 +166,14 @@ def setUp(self):
self.outputs = {'Output': output}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(
{'Input', 'Offset', 'Mask', 'Filter'},
'Output',
max_relative_error=0.05)
max_relative_error=0.05,
check_eager=True)

def init_test_case(self):
self.pad = [1, 1]
Expand Down Expand Up @@ -327,6 +346,10 @@ def test_invalid_filter():

self.assertRaises(ValueError, test_invalid_filter)

def test_error_with_eager_guard(self):
with _test_eager_guard():
self.test_error()


class TestDeformConv2DAPI(unittest.TestCase):
def test_api(self):
Expand Down Expand Up @@ -358,6 +381,10 @@ def test_deform_conv2d_v2():

test_deform_conv2d_v2()

def test_api_with_eager_guard(self):
with _test_eager_guard():
self.test_api()


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 33583dd

Please sign in to comment.