diff --git a/.gitignore b/.gitignore index b878b2c7..53e602e3 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,7 @@ Thumbs.db .idea/* # docs -docs/_build/* \ No newline at end of file +docs/_build/* + +# matplotlib tsets +tests/result_images/* diff --git a/.travis.yml b/.travis.yml index 93c48210..2f4042e9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,17 +20,17 @@ before_install: - conda config --set always_yes yes --set changeps1 no - conda update -q conda - conda info -a - - deps='pip atlas numpy scipy sphinx nose six future pep8' + - deps='pip atlas numpy scipy sphinx nose six future pep8 matplotlib decorator' - conda create -q -n test-environment "python=$TRAVIS_PYTHON_VERSION" $deps - source activate test-environment - pip install python-coveralls - pip install numpydoc install: - - pip install -e . + - pip install -e .[display] before_script: - - pep8 mir_eval tests evaluators + - pep8 mir_eval evaluators tests script: - nosetests -v --with-coverage --cover-package=mir_eval diff --git a/mir_eval/display.py b/mir_eval/display.py new file mode 100644 index 00000000..4edcadad --- /dev/null +++ b/mir_eval/display.py @@ -0,0 +1,845 @@ +# -*- encoding: utf-8 -*- +'''Display functions''' + +from collections import defaultdict + +import numpy as np +from scipy.signal import spectrogram + +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle, Patch +from matplotlib.ticker import FuncFormatter, MultipleLocator +from matplotlib.colors import LinearSegmentedColormap, LogNorm, ColorConverter +from matplotlib.collections import BrokenBarHCollection + +from .melody import freq_to_voicing +from .util import midi_to_hz, hz_to_midi + + +def __expand_limits(ax, limits, which='x'): + '''Helper function to expand axis limits''' + + if which == 'x': + getter, setter = ax.get_xlim, ax.set_xlim + elif which == 'y': + getter, setter = ax.get_ylim, ax.set_ylim + else: + raise ValueError('invalid axis: {}'.format(which)) + + old_lims = getter() + new_lims = list(limits) + + # infinite limits occur on new axis objects with no data + if np.isfinite(old_lims[0]): + new_lims[0] = min(old_lims[0], limits[0]) + + if np.isfinite(old_lims[1]): + new_lims[1] = max(old_lims[1], limits[1]) + + setter(new_lims) + + +def __get_axes(ax=None, fig=None): + '''Get or construct the target axes object for a new plot. + + Parameters + ---------- + ax : matplotlib.pyplot.axes, optional + If provided, return this axes object directly. + + fig : matplotlib.figure.Figure, optional + The figure to query for axes. + By default, uses the current figure `plt.gcf()`. + + Returns + ------- + ax : matplotlib.pyplot.axes + An axis handle on which to draw the segmentation. + If none is provided, a new set of axes is created. + + new_axes : bool + If `True`, the axis object was newly constructed. + If `False`, the axis object already existed. + + ''' + + new_axes = False + + if ax is not None: + return ax, new_axes + + if fig is None: + fig = plt.gcf() + + if not fig.get_axes(): + new_axes = True + + return fig.gca(), new_axes + + +def segments(intervals, labels, base=None, height=None, text=False, + text_kw=None, ax=None, **kwargs): + '''Plot a segmentation as a set of disjoint rectangles. + + Parameters + ---------- + intervals : np.ndarray, shape=(n, 2) + segment intervals, in the format returned by + :func:`mir_eval.io.load_intervals` or + :func:`mir_eval.io.load_labeled_intervals`. + + labels : list, shape=(n,) + reference segment labels, in the format returned by + :func:`mir_eval.io.load_labeled_intervals`. + + base : number + The vertical position of the base of the rectangles. + By default, this will be the bottom of the plot. + + height : number + The height of the rectangles. + By default, this will be the top of the plot (minus `base`). + + text : bool + If true, each segment's label is displayed in its + upper-left corner + + text_kw : dict + If `text==True`, the properties of the text + object can be specified here. + See `matplotlib.pyplot.Text` for valid parameters + + ax : matplotlib.pyplot.axes + An axis handle on which to draw the segmentation. + If none is provided, a new set of axes is created. + + kwargs + Additional keyword arguments to pass to + `matplotlib.patches.Rectangle`. + + Returns + ------- + ax : matplotlib.pyplot.axes._subplots.AxesSubplot + A handle to the (possibly constructed) plot axes + ''' + if text_kw is None: + text_kw = dict() + text_kw.setdefault('va', 'top') + text_kw.setdefault('clip_on', True) + text_kw.setdefault('bbox', dict(boxstyle='round', facecolor='white')) + + # Make sure we have a numpy array + intervals = np.atleast_2d(intervals) + + seg_def_style = dict(linewidth=1) + + ax, new_axes = __get_axes(ax=ax) + + if new_axes: + ax.set_ylim([0, 1]) + + # Infer height + if base is None: + base = ax.get_ylim()[0] + + if height is None: + height = ax.get_ylim()[1] + + cycler = ax._get_patches_for_fill.prop_cycler + + seg_map = dict() + + for lab in labels: + if lab in seg_map: + continue + + style = next(cycler) + seg_map[lab] = seg_def_style.copy() + seg_map[lab].update(style) + # Swap color -> facecolor here so we preserve edgecolor on rects + seg_map[lab]['facecolor'] = seg_map[lab].pop('color') + seg_map[lab].update(kwargs) + seg_map[lab]['label'] = lab + + for ival, lab in zip(intervals, labels): + rect = Rectangle((ival[0], base), ival[1] - ival[0], height, + **seg_map[lab]) + ax.add_patch(rect) + seg_map[lab].pop('label', None) + + if text: + ann = ax.annotate(lab, + xy=(ival[0], height), xycoords='data', + xytext=(8, -10), textcoords='offset points', + **text_kw) + ann.set_clip_path(rect) + + if new_axes: + ax.set_yticks([]) + + # Only expand if we have data + if intervals.size: + __expand_limits(ax, [intervals.min(), intervals.max()], which='x') + + return ax + + +def labeled_intervals(intervals, labels, label_set=None, + base=None, height=None, extend_labels=True, + ax=None, tick=True, **kwargs): + '''Plot labeled intervals with each label on its own row. + + Parameters + ---------- + intervals : np.ndarray, shape=(n, 2) + segment intervals, in the format returned by + :func:`mir_eval.io.load_intervals` or + :func:`mir_eval.io.load_labeled_intervals`. + + labels : list, shape=(n,) + reference segment labels, in the format returned by + :func:`mir_eval.io.load_labeled_intervals`. + + label_set : list + An (ordered) list of labels to determine the plotting order. + If not provided, the labels will be inferred from + `ax.get_yticklabels()`. + If no `yticklabels` exist, then the sorted set of unique values + in `labels` is taken as the label set. + + base : np.ndarray, shape=(n,), optional + Vertical positions of each label. + By default, labels are positioned at integers `np.arange(len(labels))`. + + height : scalar or np.ndarray, shape=(n,), optional + Height for each label. + If scalar, the same value is applied to all labels. + By default, each label has `height=1`. + + extend_labels : bool + If `False`, only values of `labels` that also exist in `label_set` + will be shown. + + If `True`, all labels are shown, with those in `labels` but + not in `label_set` appended to the top of the plot. + A horizontal line is drawn to indicate the separation between + values in or out of `label_set`. + + ax : matplotlib.pyplot.axes + An axis handle on which to draw the intervals. + If none is provided, a new set of axes is created. + + tick : bool + If `True`, sets tick positions and labels on the y-axis. + + kwargs + Additional keyword arguments to pass to + `matplotlib.collection.BrokenBarHCollection`. + + Returns + ------- + ax : matplotlib.pyplot.axes._subplots.AxesSubplot + A handle to the (possibly constructed) plot axes + ''' + + # Get the axes handle + ax, _ = __get_axes(ax=ax) + + # Make sure we have a numpy array + intervals = np.atleast_2d(intervals) + + if label_set is None: + # If we have non-empty pre-existing tick labels, use them + label_set = [_.get_text() for _ in ax.get_yticklabels()] + # If none of the label strings have content, treat it as empty + if not any(label_set): + label_set = [] + else: + label_set = list(label_set) + + # Put additional labels at the end, in order + if extend_labels: + ticks = label_set + sorted(set(labels) - set(label_set)) + elif label_set: + ticks = label_set + else: + ticks = sorted(set(labels)) + + style = dict(linewidth=1) + + style.update(next(ax._get_patches_for_fill.prop_cycler)) + # Swap color -> facecolor here so we preserve edgecolor on rects + style['facecolor'] = style.pop('color') + style.update(kwargs) + + if base is None: + base = np.arange(len(ticks)) + + if height is None: + height = 1 + + if np.isscalar(height): + height = height * np.ones_like(base) + + seg_y = dict() + for ybase, yheight, lab in zip(base, height, ticks): + seg_y[lab] = (ybase, yheight) + + xvals = defaultdict(list) + for ival, lab in zip(intervals, labels): + if lab not in seg_y: + continue + xvals[lab].append((ival[0], ival[1] - ival[0])) + + for lab in seg_y: + ax.add_collection(BrokenBarHCollection(xvals[lab], seg_y[lab], + **style)) + # Pop the label after the first time we see it, so we only get + # one legend entry + style.pop('label', None) + + # Draw a line separating the new labels from pre-existing labels + if label_set != ticks: + ax.axhline(len(label_set), color='k', alpha=0.5) + + if tick: + ax.set_yticks([]) + ax.grid('on', axis='y') + ax.set_yticks(base) + ax.set_yticklabels(ticks, va='bottom') + + if base.size: + __expand_limits(ax, [base.min(), (base + height).max()], which='y') + if intervals.size: + __expand_limits(ax, [intervals.min(), intervals.max()], which='x') + + return ax + + +def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): + '''Plot a hierarchical segmentation + + Parameters + ---------- + intervals_hier : list of np.ndarray + A list of segmentation intervals. Each element should be + an n-by-2 array of segment intervals, in the format returned by + :func:`mir_eval.io.load_intervals` or + :func:`mir_eval.io.load_labeled_intervals`. + Segmentations should be ordered by increasing specificity. + + labels_hier : list of list-like + A list of segmentation labels. Each element should + be a list of labels for the corresponding element in + `intervals_hier`. + + levels : list of string + Each element `levels[i]` is a label for the `i`th segmentation. + This is used in the legend to denote the levels in a segment hierarchy. + + kwargs : + Additional keyword arguments to `labeled_intervals`. + + Returns + ------- + ax + A handle to the matplotlib axes + ''' + + # This will break if a segment label exists in multiple levels + if levels is None: + levels = list(range(len(intervals_hier))) + + # Get the axes handle + ax, _ = __get_axes(ax=ax) + + # Count the pre-existing patches + n_patches = len(ax.patches) + + for ints, labs, key in zip(intervals_hier[::-1], + labels_hier[::-1], + levels[::-1]): + labeled_intervals(ints, labs, label=key, ax=ax, **kwargs) + + # Reverse the patch ordering for anything we've added. + # This way, intervals are listed in the legend from top to bottom + ax.patches[n_patches:] = ax.patches[n_patches:][::-1] + return ax + + +def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, + **kwargs): + '''Plot event times as a set of vertical lines + + Parameters + ---------- + times : np.ndarray, shape=(n,) + event times, in the format returned by + :func:`mir_eval.io.load_events` or + :func:`mir_eval.io.load_labeled_events`. + + labels : list, shape=(n,), optional + event labels, in the format returned by + :func:`mir_eval.io.load_labeled_events`. + + base : number + The vertical position of the base of the line. + By default, this will be the bottom of the plot. + + height : number + The height of the lines. + By default, this will be the top of the plot (minus `base`). + + ax : matplotlib.pyplot.axes + An axis handle on which to draw the segmentation. + If none is provided, a new set of axes is created. + + text_kw : dict + If `labels` is provided, the properties of the text + objects can be specified here. + See `matplotlib.pyplot.Text` for valid parameters + + kwargs + Additional keyword arguments to pass to + `matplotlib.pyplot.vlines`. + + Returns + ------- + ax : matplotlib.pyplot.axes._subplots.AxesSubplot + A handle to the (possibly constructed) plot axes + ''' + if text_kw is None: + text_kw = dict() + text_kw.setdefault('va', 'top') + text_kw.setdefault('clip_on', True) + text_kw.setdefault('bbox', dict(boxstyle='round', facecolor='white')) + + # make sure we have an array for times + times = np.asarray(times) + + # Get the axes handle + ax, new_axes = __get_axes(ax=ax) + + # If we have fresh axes, set the limits + + if new_axes: + # Infer base and height + if base is None: + base = 0 + if height is None: + height = 1 + + ax.set_ylim([base, height]) + else: + if base is None: + base = ax.get_ylim()[0] + + if height is None: + height = ax.get_ylim()[1] + + cycler = ax._get_patches_for_fill.prop_cycler + + style = next(cycler).copy() + style.update(kwargs) + # If the user provided 'colors', don't override it with 'color' + if 'colors' in style: + style.pop('color', None) + + lines = ax.vlines(times, base, base + height, **style) + + if labels: + for path, lab in zip(lines.get_paths(), labels): + ax.annotate(lab, + xy=(path.vertices[0][0], height), + xycoords='data', + xytext=(8, -10), textcoords='offset points', + **text_kw) + + if new_axes: + ax.set_yticks([]) + + __expand_limits(ax, [base, base + height], which='y') + + if times.size: + __expand_limits(ax, [times.min(), times.max()], which='x') + + return ax + + +def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): + '''Visualize pitch contours + + Parameters + ---------- + times : np.ndarray, shape=(n,) + Sample times of frequencies + + frequencies : np.ndarray, shape=(n,) + frequencies (in Hz) of the pitch contours. + Voicing is indicated by sign (positive for voiced, + non-positive for non-voiced). + + midi : bool + If `True`, plot on a MIDI-numbered vertical axis. + Otherwise, plot on a linear frequency axis. + + unvoiced : bool + If `True`, unvoiced pitch contours are plotted and indicated + by transparency. + + Otherwise, unvoiced pitch contours are omitted from the display. + + ax : matplotlib.pyplot.axes + An axis handle on which to draw the pitch contours. + If none is provided, a new set of axes is created. + + kwargs : + Additional keyword arguments to `matplotlib.pyplot.plot`. + + Returns + ------- + ax : matplotlib.pyplot.axes._subplots.AxesSubplot + A handle to the (possibly constructed) plot axes + ''' + + ax, _ = __get_axes(ax=ax) + + times = np.asarray(times) + + # First, segment into contiguously voiced contours + frequencies, voicings = freq_to_voicing(np.asarray(frequencies, + dtype=np.float)) + + # Here are all the change-points + v_changes = 1 + np.flatnonzero(voicings[1:] != voicings[:-1]) + v_changes = np.unique(np.concatenate([[0], v_changes, [len(voicings)]])) + + # Set up arrays of slices for voiced and unvoiced regions + v_slices, u_slices = [], [] + for start, end in zip(v_changes, v_changes[1:]): + idx = slice(start, end) + # A region is voiced if its starting sample is voiced + # It's unvoiced if none of the samples in the region are voiced. + if voicings[start]: + v_slices.append(idx) + elif frequencies[idx].all(): + u_slices.append(idx) + + # Now we just need to plot the contour + style = dict() + style.update(next(ax._get_lines.prop_cycler)) + style.update(kwargs) + + if midi: + idx = frequencies > 0 + frequencies[idx] = hz_to_midi(frequencies[idx]) + + # Tick at integer midi notes + ax.yaxis.set_minor_locator(MultipleLocator(1)) + + for idx in v_slices: + ax.plot(times[idx], frequencies[idx], **style) + style.pop('label', None) + + # Plot the unvoiced portions + if unvoiced: + style['alpha'] = style.get('alpha', 1.0) * 0.5 + for idx in u_slices: + ax.plot(times[idx], frequencies[idx], **style) + + return ax + + +def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, + **kwargs): + '''Visualize multiple f0 measurements + + Parameters + ---------- + times : np.ndarray, shape=(n,) + Sample times of frequencies + + frequencies : list of np.ndarray + frequencies (in Hz) of the pitch measurements. + Voicing is indicated by sign (positive for voiced, + non-positive for non-voiced). + + `times` and `frequencies` should be in the format produced by + :func:`mir_eval.io.load_ragged_time_series` + + midi : bool + If `True`, plot on a MIDI-numbered vertical axis. + Otherwise, plot on a linear frequency axis. + + unvoiced : bool + If `True`, unvoiced pitches are plotted and indicated + by transparency. + + Otherwise, unvoiced pitches are omitted from the display. + + ax : matplotlib.pyplot.axes + An axis handle on which to draw the pitch contours. + If none is provided, a new set of axes is created. + + kwargs : + Additional keyword arguments to `plt.scatter`. + + Returns + ------- + ax : matplotlib.pyplot.axes._subplots.AxesSubplot + A handle to the (possibly constructed) plot axes + ''' + + # Get the axes handle + ax, _ = __get_axes(ax=ax) + + # Set up a style for the plot + style_voiced = dict() + style_voiced.update(next(ax._get_lines.prop_cycler)) + style_voiced.update(kwargs) + + style_unvoiced = style_voiced.copy() + style_unvoiced.pop('label', None) + style_unvoiced['alpha'] = style_unvoiced.get('alpha', 1.0) * 0.5 + + # We'll collect all times and frequencies first, then plot them + voiced_times = [] + voiced_freqs = [] + + unvoiced_times = [] + unvoiced_freqs = [] + + for t, freqs in zip(times, frequencies): + if not len(freqs): + continue + + freqs, voicings = freq_to_voicing(np.asarray(freqs, dtype=np.float)) + + # Discard all 0-frequency measurements + idx = freqs > 0 + freqs = freqs[idx] + voicings = voicings[idx] + + if midi: + freqs = hz_to_midi(freqs) + + n_voiced = sum(voicings) + voiced_times.extend([t] * n_voiced) + voiced_freqs.extend(freqs[voicings]) + unvoiced_times.extend([t] * (len(freqs) - n_voiced)) + unvoiced_freqs.extend(freqs[~voicings]) + + # Plot the voiced frequencies + ax.scatter(voiced_times, voiced_freqs, **style_voiced) + + # Plot the unvoiced frequencies + if unvoiced: + ax.scatter(unvoiced_times, unvoiced_freqs, **style_unvoiced) + + # Tick at integer midi notes + if midi: + ax.yaxis.set_minor_locator(MultipleLocator(1)) + return ax + + +def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs): + '''Plot a quantized piano roll as intervals + + Parameters + ---------- + intervals : np.ndarray, shape=(n, 2) + timing intervals for notes + + pitches : np.ndarray, shape=(n,), optional + pitches of notes (in Hz). + + midi : np.ndarray, shape=(n,), optional + pitches of notes (in MIDI numbers). + + At least one of `pitches` or `midi` must be provided. + + ax : matplotlib.pyplot.axes + An axis handle on which to draw the intervals. + If none is provided, a new set of axes is created. + + kwargs : + Additional keyword arguments to `labeled_intervals`. + + Returns + ------- + ax : matplotlib.pyplot.axes._subplots.AxesSubplot + A handle to the (possibly constructed) plot axes + ''' + + if midi is None: + midi = hz_to_midi(pitches) + + scale = np.arange(128) + ax = labeled_intervals(intervals, np.round(midi).astype(int), + label_set=scale, + tick=False, + ax=ax, + **kwargs) + + # Minor tick at each semitone + ax.yaxis.set_minor_locator(MultipleLocator(1)) + + ax.axis('auto') + return ax + + +def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): + '''Source-separation visualization + + Parameters + ---------- + sources : np.ndarray, shape=(nsrc, nsampl) + A list of waveform buffers corresponding to each source + + fs : number > 0 + The sampling rate + + labels : list of strings + An optional list of descriptors corresponding to each source + + alpha : float in [0, 1] + Maximum alpha (opacity) of spectrogram values. + + ax : matplotlib.pyplot.axes + An axis handle on which to draw the spectrograms. + If none is provided, a new set of axes is created. + + kwargs : + Additional keyword arguments to `scipy.signal.spectrogram` + + Returns + ------- + ax + The axis handle for this plot + ''' + + # Get the axes handle + ax, new_axes = __get_axes(ax=ax) + + # Make sure we have at least two dimensions + sources = np.atleast_2d(sources) + + if labels is None: + labels = ['Source {:d}'.format(_) for _ in range(len(sources))] + + kwargs.setdefault('scaling', 'spectrum') + + # The cumulative spectrogram across sources + # is used to establish the reference power + # for each individual source + cumspec = None + specs = [] + for i, src in enumerate(sources): + freqs, times, spec = spectrogram(src, fs=fs, **kwargs) + + specs.append(spec) + if cumspec is None: + cumspec = spec.copy() + else: + cumspec += spec + + ref_max = cumspec.max() + ref_min = ref_max * 1e-6 + + color_conv = ColorConverter() + + for i, spec in enumerate(specs): + + # For each source, grab a new color from the cycler + # Then construct a colormap that interpolates from + # [transparent white -> new color] + color = next(ax._get_lines.prop_cycler)['color'] + color = color_conv.to_rgba(color, alpha=alpha) + cmap = LinearSegmentedColormap.from_list(labels[i], + [(1.0, 1.0, 1.0, 0.0), + color]) + + ax.pcolormesh(times, freqs, spec, + cmap=cmap, + norm=LogNorm(vmin=ref_min, vmax=ref_max), + shading='gouraud', + label=labels[i]) + + # Attach a 0x0 rect to the axis with the corresponding label + # This way, it will show up in the legend + ax.add_patch(Rectangle((0, 0), 0, 0, color=color, label=labels[i])) + + if new_axes: + ax.axis('tight') + + return ax + + +def __ticker_midi_note(x, pos): + '''A ticker function for midi notes. + + Inputs x are interpreted as midi numbers, and converted + to [NOTE][OCTAVE]+[cents]. + ''' + + NOTES = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] + + cents = float(np.mod(x, 1.0)) + if cents >= 0.5: + cents = cents - 1.0 + x = x + 0.5 + + idx = int(x % 12) + + octave = int(x / 12) - 1 + + if cents == 0: + return '{:s}{:2d}'.format(NOTES[idx], octave) + return '{:s}{:2d}{:+02d}'.format(NOTES[idx], octave, int(cents * 100)) + + +def __ticker_midi_hz(x, pos): + '''A ticker function for midi pitches. + + Inputs x are interpreted as midi numbers, and converted + to Hz. + ''' + + return '{:g}'.format(midi_to_hz(x)) + + +def ticker_notes(ax=None): + '''Set the y-axis of the given axes to MIDI notes + + Parameters + ---------- + ax : matplotlib.pyplot.axes + The axes handle to apply the ticker. + By default, uses the current axes handle. + + ''' + ax, _ = __get_axes(ax=ax) + + ax.yaxis.set_major_formatter(FMT_MIDI_NOTE) + # Get the tick labels and reset the vertical alignment + for tick in ax.yaxis.get_ticklabels(): + tick.set_verticalalignment('baseline') + + +def ticker_pitch(ax=None): + '''Set the y-axis of the given axes to MIDI frequencies + + Parameters + ---------- + ax : matplotlib.pyplot.axes + The axes handle to apply the ticker. + By default, uses the current axes handle. + ''' + ax, _ = __get_axes(ax=ax) + + ax.yaxis.set_major_formatter(FMT_MIDI_HZ) + + +# Instantiate ticker objects; we don't need more than one of each +FMT_MIDI_NOTE = FuncFormatter(__ticker_midi_note) +FMT_MIDI_HZ = FuncFormatter(__ticker_midi_hz) diff --git a/mir_eval/util.py b/mir_eval/util.py index 7b145f52..dacc3c63 100644 --- a/mir_eval/util.py +++ b/mir_eval/util.py @@ -891,3 +891,36 @@ def intervals_to_durations(intervals): """ validate_intervals(intervals) return np.abs(np.diff(intervals, axis=-1)).flatten() + + +def hz_to_midi(freqs): + '''Convert Hz to MIDI numbers + + Parameters + ---------- + freqs : number or ndarray + Frequency/frequencies in Hz + + Returns + ------- + midi : number or ndarray + MIDI note numbers corresponding to input frequencies. + Note that these may be fractional. + ''' + return 12.0 * (np.log2(freqs) - np.log2(440.0)) + 69.0 + + +def midi_to_hz(midi): + '''Convert MIDI numbers to Hz + + Parameters + ---------- + midi : number or ndarray + MIDI notes + + Returns + ------- + freqs : number or ndarray + Frequency/frequencies in Hz corresponding to `midi` + ''' + return 440.0 * (2.0 ** ((midi - 69.0)/12.0)) diff --git a/setup.py b/setup.py index 181ef595..55ace61d 100644 --- a/setup.py +++ b/setup.py @@ -29,4 +29,8 @@ 'future', 'six' ], + extras_require={ + 'display': ['matplotlib>=1.5.0', + 'scipy>=0.16.0'] + } ) diff --git a/tests/baseline_images/tests.test_display/events.png b/tests/baseline_images/tests.test_display/events.png new file mode 100644 index 00000000..6d30b18e Binary files /dev/null and b/tests/baseline_images/tests.test_display/events.png differ diff --git a/tests/baseline_images/tests.test_display/hierarchy_label.png b/tests/baseline_images/tests.test_display/hierarchy_label.png new file mode 100644 index 00000000..360286b6 Binary files /dev/null and b/tests/baseline_images/tests.test_display/hierarchy_label.png differ diff --git a/tests/baseline_images/tests.test_display/hierarchy_nolabel.png b/tests/baseline_images/tests.test_display/hierarchy_nolabel.png new file mode 100644 index 00000000..c2f0e213 Binary files /dev/null and b/tests/baseline_images/tests.test_display/hierarchy_nolabel.png differ diff --git a/tests/baseline_images/tests.test_display/labeled_events.png b/tests/baseline_images/tests.test_display/labeled_events.png new file mode 100644 index 00000000..a5bdc35e Binary files /dev/null and b/tests/baseline_images/tests.test_display/labeled_events.png differ diff --git a/tests/baseline_images/tests.test_display/labeled_intervals.png b/tests/baseline_images/tests.test_display/labeled_intervals.png new file mode 100644 index 00000000..766fd351 Binary files /dev/null and b/tests/baseline_images/tests.test_display/labeled_intervals.png differ diff --git a/tests/baseline_images/tests.test_display/labeled_intervals_compare.png b/tests/baseline_images/tests.test_display/labeled_intervals_compare.png new file mode 100644 index 00000000..0d9d54be Binary files /dev/null and b/tests/baseline_images/tests.test_display/labeled_intervals_compare.png differ diff --git a/tests/baseline_images/tests.test_display/labeled_intervals_compare_common.png b/tests/baseline_images/tests.test_display/labeled_intervals_compare_common.png new file mode 100644 index 00000000..15d2e2d0 Binary files /dev/null and b/tests/baseline_images/tests.test_display/labeled_intervals_compare_common.png differ diff --git a/tests/baseline_images/tests.test_display/labeled_intervals_compare_noextend.png b/tests/baseline_images/tests.test_display/labeled_intervals_compare_noextend.png new file mode 100644 index 00000000..df2a9374 Binary files /dev/null and b/tests/baseline_images/tests.test_display/labeled_intervals_compare_noextend.png differ diff --git a/tests/baseline_images/tests.test_display/labeled_intervals_noextend.png b/tests/baseline_images/tests.test_display/labeled_intervals_noextend.png new file mode 100644 index 00000000..766fd351 Binary files /dev/null and b/tests/baseline_images/tests.test_display/labeled_intervals_noextend.png differ diff --git a/tests/baseline_images/tests.test_display/multipitch_hz_unvoiced.png b/tests/baseline_images/tests.test_display/multipitch_hz_unvoiced.png new file mode 100644 index 00000000..0af2edde Binary files /dev/null and b/tests/baseline_images/tests.test_display/multipitch_hz_unvoiced.png differ diff --git a/tests/baseline_images/tests.test_display/multipitch_hz_voiced.png b/tests/baseline_images/tests.test_display/multipitch_hz_voiced.png new file mode 100644 index 00000000..0af2edde Binary files /dev/null and b/tests/baseline_images/tests.test_display/multipitch_hz_voiced.png differ diff --git a/tests/baseline_images/tests.test_display/multipitch_midi.png b/tests/baseline_images/tests.test_display/multipitch_midi.png new file mode 100644 index 00000000..6ead7319 Binary files /dev/null and b/tests/baseline_images/tests.test_display/multipitch_midi.png differ diff --git a/tests/baseline_images/tests.test_display/piano_roll.png b/tests/baseline_images/tests.test_display/piano_roll.png new file mode 100644 index 00000000..483ae6ae Binary files /dev/null and b/tests/baseline_images/tests.test_display/piano_roll.png differ diff --git a/tests/baseline_images/tests.test_display/piano_roll_midi.png b/tests/baseline_images/tests.test_display/piano_roll_midi.png new file mode 100644 index 00000000..483ae6ae Binary files /dev/null and b/tests/baseline_images/tests.test_display/piano_roll_midi.png differ diff --git a/tests/baseline_images/tests.test_display/pitch_hz.png b/tests/baseline_images/tests.test_display/pitch_hz.png new file mode 100644 index 00000000..1c97b007 Binary files /dev/null and b/tests/baseline_images/tests.test_display/pitch_hz.png differ diff --git a/tests/baseline_images/tests.test_display/pitch_midi.png b/tests/baseline_images/tests.test_display/pitch_midi.png new file mode 100644 index 00000000..fdad939b Binary files /dev/null and b/tests/baseline_images/tests.test_display/pitch_midi.png differ diff --git a/tests/baseline_images/tests.test_display/pitch_midi_hz.png b/tests/baseline_images/tests.test_display/pitch_midi_hz.png new file mode 100644 index 00000000..c96e3cab Binary files /dev/null and b/tests/baseline_images/tests.test_display/pitch_midi_hz.png differ diff --git a/tests/baseline_images/tests.test_display/segment.png b/tests/baseline_images/tests.test_display/segment.png new file mode 100644 index 00000000..bf577ed8 Binary files /dev/null and b/tests/baseline_images/tests.test_display/segment.png differ diff --git a/tests/baseline_images/tests.test_display/segment_text.png b/tests/baseline_images/tests.test_display/segment_text.png new file mode 100644 index 00000000..51c286e5 Binary files /dev/null and b/tests/baseline_images/tests.test_display/segment_text.png differ diff --git a/tests/baseline_images/tests.test_display/separation.png b/tests/baseline_images/tests.test_display/separation.png new file mode 100644 index 00000000..80c3f39e Binary files /dev/null and b/tests/baseline_images/tests.test_display/separation.png differ diff --git a/tests/baseline_images/tests.test_display/separation_label.png b/tests/baseline_images/tests.test_display/separation_label.png new file mode 100644 index 00000000..f0297cab Binary files /dev/null and b/tests/baseline_images/tests.test_display/separation_label.png differ diff --git a/tests/baseline_images/tests.test_display/ticker_midi_zoom.png b/tests/baseline_images/tests.test_display/ticker_midi_zoom.png new file mode 100644 index 00000000..29efe7dc Binary files /dev/null and b/tests/baseline_images/tests.test_display/ticker_midi_zoom.png differ diff --git a/tests/mpl_ic.py b/tests/mpl_ic.py new file mode 100644 index 00000000..17617972 --- /dev/null +++ b/tests/mpl_ic.py @@ -0,0 +1,349 @@ +# CREATED:2015-02-17 14:41:28 by Brian McFee +# this function is lifted wholesale from matploblib v1.4.2, +# and modified so that images are stored explicitly under the tests path + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import six + +import functools +import gc +import os +import sys +import shutil +import warnings +import unittest + +import nose +import numpy as np + +import matplotlib.units +from matplotlib import cbook +from matplotlib import ticker +from matplotlib import pyplot as plt +from matplotlib import ft2font +from matplotlib.testing.noseclasses import KnownFailureTest, \ + KnownFailureDidNotFailTest, ImageComparisonFailure +from matplotlib.testing.compare import comparable_formats, compare_images, \ + make_test_filename + + +def knownfailureif(fail_condition, msg=None, known_exception_class=None): + """ + + Assume a will fail if *fail_condition* is True. *fail_condition* + may also be False or the string 'indeterminate'. + + *msg* is the error message displayed for the test. + + If *known_exception_class* is not None, the failure is only known + if the exception is an instance of this class. (Default = None) + + """ + # based on numpy.testing.dec.knownfailureif + if msg is None: + msg = 'Test known to fail' + + def known_fail_decorator(f): + # Local import to avoid a hard nose dependency and only incur the + # import time overhead at actual test-time. + import nose + + def failer(*args, **kwargs): + try: + # Always run the test (to generate images). + result = f(*args, **kwargs) + except Exception as err: + if fail_condition: + if known_exception_class is not None: + if not isinstance(err, known_exception_class): + # This is not the expected exception + raise + # (Keep the next ultra-long comment so in shows in + # console.) + # An error here when running nose means that you don't have + # the matplotlib.testing.noseclasses:KnownFailure plugin in + # use. + raise KnownFailureTest(msg) + else: + raise + if fail_condition and fail_condition != 'indeterminate': + raise KnownFailureDidNotFailTest(msg) + return result + return nose.tools.make_decorator(f)(failer) + return known_fail_decorator + + +def _do_cleanup(original_units_registry): + plt.close('all') + gc.collect() + + import matplotlib.testing + matplotlib.testing.setup() + + matplotlib.units.registry.clear() + matplotlib.units.registry.update(original_units_registry) + warnings.resetwarnings() # reset any warning filters set in tests + + +class CleanupTest(object): + @classmethod + def setup_class(cls): + cls.original_units_registry = matplotlib.units.registry.copy() + + @classmethod + def teardown_class(cls): + _do_cleanup(cls.original_units_registry) + + def test(self): + self._func() + + +class CleanupTestCase(unittest.TestCase): + '''A wrapper for unittest.TestCase that includes cleanup operations''' + @classmethod + def setUpClass(cls): + import matplotlib.units + cls.original_units_registry = matplotlib.units.registry.copy() + + @classmethod + def tearDownClass(cls): + _do_cleanup(cls.original_units_registry) + + +def cleanup(func): + @functools.wraps(func) + def wrapped_function(*args, **kwargs): + original_units_registry = matplotlib.units.registry.copy() + try: + func(*args, **kwargs) + finally: + _do_cleanup(original_units_registry) + + return wrapped_function + + +def check_freetype_version(ver): + if ver is None: + return True + + from distutils import version + if isinstance(ver, six.string_types): + ver = (ver, ver) + ver = [version.StrictVersion(x) for x in ver] + found = version.StrictVersion(ft2font.__freetype_version__) + + return found >= ver[0] and found <= ver[1] + + +class ImageComparisonTest(CleanupTest): + @classmethod + def setup_class(cls): + CleanupTest.setup_class() + + cls._func() + + @staticmethod + def remove_text(figure): + figure.suptitle("") + for ax in figure.get_axes(): + ax.set_title("") + ax.xaxis.set_major_formatter(ticker.NullFormatter()) + ax.xaxis.set_minor_formatter(ticker.NullFormatter()) + ax.yaxis.set_major_formatter(ticker.NullFormatter()) + ax.yaxis.set_minor_formatter(ticker.NullFormatter()) + try: + ax.zaxis.set_major_formatter(ticker.NullFormatter()) + ax.zaxis.set_minor_formatter(ticker.NullFormatter()) + except AttributeError: + pass + + def test(self): + baseline_dir, result_dir = _image_directories(self._func) + + for fignum, baseline in zip(plt.get_fignums(), self._baseline_images): + for extension in self._extensions: + will_fail = extension not in comparable_formats() + if will_fail: + fail_msg = ('Cannot compare %s files on this system' % + extension) + else: + fail_msg = 'No failure expected' + + orig_expected_fname = ( + os.path.join(baseline_dir, baseline) + '.' + extension) + if (extension == 'eps' and + not os.path.exists(orig_expected_fname)): + orig_expected_fname = ( + os.path.join(baseline_dir, baseline) + '.pdf') + expected_fname = make_test_filename(os.path.join( + result_dir, + os.path.basename(orig_expected_fname)), 'expected') + actual_fname = ( + os.path.join(result_dir, baseline) + '.' + extension) + if os.path.exists(orig_expected_fname): + shutil.copyfile(orig_expected_fname, expected_fname) + else: + will_fail = True + fail_msg = 'Do not have baseline image %s' % expected_fname + + @knownfailureif( + will_fail, fail_msg, + known_exception_class=ImageComparisonFailure) + def do_test(): + figure = plt.figure(fignum) + + if self._remove_text: + self.remove_text(figure) + + figure.savefig(actual_fname, **self._savefig_kwarg) + + plt.close(figure) + + err = compare_images(expected_fname, actual_fname, + self._tol, in_decorator=True) + + try: + if not os.path.exists(expected_fname): + raise ImageComparisonFailure( + 'image does not exist: %s' % expected_fname) + + if err: + raise ImageComparisonFailure( + 'images not close: %(actual)s vs. %(expected)s' + ' (RMS %(rms).3f)' % err) + except ImageComparisonFailure: + if not check_freetype_version(self._freetype_version): + raise KnownFailureTest( + "Mismatched version of freetype. Test " + "requires '%s', you have '%s'" % + (self._freetype_version, + ft2font.__freetype_version__)) + raise + + yield (do_test,) + + +def image_comparison(baseline_images=None, extensions=None, tol=13, + freetype_version=None, remove_text=False, + savefig_kwarg=None): + """ + call signature:: + + image_comparison(baseline_images=['my_figure'], extensions=None) + + Compare images generated by the test with those specified in + *baseline_images*, which must correspond else an + ImageComparisonFailure exception will be raised. + + Keyword arguments: + + *baseline_images*: list + A list of strings specifying the names of the images generated + by calls to :meth:`matplotlib.figure.savefig`. + + *extensions*: [ None | list ] + + If *None*, default to all supported extensions. + + Otherwise, a list of extensions to test. For example ['png','pdf']. + + *tol*: (default 13) + The RMS threshold above which the test is considered failed. + + *freetype_version*: str or tuple + The expected freetype version or range of versions for this + test to pass. + + *remove_text*: bool + Remove the title and tick text from the figure before + comparison. This does not remove other, more deliberate, + text, such as legends and annotations. + + *savefig_kwarg*: dict + Optional arguments that are passed to the savefig method. + + """ + + if baseline_images is None: + raise ValueError('baseline_images must be specified') + + if extensions is None: + # default extensions to test + extensions = ['png', 'pdf', 'svg'] + + if savefig_kwarg is None: + # default no kwargs to savefig + savefig_kwarg = dict() + + def compare_images_decorator(func): + # We want to run the setup function (the actual test function + # that generates the figure objects) only once for each type + # of output file. The only way to achieve this with nose + # appears to be to create a test class with "setup_class" and + # "teardown_class" methods. Creating a class instance doesn't + # work, so we use type() to actually create a class and fill + # it with the appropriate methods. + name = func.__name__ + # For nose 1.0, we need to rename the test function to + # something without the word "test", or it will be run as + # well, outside of the context of our image comparison test + # generator. + func = staticmethod(func) + func.__get__(1).__name__ = str('_private') + new_class = type( + name, + (ImageComparisonTest,), + {'_func': func, + '_baseline_images': baseline_images, + '_extensions': extensions, + '_tol': tol, + '_freetype_version': freetype_version, + '_remove_text': remove_text, + '_savefig_kwarg': savefig_kwarg}) + + return new_class + return compare_images_decorator + + +def _image_directories(func): + """ + Compute the baseline and result image directories for testing *func*. + Create the result directory if it doesn't exist. + """ + module_name = func.__module__ + # mods = module_name.split('.') + # mods.pop(0) # <- will be the name of the package being tested (in + # most cases "matplotlib") + # assert mods.pop(0) == 'tests' + # subdir = os.path.join(*mods) + subdir = module_name + + import imp + + def find_dotted_module(module_name, path=None): + """A version of imp which can handle dots in the module name""" + res = None + for sub_mod in module_name.split('.'): + try: + res = file, path, _ = imp.find_module(sub_mod, path) + path = [path] + if file is not None: + file.close() + except ImportError: + # assume namespace package + path = sys.modules[sub_mod].__path__ + res = None, path, None + return res + + mod_file = find_dotted_module(func.__module__)[1] + basedir = os.path.dirname(mod_file) + + baseline_dir = os.path.join(basedir, 'baseline_images', subdir) + result_dir = os.path.abspath(os.path.join('result_images', subdir)) + + if not os.path.exists(result_dir): + cbook.mkdirs(result_dir) + + return baseline_dir, result_dir diff --git a/tests/test_display.py b/tests/test_display.py new file mode 100644 index 00000000..76d189ae --- /dev/null +++ b/tests/test_display.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +'''Unit tests for the display module''' + +# For testing purposes, clobber the rcfile +import matplotlib + +matplotlib.use('Agg') # nopep8 + +import matplotlib.pyplot as plt +from matplotlib.style import context as style_context +import numpy as np + +# Import the hacked image comparison module +from tests.mpl_ic import image_comparison + +# We'll make a decorator to handle style contexts +from decorator import decorator + +import mir_eval +import mir_eval.display +from mir_eval.io import load_labeled_intervals +from mir_eval.io import load_valued_intervals +from mir_eval.io import load_labeled_events +from mir_eval.io import load_ragged_time_series +from mir_eval.io import load_wav + + +# Test fixtures will use seaborn-muted style +GLOBAL_STYLE = 'seaborn-muted' + + +@decorator +def styled(f, *args, **kwargs): + matplotlib.rcParams.update(matplotlib.rcParamsDefault) + with style_context(GLOBAL_STYLE): + return f(*args, **kwargs) + + +@image_comparison(baseline_images=['segment'], extensions=['png']) +@styled +def test_display_segment(): + + plt.figure() + + # Load some segment data + intervals, labels = load_labeled_intervals('tests/data/segment/ref00.lab') + + # Plot the segments with no labels + mir_eval.display.segments(intervals, labels, text=False) + + # Draw a legend + plt.legend() + + +@image_comparison(baseline_images=['segment_text'], extensions=['png']) +@styled +def test_display_segment_text(): + plt.figure() + + # Load some segment data + intervals, labels = load_labeled_intervals('tests/data/segment/ref00.lab') + + # Plot the segments with no labels + mir_eval.display.segments(intervals, labels, text=True) + + +@image_comparison(baseline_images=['labeled_intervals'], extensions=['png']) +@styled +def test_display_labeled_intervals(): + + plt.figure() + + # Load some chord data + intervals, labels = load_labeled_intervals('tests/data/chord/ref01.lab') + + # Plot the chords with nothing fancy + mir_eval.display.labeled_intervals(intervals, labels) + + +@image_comparison(baseline_images=['labeled_intervals_noextend'], + extensions=['png']) +@styled +def test_display_labeled_intervals_noextend(): + + plt.figure() + + # Load some chord data + intervals, labels = load_labeled_intervals('tests/data/chord/ref01.lab') + + # Plot the chords with nothing fancy + ax = plt.axes() + ax.set_yticklabels([]) + mir_eval.display.labeled_intervals(intervals, labels, + label_set=[], + extend_labels=False, + ax=ax) + + +@image_comparison(baseline_images=['labeled_intervals_compare'], + extensions=['png']) +@styled +def test_display_labeled_intervals_compare(): + + plt.figure() + + # Load some chord data + ref_int, ref_labels = load_labeled_intervals('tests/data/chord/ref01.lab') + est_int, est_labels = load_labeled_intervals('tests/data/chord/est01.lab') + + # Plot reference and estimates using label set extension + mir_eval.display.labeled_intervals(ref_int, ref_labels, + alpha=0.5, label='Reference') + mir_eval.display.labeled_intervals(est_int, est_labels, + alpha=0.5, label='Estimate') + + plt.legend() + + +@image_comparison(baseline_images=['labeled_intervals_compare_noextend'], + extensions=['png']) +@styled +def test_display_labeled_intervals_compare_noextend(): + + plt.figure() + + # Load some chord data + ref_int, ref_labels = load_labeled_intervals('tests/data/chord/ref01.lab') + est_int, est_labels = load_labeled_intervals('tests/data/chord/est01.lab') + + # Plot reference and estimate, but only use the reference labels + mir_eval.display.labeled_intervals(ref_int, ref_labels, + alpha=0.5, label='Reference') + mir_eval.display.labeled_intervals(est_int, est_labels, + extend_labels=False, + alpha=0.5, label='Estimate') + + plt.legend() + + +@image_comparison(baseline_images=['labeled_intervals_compare_common'], + extensions=['png']) +@styled +def test_display_labeled_intervals_compare_common(): + + plt.figure() + + # Load some chord data + ref_int, ref_labels = load_labeled_intervals('tests/data/chord/ref01.lab') + est_int, est_labels = load_labeled_intervals('tests/data/chord/est01.lab') + + label_set = list(sorted(set(ref_labels) | set(est_labels))) + + # Plot reference and estimate with a common label set + mir_eval.display.labeled_intervals(ref_int, ref_labels, + label_set=label_set, + alpha=0.5, label='Reference') + mir_eval.display.labeled_intervals(est_int, est_labels, + label_set=label_set, + alpha=0.5, label='Estimate') + + plt.legend() + + +@image_comparison(baseline_images=['hierarchy_nolabel'], extensions=['png']) +@styled +def test_display_hierarchy_nolabel(): + + plt.figure() + + # Load some chord data + int0, lab0 = load_labeled_intervals('tests/data/hierarchy/ref00.lab') + int1, lab1 = load_labeled_intervals('tests/data/hierarchy/ref01.lab') + + # Plot reference and estimate with a common label set + mir_eval.display.hierarchy([int0, int1], + [lab0, lab1]) + + plt.legend() + + +@image_comparison(baseline_images=['hierarchy_label'], extensions=['png']) +@styled +def test_display_hierarchy_label(): + + plt.figure() + + # Load some chord data + int0, lab0 = load_labeled_intervals('tests/data/hierarchy/ref00.lab') + int1, lab1 = load_labeled_intervals('tests/data/hierarchy/ref01.lab') + + # Plot reference and estimate with a common label set + mir_eval.display.hierarchy([int0, int1], + [lab0, lab1], + levels=['Large', 'Small']) + + plt.legend() + + +@image_comparison(baseline_images=['pitch_hz'], extensions=['png']) +@styled +def test_pitch_hz(): + plt.figure() + + ref_times, ref_freqs = load_labeled_events('tests/data/melody/ref00.txt') + est_times, est_freqs = load_labeled_events('tests/data/melody/est00.txt') + + # Plot pitches on a Hz scale + mir_eval.display.pitch(ref_times, ref_freqs, unvoiced=True, + label='Reference') + mir_eval.display.pitch(est_times, est_freqs, unvoiced=True, + label='Estimate') + plt.legend() + + +@image_comparison(baseline_images=['pitch_midi'], extensions=['png']) +@styled +def test_pitch_midi(): + plt.figure() + + times, freqs = load_labeled_events('tests/data/melody/ref00.txt') + + # Plot pitches on a midi scale with note tickers + mir_eval.display.pitch(times, freqs, midi=True) + mir_eval.display.ticker_notes() + + +@image_comparison(baseline_images=['pitch_midi_hz'], extensions=['png']) +@styled +def test_pitch_midi_hz(): + plt.figure() + + times, freqs = load_labeled_events('tests/data/melody/ref00.txt') + + # Plot pitches on a midi scale with note tickers + mir_eval.display.pitch(times, freqs, midi=True) + mir_eval.display.ticker_pitch() + + +@image_comparison(baseline_images=['multipitch_hz_unvoiced'], + extensions=['png']) +@styled +def test_multipitch_hz_unvoiced(): + plt.figure() + + times, pitches = load_ragged_time_series('tests/data/multipitch/est01.txt') + + # Plot pitches on a midi scale with note tickers + mir_eval.display.multipitch(times, pitches, midi=False, unvoiced=True) + + +@image_comparison(baseline_images=['multipitch_hz_voiced'], extensions=['png']) +@styled +def test_multipitch_hz_voiced(): + plt.figure() + + times, pitches = load_ragged_time_series('tests/data/multipitch/est01.txt') + + mir_eval.display.multipitch(times, pitches, midi=False, unvoiced=False) + + +@image_comparison(baseline_images=['multipitch_midi'], extensions=['png']) +@styled +def test_multipitch_midi(): + plt.figure() + + ref_t, ref_p = load_ragged_time_series('tests/data/multipitch/ref01.txt') + est_t, est_p = load_ragged_time_series('tests/data/multipitch/est01.txt') + + # Plot pitches on a midi scale with note tickers + mir_eval.display.multipitch(ref_t, ref_p, midi=True, + alpha=0.5, label='Reference') + mir_eval.display.multipitch(est_t, est_p, midi=True, + alpha=0.5, label='Estimate') + + plt.legend() + + +@image_comparison(baseline_images=['piano_roll'], extensions=['png']) +@styled +def test_pianoroll(): + plt.figure() + + ref_t, ref_p = load_valued_intervals('tests/data/transcription/ref04.txt') + est_t, est_p = load_valued_intervals('tests/data/transcription/est04.txt') + + mir_eval.display.piano_roll(ref_t, ref_p, + label='Reference', alpha=0.5) + mir_eval.display.piano_roll(est_t, est_p, + label='Estimate', alpha=0.5, facecolor='r') + + plt.legend() + + +@image_comparison(baseline_images=['piano_roll_midi'], extensions=['png']) +@styled +def test_pianoroll_midi(): + plt.figure() + + ref_t, ref_p = load_valued_intervals('tests/data/transcription/ref04.txt') + est_t, est_p = load_valued_intervals('tests/data/transcription/est04.txt') + + ref_midi = mir_eval.util.hz_to_midi(ref_p) + est_midi = mir_eval.util.hz_to_midi(est_p) + mir_eval.display.piano_roll(ref_t, midi=ref_midi, + label='Reference', alpha=0.5) + mir_eval.display.piano_roll(est_t, midi=est_midi, + label='Estimate', alpha=0.5, facecolor='r') + + plt.legend() + + +@image_comparison(baseline_images=['ticker_midi_zoom'], extensions=['png']) +@styled +def test_ticker_midi_zoom(): + + plt.figure() + + plt.plot(np.arange(3)) + mir_eval.display.ticker_notes() + + +@image_comparison(baseline_images=['separation'], extensions=['png']) +@styled +def test_separation(): + plt.figure() + + x0, fs = load_wav('tests/data/separation/ref05/0.wav') + x1, fs = load_wav('tests/data/separation/ref05/1.wav') + x2, fs = load_wav('tests/data/separation/ref05/2.wav') + + mir_eval.display.separation([x0, x1, x2], fs=fs) + + +@image_comparison(baseline_images=['separation_label'], extensions=['png']) +@styled +def test_separation_label(): + plt.figure() + + x0, fs = load_wav('tests/data/separation/ref05/0.wav') + x1, fs = load_wav('tests/data/separation/ref05/1.wav') + x2, fs = load_wav('tests/data/separation/ref05/2.wav') + + mir_eval.display.separation([x0, x1, x2], fs=fs, + labels=['Alice', 'Bob', 'Carol']) + + plt.legend() + + +@image_comparison(baseline_images=['events'], extensions=['png']) +@styled +def test_events(): + plt.figure() + + # Load some event data + beats_ref = mir_eval.io.load_events('tests/data/beat/ref00.txt')[:30] + beats_est = mir_eval.io.load_events('tests/data/beat/est00.txt')[:30] + + # Plot both with labels + mir_eval.display.events(beats_ref, label='reference') + mir_eval.display.events(beats_est, label='estimate') + plt.legend() + + +@image_comparison(baseline_images=['labeled_events'], extensions=['png']) +@styled +def test_labeled_events(): + plt.figure() + + # Load some event data + beats_ref = mir_eval.io.load_events('tests/data/beat/ref00.txt')[:10] + + labels = list('abcdefghijklmnop') + # Plot both with labels + mir_eval.display.events(beats_ref, labels)