Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Complex op】add complex support for index_select and index_sample #56457

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ void GradNodeBase::HandleComplexGradToRealGrad(
for (size_t slot_id = 0; slot_id < out_grads->size(); slot_id++) {
const std::vector<paddle::Tensor>& slot_out_grads = (*out_grads)[slot_id];
for (size_t rank_id = 0; rank_id < slot_out_grads.size(); rank_id++) {
if (bwd_out_meta_[slot_id].size() == 0) continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里增加判断的理由是什么呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h 中生成了 IndexSelectGradNode 类(集成自 egr::GradNodeBase
  2. paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.cc 中的 index_select_ad_func 函数里对 IndexSelectGradNode 类进行了初始化
    grad_node = std::shared_ptr<IndexSelectGradNode>(new IndexSelectGradNode(1, 2));
    其中 2 是根据 ops.yaml 中 index_select 里 args 输入参数的数据类型确定的,即数据类型为 Tensor 的个数。
  3. paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.cc 中的 IndexSelectGradNode::operator() 方法最后有:
      if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);
  4. paddle/fluid/eager/grad_node_info.cc 中的 GradNodeBase::HandleComplexGradToRealGrad 函数里的 const GradSlotMeta& slot_meta = bwd_out_meta_[slot_id][rank_id]; 会报错
    原因在于 bwd_out_meta_ 初始化时的 size 为 2,但当 bwd_out_meta_[1] 所维持的 vector size 为 0,因此此处在执行 bwd_out_meta_[1][0] 时会造成 Segmentation faults.

所以在此处加了特判。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那这里是不是不是根本问题?正常情况下,这里如果有out_grads, 那么对应的out_grad_meta应该也是记录好的才对,应该要看看为什么这里的out_grad_meta没有值

const GradSlotMeta& slot_meta = bwd_out_meta_[slot_id][rank_id];

PADDLE_ENFORCE(
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,8 @@
func : index_sample_grad
data_type : out_grad
no_need_buffer : x
data_transform :
skip_transform : index

- backward_op : index_select_grad
forward : index_select(Tensor x, Tensor index, int axis) -> Tensor(out)
Expand All @@ -1132,6 +1134,8 @@
func : index_select_grad
data_type : out_grad
no_need_buffer : x
data_transform :
skip_transform : index

- backward_op : index_select_strided_grad
forward : index_select_strided(Tensor x, int64_t index, int axis) -> Tensor(out)
Expand Down
25 changes: 18 additions & 7 deletions paddle/phi/api/yaml/generator/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,23 @@

import collections
import re
from typing import List

PREFIX_TENSOR_NAME = 'input_'
PREFIX_META_TENSOR_NAME = 'meta_'


def parse_plain_list(s: str, sep=",") -> List[str]:
"""Copy from `paddle/fluid/operators/generator/parse_utils.py`"""
if sep == ",":
patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}"
items = re.split(patten, s.strip())
items = [x.strip() for x in items]
return items
else:
return [item.strip() for item in s.strip().split(sep)]


class BaseAPI:
def __init__(self, api_item_yaml):
self.api = self.get_api_name(api_item_yaml)
Expand Down Expand Up @@ -367,14 +379,13 @@ def parse_data_transform(self, api_item_yaml):
data_transform = {'skip_transform': [], 'support_trans_dtype': []}
if 'data_transform' in api_item_yaml:
if 'skip_transform' in api_item_yaml['data_transform']:
data_transform['skip_transform'] = api_item_yaml[
'data_transform'
]['skip_transform']
data_transform['skip_transform'] = parse_plain_list(
api_item_yaml['data_transform']['skip_transform']
)
if 'support_trans_dtype' in api_item_yaml['data_transform']:
data_transform['support_trans_dtype'] = api_item_yaml[
'data_transform'
]['support_trans_dtype']

data_transform['support_trans_dtype'] = parse_plain_list(
api_item_yaml['data_transform']['support_trans_dtype']
)
return data_transform

# Override by child class
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,8 @@
func : index_sample
data_type : x
backward : index_sample_grad
data_transform :
skip_transform : index

- op : index_select
args : (Tensor x, Tensor index, int axis = 0)
Expand All @@ -1242,6 +1244,8 @@
func : index_select
data_type : x
backward : index_select_grad
data_transform :
skip_transform : index

- op : index_select_strided
args : (Tensor x, int64_t index, int axis = 0)
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/index_sample_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,6 @@ PD_REGISTER_KERNEL(index_sample_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/index_sample_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,6 @@ PD_REGISTER_KERNEL(index_sample,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/index_select_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ PD_REGISTER_KERNEL(index_select_grad,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/index_select_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@ PD_REGISTER_KERNEL(index_select,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/index_sample_grad_kernel.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,6 @@ PD_REGISTER_KERNEL(index_sample_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/index_sample_kernel.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,6 @@ PD_REGISTER_KERNEL(index_sample,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/index_select_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,7 @@ PD_REGISTER_KERNEL(index_select_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/index_select_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,7 @@ PD_REGISTER_KERNEL(index_select,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
26 changes: 22 additions & 4 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def index_select(x, index, axis=0, name=None):
size as the length of ``index``; other dimensions have the same size as in the ``x`` tensor.

Args:
x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float16, float32, float64, int32, int64.
x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float16, float32, float64, int32, int64, complex64 and complex128.
index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64.
axis (int, optional): The dimension in which we index. Default: if None, the ``axis`` is 0.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Expand Down Expand Up @@ -352,7 +352,16 @@ def index_select(x, index, axis=0, name=None):
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'uint16',
'float16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'paddle.tensor.search.index_select',
)
check_variable_and_dtype(
Expand Down Expand Up @@ -733,7 +742,7 @@ def index_sample(x, index):

Args:
x (Tensor): The source input tensor with 2-D shape. Supported data type is
int32, int64, bfloat16, float16, float32, float64.
int32, int64, bfloat16, float16, float32, float64, complex64, complex128.
index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X.
Data type is int32 or int64.

Expand Down Expand Up @@ -788,7 +797,16 @@ def index_sample(x, index):
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'uint16',
'float16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'paddle.tensor.search.index_sample',
)
check_variable_and_dtype(
Expand Down
27 changes: 27 additions & 0 deletions test/legacy_test/test_index_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def setUp(self):
self.python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
if self.x_type == np.complex64 or self.x_type == np.complex128:
xnp = (
np.random.random(self.x_shape)
+ 1j * np.random.random(self.x_shape)
).astype(self.x_type)
indexnp = np.random.randint(
low=0, high=self.x_shape[1], size=self.index_shape
).astype(self.index_type)
Expand Down Expand Up @@ -122,6 +127,28 @@ def config(self):
self.index_type = "int64"


class TestIndexSampleComplex64(TestIndexSampleOp):
def config(self):
"""
For complex64 x type
"""
self.x_shape = (10, 128)
self.x_type = np.complex64
self.index_shape = (10, 64)
self.index_type = "int64"


class TestIndexSampleComplex128(TestIndexSampleOp):
def config(self):
"""
For complex64 x type
"""
self.x_shape = (10, 128)
self.x_type = np.complex128
self.index_shape = (10, 64)
self.index_type = "int64"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Expand Down
33 changes: 31 additions & 2 deletions test/legacy_test/test_index_select_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def setUp(self):
low=0, high=self.x_shape[self.dim], size=self.index_size
)
x_np = np.random.random(self.x_shape).astype(self.x_type)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x_np = (
np.random.random(self.x_shape)
+ 1j * np.random.random(self.x_shape)
).astype(self.x_type)
self.inputs = {'X': x_np, 'Index': index_np}
self.attrs = {'dim': self.dim}
outer_loop = np.prod(self.x_shape[: self.dim])
Expand All @@ -60,10 +65,16 @@ def init_dtype_type(self):
self.index_size = 100

def test_check_output(self):
self.check_output(check_prim=True)
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_output(check_prim=False)
else:
self.check_output(check_prim=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
else:
self.check_grad(['X'], 'Out', check_prim=True)


class TestIndexSelectOpCase2(TestIndexSelectOp):
Expand Down Expand Up @@ -146,6 +157,24 @@ def test_check_grad_normal(self):
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)


class TestIndexSelectComplex64(TestIndexSelectOp):
def init_dtype_type(self):
self.x_type = np.complex64
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10


class TestIndexSelectComplex128(TestIndexSelectOp):
def init_dtype_type(self):
self.x_type = np.complex128
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10


class TestIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
Expand Down