Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op】 add complex support for unbind,broadcast,broadcast_tensors and broadcast_tensor_grad #59122

4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/broadcast_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ PD_REGISTER_KERNEL(broadcast,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,6 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/broadcast_tensors_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(broadcast_tensors,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/unbind_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(unbind,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
8 changes: 6 additions & 2 deletions paddle/phi/kernels/gpu/broadcast_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ PD_REGISTER_KERNEL(broadcast,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#else
PD_REGISTER_KERNEL(broadcast,
GPU,
Expand All @@ -79,5 +81,7 @@ PD_REGISTER_KERNEL(broadcast,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,6 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ PD_REGISTER_KERNEL(broadcast_tensors,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/unbind_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(unbind,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 1 addition & 1 deletion python/paddle/distributed/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def broadcast(tensor, src, group=None, sync_op=True):
Args:
tensor (Tensor): The tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
src (int): The source rank in global view.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def broadcast_tensors(input, name=None):

Args:
input (list|tuple): ``input`` is a Tensor list or Tensor tuple which is with data type bool,
float16, float32, float64, int32, int64. All the Tensors in ``input`` must have same data type.
float16, float32, float64, int32, int64, complex64, complex128. All the Tensors in ``input`` must have same data type.
Currently we only support tensors with rank no greater than 5.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -1371,6 +1371,8 @@ def broadcast_tensors(input, name=None):
'int32',
'int64',
'uint16',
'complex64',
'complex128',
],
'broadcast_tensors',
)
Expand Down Expand Up @@ -3018,7 +3020,7 @@ def unbind(input, axis=0):
Removes a tensor dimension, then split the input tensor into multiple sub-Tensors.

Args:
input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32 or int64.
input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32, int64, complex64 or complex128.
axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind.
If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
Returns:
Expand Down Expand Up @@ -3075,6 +3077,8 @@ def unbind(input, axis=0):
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'unbind',
)
Expand Down
48 changes: 48 additions & 0 deletions test/cpp/phi/kernels/test_ternary_broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ TEST(Broadcast, add) {
dim_out,
times,
AddTernary_1<phi::dtype::bfloat16>());
TestCase<phi::dtype::complex<float>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_1<phi::dtype::complex<float>>());
TestCase<phi::dtype::complex<double>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_1<phi::dtype::complex<double>>());
} while (0);

do {
Expand All @@ -145,6 +161,22 @@ TEST(Broadcast, add) {
dim_out,
times,
AddTernary_2<phi::dtype::bfloat16>());
TestCase<phi::dtype::complex<float>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_2<phi::dtype::complex<float>>());
TestCase<phi::dtype::complex<double>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_2<phi::dtype::complex<double>>());
} while (0);

do {
Expand All @@ -168,6 +200,22 @@ TEST(Broadcast, add) {
dim_out,
times,
AddTernary_3<phi::dtype::bfloat16>());
TestCase<phi::dtype::complex<float>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_3<phi::dtype::complex<float>>());
TestCase<phi::dtype::complex<double>>(
*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_3<phi::dtype::complex<double>>());
} while (0);
#endif
}
2 changes: 2 additions & 0 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,8 @@ def infer_dtype(numpy_dict, dtype_set):
input_dtype_set = set()
infer_dtype(inputs, input_dtype_set)
dtype_list = [
np.dtype(np.complex128),
np.dtype(np.complex64),
np.dtype(np.float64),
np.dtype(np.float32),
np.dtype(np.float16),
Expand Down
20 changes: 19 additions & 1 deletion test/legacy_test/test_broadcast_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
class TestBroadcastOpCpu(OpTest):
def setUp(self):
self.op_type = "broadcast"
input = np.random.random((100, 2)).astype("float32")
self.init_dtype()
input = np.random.random((100, 2)).astype(self.dtype)
if self.dtype == 'complex64' or self.dtype == 'complex128':
input = (
np.random.random((100, 2)) + 1j * np.random.random((100, 2))
).astype(self.dtype)
np_out = input[:]
self.inputs = {"X": input}
self.attrs = {"sync_mode": False, "root": 0}
Expand All @@ -35,6 +40,19 @@ def test_check_output_cpu(self):
except:
print("do not support cpu test, skip")

def init_dtype(self):
self.dtype = 'float32'


class TestBroadcastOpCpu_complex64(TestBroadcastOpCpu):
def init_dtype(self):
self.dtype = 'complex64'


class TestBroadcastOpCpu_complex128(TestBroadcastOpCpu):
def init_dtype(self):
self.dtype = 'complex128'


if __name__ == "__main__":
unittest.main()
Loading