Skip to content

Commit

Permalink
add torch cosine similarity (#6045)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogbanugot authored Oct 22, 2022
1 parent bc4f0c8 commit 7b24edd
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
26 changes: 26 additions & 0 deletions ivy/functional/frontends/torch/nn/functional/distance_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import ivy
import ivy.functional.frontends.torch as torch_frontend
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back


@to_ivy_arrays_and_back
def cosine_similarity(x1, x2, *, dim=1, eps=1e-08):
x1, x2 = torch_frontend.promote_types_of_torch_inputs(x1, x2)

if len(x1.shape) == len(x2.shape) and len(x2.shape) >= 2:
numerator = ivy.sum(x1 * x2, axis=dim)
x1_squared_norm = ivy.sum(ivy.square(x1), axis=dim)
x2_squared_norm = ivy.sum(ivy.square(x2), axis=dim)
else:
numerator = ivy.sum(x1 * x2)
x1_squared_norm = ivy.sum(ivy.square(x1))
x2_squared_norm = ivy.sum(ivy.square(x2))

x1_norm = ivy.sqrt(x1_squared_norm)
x2_norm = ivy.sqrt(x2_squared_norm)
norm_mm = x1_norm * x2_norm
norm_mm, eps = torch_frontend.promote_types_of_torch_inputs(norm_mm, eps)
denominator = ivy.maximum(norm_mm, eps)

cosine = numerator / denominator
return cosine
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from hypothesis import assume, given, strategies as st

# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_cmd_line_args


def _filter_dtypes(input_dtype):
assume(("bfloat16" not in input_dtype) and ("float16" not in input_dtype))


# Cosine Similarity
@handle_cmd_line_args
@given(
d_type_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float", full=True),
min_value=2,
max_value=5,
min_dim_size=2,
shared_dtype=True,
num_arrays=2,
),
dim=st.integers(min_value=-1, max_value=0),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.torch.nn.functional.cosine_similarity"
),
)
def test_torch_cosine_similarity(
d_type_and_x,
dim,
with_out,
as_variable,
num_positional_args,
native_array,
):
dtype, x = d_type_and_x
_filter_dtypes(dtype)
helpers.test_frontend_function(
input_dtypes=dtype,
as_variable_flags=as_variable,
with_out=with_out,
num_positional_args=num_positional_args,
native_array_flags=native_array,
frontend="torch",
fn_tree="nn.functional.cosine_similarity",
rtol=1e-01,
x1=x[0],
x2=x[1],
dim=dim,
)

0 comments on commit 7b24edd

Please sign in to comment.