diff --git a/mir_eval/util.py b/mir_eval/util.py index 1df1142a..0ce91304 100644 --- a/mir_eval/util.py +++ b/mir_eval/util.py @@ -129,11 +129,11 @@ def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1, def interpolate_intervals(intervals, labels, time_points, fill_value=None): """Assign labels to a set of points in time given a set of intervals. - Note: Times outside of the known boundaries are mapped to None by default. + Time points that do not lie within an interval are mapped to `fill_value`. Parameters ---------- - intervals : np.ndarray, shape=(n, d) + intervals : np.ndarray, shape=(n, 2) An array of time intervals, as returned by :func:`mir_eval.io.load_intervals()`. The ``i`` th interval spans time ``intervals[i, 0]`` to @@ -145,7 +145,8 @@ def interpolate_intervals(intervals, labels, time_points, fill_value=None): The annotation for each interval time_points : array_like, shape=(m,) - Points in time to assign labels. + Points in time to assign labels. These must be in + non-decreasing order. fill_value : type(labels[0]) Object to use for the label with out-of-range time points. @@ -156,22 +157,26 @@ def interpolate_intervals(intervals, labels, time_points, fill_value=None): aligned_labels : list Labels corresponding to the given time points. + Raises + ------ + ValueError + If `time_points` is not in non-decreasing order. """ - # Sort the intervals by start time - intervals, labels = sort_labeled_intervals(intervals, labels) + # Verify that time_points is sorted + time_points = np.asarray(time_points) + + if np.any(time_points[1:] < time_points[:-1]): + raise ValueError('time_points must be in non-decreasing order') + + aligned_labels = [fill_value] * len(time_points) - start, end = intervals.min(), intervals.max() + starts = np.searchsorted(time_points, intervals[:, 0], side='left') + ends = np.searchsorted(time_points, intervals[:, 1], side='right') - aligned_labels = [] + for (start, end, lab) in zip(starts, ends, labels): + aligned_labels[start:end] = [lab] * (end - start) - for tpoint in time_points: - # This logic isn't correct if there's a gap in intervals - if start <= tpoint <= end: - index = np.argmax(intervals[:, 0] > tpoint) - 1 - aligned_labels.append(labels[index]) - else: - aligned_labels.append(fill_value) return aligned_labels diff --git a/tests/test_util.py b/tests/test_util.py index b4103d83..913e529f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -21,6 +21,27 @@ def test_interpolate_intervals(): expected_ans) +def test_interpolate_intervals_gap(): + """Check that an interval set is interpolated properly, with gaps.""" + labels = list('abc') + intervals = np.array([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]]) + time_points = [0.0, 0.75, 1.25, 1.75, 2.25, 2.75, 3.5] + expected_ans = ['N', 'a', 'N', 'b', 'N', 'c', 'N'] + assert (util.interpolate_intervals(intervals, labels, time_points, 'N') == + expected_ans) + + +@nose.tools.raises(ValueError) +def test_interpolate_intervals_badtime(): + """Check that interpolate_intervals throws an exception if + input is unordered. + """ + labels = list('abc') + intervals = np.array([(n, n + 1.0) for n in range(len(labels))]) + time_points = [-1.0, 0.1, 0.9, 0.8, 2.3, 4.0] + mir_eval.util.interpolate_intervals(intervals, labels, time_points) + + def test_intervals_to_samples(): """Check that an interval set is sampled properly, with boundaries conditions and out-of-range values.