diff --git a/ivy/array/linear_algebra.py b/ivy/array/linear_algebra.py index c65d4183bebd4..8ca5b248a1c7f 100644 --- a/ivy/array/linear_algebra.py +++ b/ivy/array/linear_algebra.py @@ -714,9 +714,10 @@ def solve( x2: Union[ivy.Array, ivy.NativeArray], /, *, + adjoint: bool = False, out: Optional[ivy.Array] = None, ) -> ivy.Array: - return ivy.solve(self._data, x2, out=out) + return ivy.solve(self._data, x2, adjoint=adjoint, out=out) def svd( self: ivy.Array, diff --git a/ivy/container/linear_algebra.py b/ivy/container/linear_algebra.py index 3abb14ec7c97b..1e2637f552eeb 100644 --- a/ivy/container/linear_algebra.py +++ b/ivy/container/linear_algebra.py @@ -2099,6 +2099,7 @@ def static_solve( x2: Union[ivy.Array, ivy.NativeArray, ivy.Container], /, *, + adjoint: bool = False, key_chains: Optional[Union[List[str], Dict[str, str]]] = None, to_apply: bool = True, prune_unapplied: bool = False, @@ -2109,6 +2110,7 @@ def static_solve( "solve", x1, x2, + adjoint=adjoint, key_chains=key_chains, to_apply=to_apply, prune_unapplied=prune_unapplied, @@ -2121,6 +2123,7 @@ def solve( x2: Union[ivy.Container, ivy.Array, ivy.NativeArray], /, *, + adjoint: bool = False, key_chains: Optional[Union[List[str], Dict[str, str]]] = None, to_apply: bool = True, prune_unapplied: bool = False, @@ -2130,6 +2133,7 @@ def solve( return self.static_solve( self, x2, + adjoint=adjoint, key_chains=key_chains, to_apply=to_apply, prune_unapplied=prune_unapplied, diff --git a/ivy/functional/backends/jax/linear_algebra.py b/ivy/functional/backends/jax/linear_algebra.py index 0fd93488e8b60..b7d91875ddc8d 100644 --- a/ivy/functional/backends/jax/linear_algebra.py +++ b/ivy/functional/backends/jax/linear_algebra.py @@ -172,13 +172,13 @@ def matmul( adjoint_b: bool = False, out: Optional[JaxArray] = None, ) -> JaxArray: - if transpose_a is True: + if transpose_a: x1 = jnp.transpose(x1) - if transpose_b is True: + if transpose_b: x2 = jnp.transpose(x2) - if adjoint_a is True: + if adjoint_a: x1 = jnp.transpose(jnp.conjugate(x1)) - if adjoint_b is True: + if adjoint_b: x2 = jnp.transpose(jnp.conjugate(x2)) return jnp.matmul(x1, x2) @@ -346,7 +346,16 @@ def slogdet( {"0.3.14 and below": ("bfloat16", "float16", "complex")}, backend_version, ) -def solve(x1: JaxArray, x2: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def solve( + x1: JaxArray, + x2: JaxArray, + /, + *, + adjoint: bool = False, + out: Optional[JaxArray] = None +) -> JaxArray: + if adjoint: + x1 = jnp.transpose(jnp.conjugate(x1)) expanded_last = False x1, x2 = ivy.promote_types_of_inputs(x1, x2) if len(x2.shape) <= 1: diff --git a/ivy/functional/backends/numpy/linear_algebra.py b/ivy/functional/backends/numpy/linear_algebra.py index b62147766dd56..7c4d25b4ac423 100644 --- a/ivy/functional/backends/numpy/linear_algebra.py +++ b/ivy/functional/backends/numpy/linear_algebra.py @@ -125,13 +125,13 @@ def matmul( adjoint_b: bool = False, out: Optional[np.ndarray] = None, ) -> np.ndarray: - if transpose_a is True: + if transpose_a: x1 = np.transpose(x1) - if transpose_b is True: + if transpose_b: x2 = np.transpose(x2) - if adjoint_a is True: + if adjoint_a: x1 = np.transpose(np.conjugate(x1)) - if adjoint_b is True: + if adjoint_b: x2 = np.transpose(np.conjugate(x2)) ret = np.matmul(x1, x2, out=out) if len(x1.shape) == len(x2.shape) == 1: @@ -302,8 +302,15 @@ def slogdet( @with_unsupported_dtypes({"1.23.0 and below": ("float16",)}, backend_version) def solve( - x1: np.ndarray, x2: np.ndarray, /, *, out: Optional[np.ndarray] = None + x1: np.ndarray, + x2: np.ndarray, + /, + *, + adjoint: bool = False, + out: Optional[np.ndarray] = None ) -> np.ndarray: + if adjoint: + x1 = np.transpose(np.conjugate(x1)) expanded_last = False x1, x2 = ivy.promote_types_of_inputs(x1, x2) if len(x2.shape) <= 1: diff --git a/ivy/functional/backends/tensorflow/linear_algebra.py b/ivy/functional/backends/tensorflow/linear_algebra.py index 99d7b270ea6d6..48fef7982dff2 100644 --- a/ivy/functional/backends/tensorflow/linear_algebra.py +++ b/ivy/functional/backends/tensorflow/linear_algebra.py @@ -221,14 +221,14 @@ def matmul( x1, x2 = ivy.promote_types_of_inputs(x1, x2) dtype_from = tf.as_dtype(x1.dtype) - if transpose_a is True: + if transpose_a: x1 = tf.transpose(x1) - if transpose_b is True: + if transpose_b: x2 = tf.transpose(x2) - if adjoint_a is True: + if adjoint_a: x1 = tf.linalg.adjoint(x1) - if adjoint_b is True: + if adjoint_b: x2 = tf.linalg.adjoint(x2) if dtype_from.is_unsigned or dtype_from == tf.int8 or dtype_from == tf.int16: @@ -554,8 +554,11 @@ def solve( x2: Union[tf.Tensor, tf.Variable], /, *, + adjoint: bool = False, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: + if adjoint: + x1 = tf.linalg.adjoint(x1) x1, x2 = ivy.promote_types_of_inputs(x1, x2) expanded_last = False if len(x2.shape) <= 1: diff --git a/ivy/functional/backends/torch/linear_algebra.py b/ivy/functional/backends/torch/linear_algebra.py index 6a6fa13982d37..7867863e38473 100644 --- a/ivy/functional/backends/torch/linear_algebra.py +++ b/ivy/functional/backends/torch/linear_algebra.py @@ -169,13 +169,13 @@ def matmul( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if transpose_a is True: + if transpose_a: x1 = torch.t(x1) - if transpose_b is True: + if transpose_b: x2 = torch.t(x2) - if adjoint_a is True: + if adjoint_a: x1 = torch.adjoint(x1) - if adjoint_b is True: + if adjoint_b: x2 = torch.adjoint(x2) x1, x2 = ivy.promote_types_of_inputs(x1, x2) return torch.matmul(x1, x2, out=out) @@ -331,8 +331,11 @@ def solve( x2: torch.Tensor, /, *, + adjoint: bool = False, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if adjoint: + x1 = torch.adjoint(x1) x1, x2 = ivy.promote_types_of_inputs(x1, x2) expanded_last = False if len(x2.shape) <= 1: diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py index a393b5d7e0719..c08744a7237cb 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py @@ -669,7 +669,7 @@ def test_slogdet( # solve @st.composite -def _get_first_matrix(draw): +def _get_first_matrix(draw, adjoint=True): # batch_shape, random_size, shared # float16 causes a crash when filtering out matrices @@ -685,7 +685,7 @@ def _get_first_matrix(draw): shared_size = draw( st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") ) - return input_dtype, draw( + matrix = draw( helpers.array_values( dtype=input_dtype, shape=tuple([shared_size, shared_size]), @@ -693,6 +693,11 @@ def _get_first_matrix(draw): max_value=5, ).filter(lambda x: np.linalg.cond(x) < 1 / sys.float_info.epsilon) ) + if adjoint: + adjoint = draw(st.booleans()) + if adjoint: + matrix = np.transpose(np.conjugate(matrix)) + return input_dtype, matrix, adjoint @st.composite @@ -720,7 +725,7 @@ def _get_second_matrix(draw): @handle_test( fn_tree="functional.ivy.solve", - x=_get_first_matrix(), + x=_get_first_matrix(adjoint=True), y=_get_second_matrix(), ) def test_solve( @@ -733,7 +738,7 @@ def test_solve( on_device, ground_truth_backend, ): - input_dtype1, x1 = x + input_dtype1, x1, adjoint = x input_dtype2, x2 = y helpers.test_function( ground_truth_backend=ground_truth_backend, @@ -746,6 +751,7 @@ def test_solve( atol_=1e-1, x1=x1, x2=x2, + adjoint=adjoint, )