Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #248, support gaps in labeled interval interpolation #249

Merged
merged 4 commits into from
Apr 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions mir_eval/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1,
def interpolate_intervals(intervals, labels, time_points, fill_value=None):
"""Assign labels to a set of points in time given a set of intervals.

Note: Times outside of the known boundaries are mapped to None by default.
Time points that do not lie within an interval are mapped to `fill_value`.

Parameters
----------
intervals : np.ndarray, shape=(n, d)
intervals : np.ndarray, shape=(n, 2)
An array of time intervals, as returned by
:func:`mir_eval.io.load_intervals()`.
The ``i`` th interval spans time ``intervals[i, 0]`` to
Expand All @@ -145,7 +145,8 @@ def interpolate_intervals(intervals, labels, time_points, fill_value=None):
The annotation for each interval

time_points : array_like, shape=(m,)
Points in time to assign labels.
Points in time to assign labels. These must be in
non-decreasing order.

fill_value : type(labels[0])
Object to use for the label with out-of-range time points.
Expand All @@ -156,22 +157,26 @@ def interpolate_intervals(intervals, labels, time_points, fill_value=None):
aligned_labels : list
Labels corresponding to the given time points.

Raises
------
ValueError
If `time_points` is not in non-decreasing order.
"""

# Sort the intervals by start time
intervals, labels = sort_labeled_intervals(intervals, labels)
# Verify that time_points is sorted
time_points = np.asarray(time_points)

if np.any(time_points[1:] < time_points[:-1]):
raise ValueError('time_points must be in non-decreasing order')

aligned_labels = [fill_value] * len(time_points)

start, end = intervals.min(), intervals.max()
starts = np.searchsorted(time_points, intervals[:, 0], side='left')
ends = np.searchsorted(time_points, intervals[:, 1], side='right')

aligned_labels = []
for (start, end, lab) in zip(starts, ends, labels):
aligned_labels[start:end] = [lab] * (end - start)

for tpoint in time_points:
# This logic isn't correct if there's a gap in intervals
if start <= tpoint <= end:
index = np.argmax(intervals[:, 0] > tpoint) - 1
aligned_labels.append(labels[index])
else:
aligned_labels.append(fill_value)
return aligned_labels


Expand Down
21 changes: 21 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ def test_interpolate_intervals():
expected_ans)


def test_interpolate_intervals_gap():
"""Check that an interval set is interpolated properly, with gaps."""
labels = list('abc')
intervals = np.array([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]])
time_points = [0.0, 0.75, 1.25, 1.75, 2.25, 2.75, 3.5]
expected_ans = ['N', 'a', 'N', 'b', 'N', 'c', 'N']
assert (util.interpolate_intervals(intervals, labels, time_points, 'N') ==
expected_ans)


@nose.tools.raises(ValueError)
def test_interpolate_intervals_badtime():
"""Check that interpolate_intervals throws an exception if
input is unordered.
"""
labels = list('abc')
intervals = np.array([(n, n + 1.0) for n in range(len(labels))])
time_points = [-1.0, 0.1, 0.9, 0.8, 2.3, 4.0]
mir_eval.util.interpolate_intervals(intervals, labels, time_points)


def test_intervals_to_samples():
"""Check that an interval set is sampled properly, with boundaries
conditions and out-of-range values.
Expand Down