Skip to content

Commit

Permalink
【complex op】 add complex support for unbind,broadcast,broadcast_tenso…
Browse files Browse the repository at this point in the history
…rs and broadcast_tensor_grad (#59122)

* add complex support for unbind,broadcast,broadcast_tensors and broadcast_tensors_grad

* add test_dtype

* add complex support for unbind,broadcast,broadcast_tensors and broadcast_tensor_grad

* fix code_style

* Resolve conflicts and generate complex data.
  • Loading branch information
zyt1024 authored Nov 29, 2023
1 parent 87bf502 commit 5e89708
Show file tree
Hide file tree
Showing 15 changed files with 310 additions and 65 deletions.
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/broadcast_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ PD_REGISTER_KERNEL(broadcast,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,6 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/broadcast_tensors_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(broadcast_tensors,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/unbind_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(unbind,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
8 changes: 6 additions & 2 deletions paddle/phi/kernels/gpu/broadcast_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ PD_REGISTER_KERNEL(broadcast,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(broadcast,
GPU,
Expand All @@ -79,5 +81,7 @@ PD_REGISTER_KERNEL(broadcast,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,6 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ PD_REGISTER_KERNEL(broadcast_tensors,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/unbind_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(unbind,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 1 addition & 1 deletion python/paddle/distributed/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def broadcast(tensor, src, group=None, sync_op=True):
Args:
tensor (Tensor): The tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
src (int): The source rank in global view.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,7 @@ def broadcast_tensors(input, name=None):
Args:
input (list|tuple): ``input`` is a Tensor list or Tensor tuple which is with data type bool,
float16, float32, float64, int32, int64. All the Tensors in ``input`` must have same data type.
float16, float32, float64, int32, int64, complex64, complex128. All the Tensors in ``input`` must have same data type.
Currently we only support tensors with rank no greater than 5.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -1390,6 +1390,8 @@ def broadcast_tensors(input, name=None):
'int32',
'int64',
'uint16',
'complex64',
'complex128',
],
'broadcast_tensors',
)
Expand Down Expand Up @@ -3037,7 +3039,7 @@ def unbind(input, axis=0):
Removes a tensor dimension, then split the input tensor into multiple sub-Tensors.
Args:
input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32 or int64.
input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32, int64, complex64 or complex128.
axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind.
If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
Returns:
Expand Down Expand Up @@ -3094,6 +3096,8 @@ def unbind(input, axis=0):
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'unbind',
)
Expand Down
48 changes: 48 additions & 0 deletions test/cpp/phi/kernels/test_ternary_broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ TEST(Broadcast, add) {
dim_out,
times,
AddTernary_1<phi::dtype::bfloat16>());
TestCase<phi::dtype::complex<float>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_1<phi::dtype::complex<float>>());
TestCase<phi::dtype::complex<double>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_1<phi::dtype::complex<double>>());
} while (0);

do {
Expand All @@ -145,6 +161,22 @@ TEST(Broadcast, add) {
dim_out,
times,
AddTernary_2<phi::dtype::bfloat16>());
TestCase<phi::dtype::complex<float>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_2<phi::dtype::complex<float>>());
TestCase<phi::dtype::complex<double>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_2<phi::dtype::complex<double>>());
} while (0);

do {
Expand All @@ -168,6 +200,22 @@ TEST(Broadcast, add) {
dim_out,
times,
AddTernary_3<phi::dtype::bfloat16>());
TestCase<phi::dtype::complex<float>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_3<phi::dtype::complex<float>>());
TestCase<phi::dtype::complex<double>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_3<phi::dtype::complex<double>>());
} while (0);
#endif
}
2 changes: 2 additions & 0 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,8 @@ def infer_dtype(numpy_dict, dtype_set):
input_dtype_set = set()
infer_dtype(inputs, input_dtype_set)
dtype_list = [
np.dtype(np.complex128),
np.dtype(np.complex64),
np.dtype(np.float64),
np.dtype(np.float32),
np.dtype(np.float16),
Expand Down
20 changes: 19 additions & 1 deletion test/legacy_test/test_broadcast_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
class TestBroadcastOpCpu(OpTest):
def setUp(self):
self.op_type = "broadcast"
input = np.random.random((100, 2)).astype("float32")
self.init_dtype()
input = np.random.random((100, 2)).astype(self.dtype)
if self.dtype == 'complex64' or self.dtype == 'complex128':
input = (
np.random.random((100, 2)) + 1j * np.random.random((100, 2))
).astype(self.dtype)
np_out = input[:]
self.inputs = {"X": input}
self.attrs = {"sync_mode": False, "root": 0}
Expand All @@ -35,6 +40,19 @@ def test_check_output_cpu(self):
except:
print("do not support cpu test, skip")

def init_dtype(self):
self.dtype = 'float32'


class TestBroadcastOpCpu_complex64(TestBroadcastOpCpu):
def init_dtype(self):
self.dtype = 'complex64'


class TestBroadcastOpCpu_complex128(TestBroadcastOpCpu):
def init_dtype(self):
self.dtype = 'complex128'


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 5e89708

Please sign in to comment.