Skip to content

Commit

Permalink
Support gaps in labeled interval interpolation (#249)
Browse files Browse the repository at this point in the history
Fixes #248
  • Loading branch information
bmcfee authored and craffel committed Jun 22, 2018
1 parent 56b9360 commit 9ea97f7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
33 changes: 19 additions & 14 deletions mir_eval/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1,
def interpolate_intervals(intervals, labels, time_points, fill_value=None):
"""Assign labels to a set of points in time given a set of intervals.
Note: Times outside of the known boundaries are mapped to None by default.
Time points that do not lie within an interval are mapped to `fill_value`.
Parameters
----------
intervals : np.ndarray, shape=(n, d)
intervals : np.ndarray, shape=(n, 2)
An array of time intervals, as returned by
:func:`mir_eval.io.load_intervals()`.
The ``i`` th interval spans time ``intervals[i, 0]`` to
Expand All @@ -145,7 +145,8 @@ def interpolate_intervals(intervals, labels, time_points, fill_value=None):
The annotation for each interval
time_points : array_like, shape=(m,)
Points in time to assign labels.
Points in time to assign labels. These must be in
non-decreasing order.
fill_value : type(labels[0])
Object to use for the label with out-of-range time points.
Expand All @@ -156,22 +157,26 @@ def interpolate_intervals(intervals, labels, time_points, fill_value=None):
aligned_labels : list
Labels corresponding to the given time points.
Raises
------
ValueError
If `time_points` is not in non-decreasing order.
"""

# Sort the intervals by start time
intervals, labels = sort_labeled_intervals(intervals, labels)
# Verify that time_points is sorted
time_points = np.asarray(time_points)

if np.any(time_points[1:] < time_points[:-1]):
raise ValueError('time_points must be in non-decreasing order')

aligned_labels = [fill_value] * len(time_points)

start, end = intervals.min(), intervals.max()
starts = np.searchsorted(time_points, intervals[:, 0], side='left')
ends = np.searchsorted(time_points, intervals[:, 1], side='right')

aligned_labels = []
for (start, end, lab) in zip(starts, ends, labels):
aligned_labels[start:end] = [lab] * (end - start)

for tpoint in time_points:
# This logic isn't correct if there's a gap in intervals
if start <= tpoint <= end:
index = np.argmax(intervals[:, 0] > tpoint) - 1
aligned_labels.append(labels[index])
else:
aligned_labels.append(fill_value)
return aligned_labels


Expand Down
21 changes: 21 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ def test_interpolate_intervals():
expected_ans)


def test_interpolate_intervals_gap():
"""Check that an interval set is interpolated properly, with gaps."""
labels = list('abc')
intervals = np.array([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]])
time_points = [0.0, 0.75, 1.25, 1.75, 2.25, 2.75, 3.5]
expected_ans = ['N', 'a', 'N', 'b', 'N', 'c', 'N']
assert (util.interpolate_intervals(intervals, labels, time_points, 'N') ==
expected_ans)


@nose.tools.raises(ValueError)
def test_interpolate_intervals_badtime():
"""Check that interpolate_intervals throws an exception if
input is unordered.
"""
labels = list('abc')
intervals = np.array([(n, n + 1.0) for n in range(len(labels))])
time_points = [-1.0, 0.1, 0.9, 0.8, 2.3, 4.0]
mir_eval.util.interpolate_intervals(intervals, labels, time_points)


def test_intervals_to_samples():
"""Check that an interval set is sampled properly, with boundaries
conditions and out-of-range values.
Expand Down

0 comments on commit 9ea97f7

Please sign in to comment.