Skip to content

Commit

Permalink
Abstracted distance in match_events.
Browse files Browse the repository at this point in the history
  • Loading branch information
rabitt committed Mar 19, 2016
1 parent 6f3bee0 commit 17a9f3b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 7 deletions.
3 changes: 2 additions & 1 deletion mir_eval/multipitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,11 @@ def compute_num_true_positives(ref_freqs, est_freqs, window=0.5, chroma=False):
# match chroma-wrapped frequency events
matching = util.match_events(
ref_frame, est_frame, window,
outer_distance=util._outer_distance_mod_n)
distance=util._outer_distance_mod_n)
else:
# match frequency events within tolerance window in semitones
matching = util.match_events(ref_frame, est_frame, window)

true_positives[i] = len(matching)

return true_positives
Expand Down
64 changes: 58 additions & 6 deletions mir_eval/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,25 +589,75 @@ def recurse(v):
recurse(v)


def match_events(ref, est, window):
def _outer_distance_mod_n(ref, est, modulus=12):
"""Compute the absolute outer distance modulo n.
Using this distance, d(11, 0) = 1 (modulo 12)
Parameters
----------
ref : np.ndarray, shape=(n,)
Array of reference values.
est : np.ndarray, shape=(m,)
Array of estimated values.
modulus : int
The modulus.
12 by default for octave equivalence.
Returns
-------
outer_distance : np.ndarray, shape=(n, m)
The outer circular distance modulo n.
"""
ref_mod_n = np.mod(ref, modulus)
est_mod_n = np.mod(est, modulus)
abs_diff = np.abs(np.subtract.outer(ref_mod_n, est_mod_n))
return np.minimum(abs_diff, modulus - abs_diff)


def _outer_distance(ref, est):
"""Compute the absolute outer distance.
Computes |ref[i] - est[j]| for each i and j.
Parameters
----------
ref : np.ndarray, shape=(n,)
Array of reference values.
est : np.ndarray, shape=(m,)
Array of estimated values.
Returns
-------
outer_distance : np.ndarray, shape=(n, m)
The outer 1d-euclidean distance.
"""
return np.abs(np.subtract.outer(ref, est))


def match_events(ref, est, window, distance=_outer_distance):
"""Compute a maximum matching between reference and estimated event times,
subject to a window constraint.
Given two lists of event times ``ref`` and ``est``, we seek the largest set
of correspondences ``(ref[i], est[j])`` such that ``|ref[i] - est[j]| <=
window``, and each ``ref[i]`` and ``est[j]`` is matched at most once.
of correspondences ``(ref[i], est[j])`` such that
``distance(ref[i], est[j]) <= window``, and each
``ref[i]`` and ``est[j]`` is matched at most once.
This is useful for computing precision/recall metrics in beat tracking,
onset detection, and segmentation.
Parameters
----------
ref : np.ndarray, shape=(n,)
Array of reference event times
Array of reference values
est : np.ndarray, shape=(m,)
Array of estimated event times
Array of estimated values
window : float > 0
Size of the window.
distance : function
function that computes the outer distance of ref and est.
By default uses _outer_distance, |ref[i] - est[j]|
Returns
-------
Expand All @@ -616,9 +666,11 @@ def match_events(ref, est, window):
``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``.
"""
if distance is None:
distance = _outer_distance

# Compute the indices of feasible pairings
hits = np.where(np.abs(np.subtract.outer(ref, est)) <= window)
hits = np.where(distance(ref, est) <= window)

# Construct the graph input
G = {}
Expand Down
1 change: 1 addition & 0 deletions tests/test_multipitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_compute_num_true_positives():
actual = mir_eval.multipitch.compute_num_true_positives(
ref_freqs, est_freqs)
assert np.allclose(actual, expected, atol=A_TOL)

ref_freqs_chroma = [
np.array([0., 1.5]),
np.array([]),
Expand Down
46 changes: 46 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,52 @@ def test_bipartite_match():
assert v in G[k] or k in G[v]


def test_outer_distance_mod_n():
ref = [1., 2., 3.]
est = [1.1, 6., 1.9, 5., 10.]
expected = np.array([
[0.1, 5., 0.9, 4., 3.],
[0.9, 4., 0.1, 3., 4.],
[1.9, 3., 1.1, 2., 5.]])
actual = mir_eval.util._outer_distance_mod_n(ref, est)
assert np.allclose(actual, expected)

ref = [13., 14., 15.]
est = [1.1, 6., 1.9, 5., 10.]
expected = np.array([
[0.1, 5., 0.9, 4., 3.],
[0.9, 4., 0.1, 3., 4.],
[1.9, 3., 1.1, 2., 5.]])
actual = mir_eval.util._outer_distance_mod_n(ref, est)
assert np.allclose(actual, expected)


def test_outer_distance():
ref = [1., 2., 3.]
est = [1.1, 6., 1.9, 5., 10.]
expected = np.array([
[0.1, 5., 0.9, 4., 9.],
[0.9, 4., 0.1, 3., 8.],
[1.9, 3., 1.1, 2., 7.]])
actual = mir_eval.util._outer_distance(ref, est)
assert np.allclose(actual, expected)


def test_match_events():
ref = [1., 2., 3.]
est = [1.1, 6., 1.9, 5., 10.]
expected = [(0, 0), (1, 2)]
actual = mir_eval.util.match_events(ref, est, 0.5)
assert actual == expected

ref = [1., 2., 3., 11.9]
est = [1.1, 6., 1.9, 5., 10., 0.]
expected = [(0, 0), (1, 2), (3, 5)]
actual = mir_eval.util.match_events(
ref, est, 0.5, distance=mir_eval.util._outer_distance_mod_n)
assert actual == expected


def test_validate_intervals():
# Test for ValueError when interval shape is invalid
nose.tools.assert_raises(
Expand Down

0 comments on commit 17a9f3b

Please sign in to comment.