Skip to content

Commit

Permalink
Merge branch 'master' into fix-bucket-pad
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyang2057 authored Sep 22, 2023
2 parents b3f90c0 + 4df2cd9 commit 74ef18d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/compare_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def cosine(gt: np.ndarray, pred: np.ndarray, *args):


def euclidean(gt: np.ndarray, pred: np.ndarray, *args):
return np.linalg.norm(gt - pred, 2)**2
return np.linalg.norm(gt.reshape(-1) - pred.reshape(-1))


def allclose(gt: np.ndarray, pred: np.ndarray, thresh: float):
Expand Down Expand Up @@ -71,7 +71,7 @@ def compare_binfile(result_path: Tuple[str, str],
np.savetxt(str(p.parent / (p.stem + '_hist.csv')),
np.stack((x[:-1], y)).T, fmt='%f', delimiter=',')
similarity_info = f"\n{similarity_name} similarity = {similarity}, threshold = {threshold}\n"
if similarity_name in ['cosine', 'euclidean', 'segment']:
if similarity_name in ['cosine', 'segment']:
compare_op = lt
else:
compare_op = gt
Expand All @@ -97,7 +97,7 @@ def compare_ndarray(expected: np.ndarray,
np.savetxt(dump_file, np.stack((x[:-1], y)).T, fmt='%f', delimiter=',')
similarity_info = f"{similarity_name} similarity = {similarity}, threshold = {threshold}\n"

if similarity_name in ['cosine', 'euclidean', 'segment']:
if similarity_name in ['cosine', 'segment']:
compare_op = lt
else:
compare_op = gt
Expand Down

0 comments on commit 74ef18d

Please sign in to comment.