diff --git a/mir_eval/multipitch.py b/mir_eval/multipitch.py index 19c1fdc5..64236053 100644 --- a/mir_eval/multipitch.py +++ b/mir_eval/multipitch.py @@ -235,10 +235,11 @@ def compute_num_true_positives(ref_freqs, est_freqs, window=0.5, chroma=False): # match chroma-wrapped frequency events matching = util.match_events( ref_frame, est_frame, window, - outer_distance=util._outer_distance_mod_n) + distance=util._outer_distance_mod_n) else: # match frequency events within tolerance window in semitones matching = util.match_events(ref_frame, est_frame, window) + true_positives[i] = len(matching) return true_positives diff --git a/mir_eval/util.py b/mir_eval/util.py index ae8af57e..e0c2cc8e 100644 --- a/mir_eval/util.py +++ b/mir_eval/util.py @@ -589,13 +589,60 @@ def recurse(v): recurse(v) -def match_events(ref, est, window): +def _outer_distance_mod_n(ref, est, modulus=12): + """Compute the absolute outer distance modulo n. + Using this distance, d(11, 0) = 1 (modulo 12) + + Parameters + ---------- + ref : np.ndarray, shape=(n,) + Array of reference values. + est : np.ndarray, shape=(m,) + Array of estimated values. + modulus : int + The modulus. + 12 by default for octave equivalence. + + Returns + ------- + outer_distance : np.ndarray, shape=(n, m) + The outer circular distance modulo n. + + """ + ref_mod_n = np.mod(ref, modulus) + est_mod_n = np.mod(est, modulus) + abs_diff = np.abs(np.subtract.outer(ref_mod_n, est_mod_n)) + return np.minimum(abs_diff, modulus - abs_diff) + + +def _outer_distance(ref, est): + """Compute the absolute outer distance. + Computes |ref[i] - est[j]| for each i and j. + + Parameters + ---------- + ref : np.ndarray, shape=(n,) + Array of reference values. + est : np.ndarray, shape=(m,) + Array of estimated values. + + Returns + ------- + outer_distance : np.ndarray, shape=(n, m) + The outer 1d-euclidean distance. + + """ + return np.abs(np.subtract.outer(ref, est)) + + +def match_events(ref, est, window, distance=_outer_distance): """Compute a maximum matching between reference and estimated event times, subject to a window constraint. Given two lists of event times ``ref`` and ``est``, we seek the largest set - of correspondences ``(ref[i], est[j])`` such that ``|ref[i] - est[j]| <= - window``, and each ``ref[i]`` and ``est[j]`` is matched at most once. + of correspondences ``(ref[i], est[j])`` such that + ``distance(ref[i], est[j]) <= window``, and each + ``ref[i]`` and ``est[j]`` is matched at most once. This is useful for computing precision/recall metrics in beat tracking, onset detection, and segmentation. @@ -603,11 +650,14 @@ def match_events(ref, est, window): Parameters ---------- ref : np.ndarray, shape=(n,) - Array of reference event times + Array of reference values est : np.ndarray, shape=(m,) - Array of estimated event times + Array of estimated values window : float > 0 Size of the window. + distance : function + function that computes the outer distance of ref and est. + By default uses _outer_distance, |ref[i] - est[j]| Returns ------- @@ -616,9 +666,11 @@ def match_events(ref, est, window): ``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``. """ + if distance is None: + distance = _outer_distance # Compute the indices of feasible pairings - hits = np.where(np.abs(np.subtract.outer(ref, est)) <= window) + hits = np.where(distance(ref, est) <= window) # Construct the graph input G = {} diff --git a/tests/test_multipitch.py b/tests/test_multipitch.py index e5c380db..d493af14 100644 --- a/tests/test_multipitch.py +++ b/tests/test_multipitch.py @@ -148,6 +148,7 @@ def test_compute_num_true_positives(): actual = mir_eval.multipitch.compute_num_true_positives( ref_freqs, est_freqs) assert np.allclose(actual, expected, atol=A_TOL) + ref_freqs_chroma = [ np.array([0., 1.5]), np.array([]), diff --git a/tests/test_util.py b/tests/test_util.py index 646333a9..b0999c67 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -160,6 +160,52 @@ def test_bipartite_match(): assert v in G[k] or k in G[v] +def test_outer_distance_mod_n(): + ref = [1., 2., 3.] + est = [1.1, 6., 1.9, 5., 10.] + expected = np.array([ + [0.1, 5., 0.9, 4., 3.], + [0.9, 4., 0.1, 3., 4.], + [1.9, 3., 1.1, 2., 5.]]) + actual = mir_eval.util._outer_distance_mod_n(ref, est) + assert np.allclose(actual, expected) + + ref = [13., 14., 15.] + est = [1.1, 6., 1.9, 5., 10.] + expected = np.array([ + [0.1, 5., 0.9, 4., 3.], + [0.9, 4., 0.1, 3., 4.], + [1.9, 3., 1.1, 2., 5.]]) + actual = mir_eval.util._outer_distance_mod_n(ref, est) + assert np.allclose(actual, expected) + + +def test_outer_distance(): + ref = [1., 2., 3.] + est = [1.1, 6., 1.9, 5., 10.] + expected = np.array([ + [0.1, 5., 0.9, 4., 9.], + [0.9, 4., 0.1, 3., 8.], + [1.9, 3., 1.1, 2., 7.]]) + actual = mir_eval.util._outer_distance(ref, est) + assert np.allclose(actual, expected) + + +def test_match_events(): + ref = [1., 2., 3.] + est = [1.1, 6., 1.9, 5., 10.] + expected = [(0, 0), (1, 2)] + actual = mir_eval.util.match_events(ref, est, 0.5) + assert actual == expected + + ref = [1., 2., 3., 11.9] + est = [1.1, 6., 1.9, 5., 10., 0.] + expected = [(0, 0), (1, 2), (3, 5)] + actual = mir_eval.util.match_events( + ref, est, 0.5, distance=mir_eval.util._outer_distance_mod_n) + assert actual == expected + + def test_validate_intervals(): # Test for ValueError when interval shape is invalid nose.tools.assert_raises(