Skip to content

Commit

Permalink
Adding superset adjoint to solve (ivy-llc#10611)
Browse files Browse the repository at this point in the history
  • Loading branch information
zaeemansari70 authored Feb 17, 2023
1 parent 788ab22 commit 63b6919
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 23 deletions.
3 changes: 2 additions & 1 deletion ivy/array/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,10 @@ def solve(
x2: Union[ivy.Array, ivy.NativeArray],
/,
*,
adjoint: bool = False,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
return ivy.solve(self._data, x2, out=out)
return ivy.solve(self._data, x2, adjoint=adjoint, out=out)

def svd(
self: ivy.Array,
Expand Down
4 changes: 4 additions & 0 deletions ivy/container/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,7 @@ def static_solve(
x2: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
adjoint: bool = False,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand All @@ -2109,6 +2110,7 @@ def static_solve(
"solve",
x1,
x2,
adjoint=adjoint,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
Expand All @@ -2121,6 +2123,7 @@ def solve(
x2: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
adjoint: bool = False,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand All @@ -2130,6 +2133,7 @@ def solve(
return self.static_solve(
self,
x2,
adjoint=adjoint,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
Expand Down
19 changes: 14 additions & 5 deletions ivy/functional/backends/jax/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,13 @@ def matmul(
adjoint_b: bool = False,
out: Optional[JaxArray] = None,
) -> JaxArray:
if transpose_a is True:
if transpose_a:
x1 = jnp.transpose(x1)
if transpose_b is True:
if transpose_b:
x2 = jnp.transpose(x2)
if adjoint_a is True:
if adjoint_a:
x1 = jnp.transpose(jnp.conjugate(x1))
if adjoint_b is True:
if adjoint_b:
x2 = jnp.transpose(jnp.conjugate(x2))
return jnp.matmul(x1, x2)

Expand Down Expand Up @@ -346,7 +346,16 @@ def slogdet(
{"0.3.14 and below": ("bfloat16", "float16", "complex")},
backend_version,
)
def solve(x1: JaxArray, x2: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def solve(
x1: JaxArray,
x2: JaxArray,
/,
*,
adjoint: bool = False,
out: Optional[JaxArray] = None
) -> JaxArray:
if adjoint:
x1 = jnp.transpose(jnp.conjugate(x1))
expanded_last = False
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
if len(x2.shape) <= 1:
Expand Down
17 changes: 12 additions & 5 deletions ivy/functional/backends/numpy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def matmul(
adjoint_b: bool = False,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if transpose_a is True:
if transpose_a:
x1 = np.transpose(x1)
if transpose_b is True:
if transpose_b:
x2 = np.transpose(x2)
if adjoint_a is True:
if adjoint_a:
x1 = np.transpose(np.conjugate(x1))
if adjoint_b is True:
if adjoint_b:
x2 = np.transpose(np.conjugate(x2))
ret = np.matmul(x1, x2, out=out)
if len(x1.shape) == len(x2.shape) == 1:
Expand Down Expand Up @@ -302,8 +302,15 @@ def slogdet(

@with_unsupported_dtypes({"1.23.0 and below": ("float16",)}, backend_version)
def solve(
x1: np.ndarray, x2: np.ndarray, /, *, out: Optional[np.ndarray] = None
x1: np.ndarray,
x2: np.ndarray,
/,
*,
adjoint: bool = False,
out: Optional[np.ndarray] = None
) -> np.ndarray:
if adjoint:
x1 = np.transpose(np.conjugate(x1))
expanded_last = False
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
if len(x2.shape) <= 1:
Expand Down
11 changes: 7 additions & 4 deletions ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,14 @@ def matmul(
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
dtype_from = tf.as_dtype(x1.dtype)

if transpose_a is True:
if transpose_a:
x1 = tf.transpose(x1)
if transpose_b is True:
if transpose_b:
x2 = tf.transpose(x2)

if adjoint_a is True:
if adjoint_a:
x1 = tf.linalg.adjoint(x1)
if adjoint_b is True:
if adjoint_b:
x2 = tf.linalg.adjoint(x2)

if dtype_from.is_unsigned or dtype_from == tf.int8 or dtype_from == tf.int16:
Expand Down Expand Up @@ -554,8 +554,11 @@ def solve(
x2: Union[tf.Tensor, tf.Variable],
/,
*,
adjoint: bool = False,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if adjoint:
x1 = tf.linalg.adjoint(x1)
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
expanded_last = False
if len(x2.shape) <= 1:
Expand Down
11 changes: 7 additions & 4 deletions ivy/functional/backends/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ def matmul(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if transpose_a is True:
if transpose_a:
x1 = torch.t(x1)
if transpose_b is True:
if transpose_b:
x2 = torch.t(x2)
if adjoint_a is True:
if adjoint_a:
x1 = torch.adjoint(x1)
if adjoint_b is True:
if adjoint_b:
x2 = torch.adjoint(x2)
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
return torch.matmul(x1, x2, out=out)
Expand Down Expand Up @@ -331,8 +331,11 @@ def solve(
x2: torch.Tensor,
/,
*,
adjoint: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if adjoint:
x1 = torch.adjoint(x1)
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
expanded_last = False
if len(x2.shape) <= 1:
Expand Down
14 changes: 10 additions & 4 deletions ivy_tests/test_ivy/test_functional/test_core/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def test_slogdet(

# solve
@st.composite
def _get_first_matrix(draw):
def _get_first_matrix(draw, adjoint=True):
# batch_shape, random_size, shared

# float16 causes a crash when filtering out matrices
Expand All @@ -685,14 +685,19 @@ def _get_first_matrix(draw):
shared_size = draw(
st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size")
)
return input_dtype, draw(
matrix = draw(
helpers.array_values(
dtype=input_dtype,
shape=tuple([shared_size, shared_size]),
min_value=2,
max_value=5,
).filter(lambda x: np.linalg.cond(x) < 1 / sys.float_info.epsilon)
)
if adjoint:
adjoint = draw(st.booleans())
if adjoint:
matrix = np.transpose(np.conjugate(matrix))
return input_dtype, matrix, adjoint


@st.composite
Expand Down Expand Up @@ -720,7 +725,7 @@ def _get_second_matrix(draw):

@handle_test(
fn_tree="functional.ivy.solve",
x=_get_first_matrix(),
x=_get_first_matrix(adjoint=True),
y=_get_second_matrix(),
)
def test_solve(
Expand All @@ -733,7 +738,7 @@ def test_solve(
on_device,
ground_truth_backend,
):
input_dtype1, x1 = x
input_dtype1, x1, adjoint = x
input_dtype2, x2 = y
helpers.test_function(
ground_truth_backend=ground_truth_backend,
Expand All @@ -746,6 +751,7 @@ def test_solve(
atol_=1e-1,
x1=x1,
x2=x2,
adjoint=adjoint,
)


Expand Down

0 comments on commit 63b6919

Please sign in to comment.