diff --git a/paddle/phi/kernels/cpu/broadcast_kernel.cc b/paddle/phi/kernels/cpu/broadcast_kernel.cc index baa12d1815edc7..0deb8d8bbc5627 100644 --- a/paddle/phi/kernels/cpu/broadcast_kernel.cc +++ b/paddle/phi/kernels/cpu/broadcast_kernel.cc @@ -62,4 +62,6 @@ PD_REGISTER_KERNEL(broadcast, int8_t, uint8_t, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc b/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc index 0656f681367ffc..8f73c5c5f5f6e8 100644 --- a/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc @@ -199,4 +199,6 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/broadcast_tensors_kernel.cc b/paddle/phi/kernels/cpu/broadcast_tensors_kernel.cc index 3ad26164d7d8da..7d0e08655fc275 100644 --- a/paddle/phi/kernels/cpu/broadcast_tensors_kernel.cc +++ b/paddle/phi/kernels/cpu/broadcast_tensors_kernel.cc @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(broadcast_tensors, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/unbind_kernel.cc b/paddle/phi/kernels/cpu/unbind_kernel.cc index e8d0c01352c97c..255f73af1aca75 100644 --- a/paddle/phi/kernels/cpu/unbind_kernel.cc +++ b/paddle/phi/kernels/cpu/unbind_kernel.cc @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(unbind, phi::dtype::float16, phi::dtype::bfloat16, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/broadcast_kernel.cu b/paddle/phi/kernels/gpu/broadcast_kernel.cu index 4b46e218c328e1..e4986f752b1aec 100644 --- a/paddle/phi/kernels/gpu/broadcast_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_kernel.cu @@ -66,7 +66,9 @@ PD_REGISTER_KERNEL(broadcast, int8_t, uint8_t, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #else PD_REGISTER_KERNEL(broadcast, GPU, @@ -79,5 +81,7 @@ PD_REGISTER_KERNEL(broadcast, int8_t, uint8_t, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #endif diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu index 40ea1f195069e9..1c56b93c7c1dce 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu @@ -111,4 +111,6 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu index 3d16797cb66c09..aae7d53aeb43ab 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu @@ -28,4 +28,6 @@ PD_REGISTER_KERNEL(broadcast_tensors, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/unbind_kernel.cu b/paddle/phi/kernels/gpu/unbind_kernel.cu index 37272cebdf1188..178191f048e30d 100644 --- a/paddle/phi/kernels/gpu/unbind_kernel.cu +++ b/paddle/phi/kernels/gpu/unbind_kernel.cu @@ -27,4 +27,6 @@ PD_REGISTER_KERNEL(unbind, phi::dtype::float16, phi::dtype::bfloat16, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/distributed/communication/broadcast.py b/python/paddle/distributed/communication/broadcast.py index 208158cd209182..9c87e0345db5f4 100644 --- a/python/paddle/distributed/communication/broadcast.py +++ b/python/paddle/distributed/communication/broadcast.py @@ -37,7 +37,7 @@ def broadcast(tensor, src, group=None, sync_op=True): Args: tensor (Tensor): The tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type - should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. + should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128. src (int): The source rank in global view. group (Group, optional): The group instance return by new_group or None for global default group. sync_op (bool, optional): Whether this op is a sync op. The default value is True. diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5bec599390fdb9..a0ed76a3b970da 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1349,7 +1349,7 @@ def broadcast_tensors(input, name=None): Args: input (list|tuple): ``input`` is a Tensor list or Tensor tuple which is with data type bool, - float16, float32, float64, int32, int64. All the Tensors in ``input`` must have same data type. + float16, float32, float64, int32, int64, complex64, complex128. All the Tensors in ``input`` must have same data type. Currently we only support tensors with rank no greater than 5. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1390,6 +1390,8 @@ def broadcast_tensors(input, name=None): 'int32', 'int64', 'uint16', + 'complex64', + 'complex128', ], 'broadcast_tensors', ) @@ -3037,7 +3039,7 @@ def unbind(input, axis=0): Removes a tensor dimension, then split the input tensor into multiple sub-Tensors. Args: - input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32 or int64. + input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32, int64, complex64 or complex128. axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0. Returns: @@ -3094,6 +3096,8 @@ def unbind(input, axis=0): 'float64', 'int32', 'int64', + 'complex64', + 'complex128', ], 'unbind', ) diff --git a/test/cpp/phi/kernels/test_ternary_broadcast.cu b/test/cpp/phi/kernels/test_ternary_broadcast.cu index 09598e637909aa..959b79725f07ae 100644 --- a/test/cpp/phi/kernels/test_ternary_broadcast.cu +++ b/test/cpp/phi/kernels/test_ternary_broadcast.cu @@ -122,6 +122,22 @@ TEST(Broadcast, add) { dim_out, times, AddTernary_1()); + TestCase>( + *dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_1>()); + TestCase>( + *dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_1>()); } while (0); do { @@ -145,6 +161,22 @@ TEST(Broadcast, add) { dim_out, times, AddTernary_2()); + TestCase>( + *dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_2>()); + TestCase>( + *dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_2>()); } while (0); do { @@ -168,6 +200,22 @@ TEST(Broadcast, add) { dim_out, times, AddTernary_3()); + TestCase>( + *dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_3>()); + TestCase>( + *dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_3>()); } while (0); #endif } diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 79a289f65890ed..759d76dabef9a9 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -671,6 +671,8 @@ def infer_dtype(numpy_dict, dtype_set): input_dtype_set = set() infer_dtype(inputs, input_dtype_set) dtype_list = [ + np.dtype(np.complex128), + np.dtype(np.complex64), np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16), diff --git a/test/legacy_test/test_broadcast_error.py b/test/legacy_test/test_broadcast_error.py index d42cc6d9b88401..e5defec467d9f7 100644 --- a/test/legacy_test/test_broadcast_error.py +++ b/test/legacy_test/test_broadcast_error.py @@ -23,7 +23,12 @@ class TestBroadcastOpCpu(OpTest): def setUp(self): self.op_type = "broadcast" - input = np.random.random((100, 2)).astype("float32") + self.init_dtype() + input = np.random.random((100, 2)).astype(self.dtype) + if self.dtype == 'complex64' or self.dtype == 'complex128': + input = ( + np.random.random((100, 2)) + 1j * np.random.random((100, 2)) + ).astype(self.dtype) np_out = input[:] self.inputs = {"X": input} self.attrs = {"sync_mode": False, "root": 0} @@ -35,6 +40,19 @@ def test_check_output_cpu(self): except: print("do not support cpu test, skip") + def init_dtype(self): + self.dtype = 'float32' + + +class TestBroadcastOpCpu_complex64(TestBroadcastOpCpu): + def init_dtype(self): + self.dtype = 'complex64' + + +class TestBroadcastOpCpu_complex128(TestBroadcastOpCpu): + def init_dtype(self): + self.dtype = 'complex128' + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_broadcast_tensors_op.py b/test/legacy_test/test_broadcast_tensors_op.py index 9f5b7b76caacbe..d8de6e1bba8a8b 100644 --- a/test/legacy_test/test_broadcast_tensors_op.py +++ b/test/legacy_test/test_broadcast_tensors_op.py @@ -47,7 +47,10 @@ def find_output_shape(input_list): def make_inputs_outputs(input_shapes, dtype, is_bfloat16=False): """Automatically generate formatted inputs and outputs from input_shapes""" input_list = [ - np.random.random(shape).astype(dtype) for shape in input_shapes + (np.random.random(shape) + 1j * np.random.random(shape)).astype(dtype) + if dtype == 'complex64' or dtype == 'complex128' + else np.random.random(shape).astype(dtype) + for shape in input_shapes ] output_shape = find_output_shape(input_list) output_list = [ @@ -98,8 +101,8 @@ class TestCPUBroadcastTensorsOp(OpTest): def set_place(self): self.place = core.CPUPlace() - def set_dtypes(self): - self.dtypes = ['float64'] + def set_dtype(self): + self.dtype = 'float64' def setUp(self): self.op_type = "broadcast_tensors" @@ -112,26 +115,24 @@ def setUp(self): gen_empty_tensors_test, ] self.set_place() - self.set_dtypes() + self.set_dtype() self.python_api = paddle.broadcast_tensors def run_dual_test(self, test_func, args): - for dtype in self.dtypes: - for gen_func in self.test_gen_func_list: - self.inputs, self.outputs = gen_func(dtype) - if len(self.outputs["Out"]) < 3: - self.python_out_sig = [ - f"out{i}" for i in range(len(self.outputs["Out"])) - ] - test_func(**args) + for gen_func in self.test_gen_func_list: + self.inputs, self.outputs = gen_func(self.dtype) + if len(self.outputs["Out"]) < 3: + self.python_out_sig = [ + f"out{i}" for i in range(len(self.outputs["Out"])) + ] + test_func(**args) def run_triple_in_test(self, test_func, args): - for dtype in self.dtypes: - self.inputs, self.outputs = self.test_gen_func_list[2](dtype) - self.python_out_sig = [ - f"out{i}" for i in range(len(self.outputs["Out"])) - ] - test_func(**args) + self.inputs, self.outputs = self.test_gen_func_list[2](self.dtype) + self.python_out_sig = [ + f"out{i}" for i in range(len(self.outputs["Out"])) + ] + test_func(**args) def test_check_output(self): self.run_dual_test( @@ -160,6 +161,16 @@ def test_check_grad_normal(self): ) +class TestCPUBroadcastTensorsOp_complex64(TestCPUBroadcastTensorsOp): + def set_dtypes(self): + self.dtype = 'complex64' + + +class TestCPUBroadcastTensorsOp_complex128(TestCPUBroadcastTensorsOp): + def set_dtypes(self): + self.dtype = 'complex128' + + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) @@ -239,28 +250,44 @@ def test_check_grad_normal(self): class TestBroadcastTensorsAPI(unittest.TestCase): + def setUp(self): + self.dtype = 'float32' + def test_api(self): @test_with_pir_api def test_static(): - inputs = [ - paddle.static.data( - shape=[-1, 4, 1, 4, 1], dtype='float32', name="x0" - ), - paddle.static.data( - shape=[-1, 1, 4, 1, 4], dtype='float32', name="x1" - ), - ] - paddle.broadcast_tensors(inputs) + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + inputs = [ + paddle.static.data( + shape=[-1, 4, 1, 4, 1], dtype=self.dtype, name="x0" + ), + paddle.static.data( + shape=[-1, 1, 4, 1, 4], dtype=self.dtype, name="x1" + ), + ] + paddle.broadcast_tensors(inputs) def test_dynamic(): paddle.disable_static() try: inputs = [ paddle.to_tensor( - np.random.random([4, 1, 4, 1]).astype("float32") + np.random.random([4, 1, 4, 1]).astype(self.dtype) + if self.dtype == 'float32' + else ( + np.random.random([4, 1, 4, 1]) + + 1j * np.random.random([4, 1, 4, 1]) + ).astype(self.dtype) ), paddle.to_tensor( - np.random.random([1, 4, 1, 4]).astype("float32") + np.random.random([1, 4, 1, 4]).astype(self.dtype) + if self.dtype == 'float32' + else ( + np.random.random([1, 4, 1, 4]) + + 1j * np.random.random([1, 4, 1, 4]) + ).astype(self.dtype) ), ] paddle.broadcast_tensors(inputs) @@ -271,6 +298,16 @@ def test_dynamic(): test_dynamic() +class TestBroadcastTensorsAPI_complex64(TestBroadcastTensorsAPI): + def setUp(self): + self.dtype = 'complex64' + + +class TestBroadcastTensorsAPI_complex128(TestBroadcastTensorsAPI): + def setUp(self): + self.dtype = 'complex128' + + class TestRaiseBroadcastTensorsError(unittest.TestCase): def test_errors(self): def test_type(): @@ -306,9 +343,21 @@ def test_bcast_semantics(): ] paddle.broadcast_tensors(inputs) + def test_bcast_semantics_complex64(): + inputs = [ + paddle.static.data( + shape=[-1, 1, 3, 1, 1], dtype='complex64', name="x11" + ), + paddle.static.data( + shape=[-1, 1, 8, 1, 1], dtype='complex64', name="x12" + ), + ] + paddle.broadcast_tensors(inputs) + self.assertRaises(TypeError, test_type) self.assertRaises(TypeError, test_dtype) self.assertRaises(TypeError, test_bcast_semantics) + self.assertRaises(TypeError, test_bcast_semantics_complex64) class TestRaiseBroadcastTensorsErrorDyGraph(unittest.TestCase): diff --git a/test/legacy_test/test_unbind_op.py b/test/legacy_test/test_unbind_op.py index c01858c06ad5e8..833ffd824bb13c 100644 --- a/test/legacy_test/test_unbind_op.py +++ b/test/legacy_test/test_unbind_op.py @@ -24,28 +24,38 @@ class TestUnbind(unittest.TestCase): + def setUp(self): + self.init_dtype() + self.input_1 = np.random.random([2, 3]).astype(self.dtype) + if self.dtype == 'complex64' or self.dtype == 'complex128': + self.input_1 = ( + np.random.random([2, 3]) + 1j * np.random.random([2, 3]) + ).astype(self.dtype) + + def init_dtype(self): + self.dtype = 'float32' + @test_with_pir_api def test_unbind(self): paddle.enable_static() - + self.init_dtype() main_program = static.Program() startup_program = static.Program() with static.program_guard( main_program=main_program, startup_program=startup_program ): - x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1') + x_1 = paddle.static.data(shape=[2, 3], dtype=self.dtype, name='x_1') [out_0, out_1] = tensor.unbind(input=x_1, axis=0) - input_1 = np.random.random([2, 3]).astype("float32") axis = paddle.static.data(shape=[], dtype='int32', name='axis') exe = base.Executor(place=base.CPUPlace()) [res_1, res_2] = exe.run( - feed={"x_1": input_1, "axis": 0}, + feed={"x_1": self.input_1, "axis": 0}, fetch_list=[out_0, out_1], ) - np.testing.assert_array_equal(res_1, input_1[0, 0:100]) - np.testing.assert_array_equal(res_2, input_1[1, 0:100]) + np.testing.assert_array_equal(res_1, self.input_1[0, 0:100]) + np.testing.assert_array_equal(res_2, self.input_1[1, 0:100]) @test_with_pir_api def test_unbind_static_fp16_gpu(self): @@ -73,38 +83,74 @@ def test_unbind_static_fp16_gpu(self): def test_unbind_dygraph(self): with base.dygraph.guard(): - np_x = np.random.random([2, 3]).astype("float32") - x = paddle.to_tensor(np_x) + x = paddle.to_tensor(self.input_1) x.stop_gradient = False [res_1, res_2] = paddle.unbind(x, 0) - np.testing.assert_array_equal(res_1, np_x[0, 0:100]) - np.testing.assert_array_equal(res_2, np_x[1, 0:100]) + np.testing.assert_array_equal(res_1, self.input_1[0, 0:100]) + np.testing.assert_array_equal(res_2, self.input_1[1, 0:100]) out = paddle.add_n([res_1, res_2]) - np_grad = np.ones(x.shape, np.float32) + np_grad = np.ones(x.shape, self.dtype) out.backward() np.testing.assert_array_equal(x.grad.numpy(False), np_grad) +class TestUnbind_complex64(TestUnbind): + def init_dtype(self): + self.dtype = 'complex64' + + def test_unbind_static_fp16_gpu(self): + pass + + +class TestUnbind_complex128(TestUnbind): + def init_dtype(self): + self.dtype = 'complex128' + + def test_unbind_static_fp16_gpu(self): + pass + + class TestLayersUnbind(unittest.TestCase): + def setUp(self): + self.init_dtype() + self.input_1 = np.random.random([2, 3]).astype(self.dtype) + if self.dtype == 'complex64' or self.dtype == 'complex128': + self.input_1 = ( + np.random.random([2, 3]) + 1j * np.random.random([2, 3]) + ).astype(self.dtype) + + def init_dtype(self): + self.dtype = 'float32' + @test_with_pir_api def test_layers_unbind(self): paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + x_1 = paddle.static.data(shape=[2, 3], dtype=self.dtype, name='x_1') + [out_0, out_1] = paddle.unbind(input=x_1, axis=0) + axis = paddle.static.data(shape=[], dtype='int32', name='axis') + exe = base.Executor(place=base.CPUPlace()) + [res_1, res_2] = exe.run( + feed={"x_1": self.input_1, "axis": 0}, + fetch_list=[out_0, out_1], + ) - x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1') - [out_0, out_1] = paddle.unbind(input=x_1, axis=0) - input_1 = np.random.random([2, 3]).astype("float32") - axis = paddle.static.data(shape=[], dtype='int32', name='axis') - exe = base.Executor(place=base.CPUPlace()) + np.testing.assert_array_equal(res_1, self.input_1[0, 0:100]) + np.testing.assert_array_equal(res_2, self.input_1[1, 0:100]) - [res_1, res_2] = exe.run( - feed={"x_1": input_1, "axis": 0}, - fetch_list=[out_0, out_1], - ) - np.testing.assert_array_equal(res_1, input_1[0, 0:100]) - np.testing.assert_array_equal(res_2, input_1[1, 0:100]) +class TestLayersUnbind_complex64(TestLayersUnbind): + def init_dtype(self): + self.dtype = 'complex64' + + +class TestLayersUnbind_complex128(TestLayersUnbind): + def init_dtype(self): + self.dtype = 'complex128' class TestUnbindOp(OpTest): @@ -126,6 +172,11 @@ def setUp(self): self.num = 3 self.initParameters() x = np.arange(12).reshape(3, 2, 2).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.arange(12).reshape(3, 2, 2) + + 1j * np.arange(12).reshape(3, 2, 2) + ).astype(self.dtype) self.out = np.split(x, self.num, self.axis) self.outReshape() self.inputs = {'X': x} @@ -208,6 +259,46 @@ def outReshape(self): self.out[1] = self.out[1].reshape((3, 2)) +class TestUnbindOp1_Complex64(TestUnbindOp1): + def get_dtype(self): + return np.complex64 + + +class TestUnbindOp2_Complex64(TestUnbindOp2): + def get_dtype(self): + return np.complex64 + + +class TestUnbindOp3_Complex64(TestUnbindOp3): + def get_dtype(self): + return np.complex64 + + +class TestUnbindOp4_Complex64(TestUnbindOp4): + def get_dtype(self): + return np.complex64 + + +class TestUnbindOp1_Complex128(TestUnbindOp1): + def get_dtype(self): + return np.complex128 + + +class TestUnbindOp2_Complex128(TestUnbindOp2): + def get_dtype(self): + return np.complex128 + + +class TestUnbindOp3_Complex128(TestUnbindOp3): + def get_dtype(self): + return np.complex128 + + +class TestUnbindOp4_Complex128(TestUnbindOp4): + def get_dtype(self): + return np.complex128 + + class TestUnbindFP16Op(OpTest): def setUp(self): paddle.disable_static() @@ -278,10 +369,15 @@ def test_check_grad(self): class TestUnbindAxisError(unittest.TestCase): + def setUp(self): + self.dtype = 'float32' + @test_with_pir_api def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): - x = paddle.static.data(shape=[2, 3], dtype='float32', name='x') + x = paddle.static.data(shape=[2, 3], dtype=self.dtype, name='x') def test_table_Variable(): tensor.unbind(input=x, axis=2.0) @@ -294,6 +390,16 @@ def test_invalid_axis(): self.assertRaises(ValueError, test_invalid_axis) +class TestUnbindAxisError_complex64(TestUnbindAxisError): + def setUp(self): + self.dtype = 'complex64' + + +class TestUnbindAxisError_complex128(TestUnbindAxisError): + def setUp(self): + self.dtype = 'complex128' + + class TestUnbindBool(unittest.TestCase): def test_bool(self): x = paddle.to_tensor([[True, True], [False, False]])