diff --git a/ivy/functional/frontends/tensorflow/linalg.py b/ivy/functional/frontends/tensorflow/linalg.py index fd7862a6c79da..3dc34d2cabfbe 100644 --- a/ivy/functional/frontends/tensorflow/linalg.py +++ b/ivy/functional/frontends/tensorflow/linalg.py @@ -35,6 +35,23 @@ def eigvalsh(tensor, name=None): return ivy.eigvalsh(tensor) +@to_ivy_arrays_and_back +def matmul( + a, + b, + transpose_a=False, + transpose_b=False, + adjoint_a=False, + adjoint_b=False, + a_is_sparse=False, + b_is_sparse=False, + output_type=None, + name=None +): + # TODO : handle conjugate when ivy supports complex numbers + return ivy.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b) + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"2.9.0 and below": ("float16", "bfloat16")}, "tensorflow") def solve(matrix, rhs): diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py index cd4badbc2eb66..b662fa59f6b75 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py @@ -131,6 +131,50 @@ def test_matrix_rank( ) +@handle_frontend_test( + fn_tree="tensorflow.linalg.matmul", + dtype_x=helpers.dtype_and_values( + available_dtypes=[ + "float16", + "float32", + "float64", + "int32", + "int64", + ], + shape=(3, 3), + num_arrays=2, + shared_dtype=True, + min_value=-1e04, + max_value=1e04, + ), + transpose_a=st.booleans(), + transpose_b=st.booleans(), + test_with_out=st.just(False), +) +def test_matmul( + *, + dtype_x, + transpose_a, + transpose_b, + frontend, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + b=x[1], + transpose_a=transpose_a, + transpose_b=transpose_b, + ) + + @st.composite def _solve_get_dtype_and_data(draw): batch = draw(st.integers(min_value=1, max_value=5))