Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Plot style managment #176

Merged
merged 9 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
gen_issue_str, gen_width_warning_str)

from fooof.plts.fm import plot_fm
from fooof.plts.style import style_spectrum_plot
from fooof.utils.data import trim_spectrum
from fooof.utils.params import compute_gauss_std
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
Expand Down Expand Up @@ -617,12 +616,13 @@ def get_results(self):
@copy_doc_func_to_method(plot_fm)
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,
add_legend=True, save_fig=False, file_name=None, file_path=None,
ax=None, plot_style=style_spectrum_plot,
data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None):
ax=None, data_kwargs=None, model_kwargs=None,
aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):

plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend,
save_fig, file_name, file_path, ax, plot_style,
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs)
plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log,
add_legend=add_legend, save_fig=save_fig, file_name=file_name,
file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs,
aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs)


@copy_doc_func_to_method(save_report_fm)
Expand Down
4 changes: 2 additions & 2 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,9 @@ def get_params(self, name, col=None):


@copy_doc_func_to_method(plot_fg)
def plot(self, save_fig=False, file_name=None, file_path=None):
def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs):

plot_fg(self, save_fig, file_name, file_path)
plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs)


@copy_doc_func_to_method(save_report_fg)
Expand Down
41 changes: 18 additions & 23 deletions fooof/plts/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from fooof.core.funcs import gaussian_function
from fooof.core.modutils import safe_import, check_dependency
from fooof.sim.gen import gen_aperiodic
from fooof.plts.utils import check_ax
from fooof.plts.utils import check_ax, savefig
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
from fooof.plts.style import check_n_style, style_spectrum_plot
from fooof.plts.style import style_spectrum_plot
from fooof.analysis.periodic import get_band_peak_fm
from fooof.utils.params import compute_knee_frequency, compute_fwhm

Expand All @@ -20,16 +20,15 @@
###################################################################################################
###################################################################################################

@savefig
@check_dependency(plt, 'matplotlib')
def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
def plot_annotated_peak_search(fm):
"""Plot a series of plots illustrating the peak search from a flattened spectrum.

Parameters
----------
fm : FOOOF
FOOOF object, with model fit, data and settings available.
plot_style : callable, optional, default: style_spectrum_plot
A function to call to apply styling & aesthetics to the plots.
"""

# Recalculate the initial aperiodic fit and flattened spectrum that
Expand All @@ -46,14 +45,12 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
# This forces the creation of a new plotting axes per iteration
ax = check_ax(None, PLT_FIGSIZES['spectral'])

plot_spectrum(fm.freqs, flatspec, ax=ax, plot_style=None,
label='Flattened Spectrum', color=PLT_COLORS['data'], linewidth=2.5)
plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs),
ax=ax, plot_style=None, label='Relative Threshold',
color='orange', linewidth=2.5, linestyle='dashed')
plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs),
ax=ax, plot_style=None, label='Absolute Threshold',
color='red', linewidth=2.5, linestyle='dashed')
plot_spectrum(fm.freqs, flatspec, ax=ax, linewidth=2.5,
label='Flattened Spectrum', color=PLT_COLORS['data'])
plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), ax=ax,
label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed')
plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs), ax=ax,
label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed')

maxi = np.argmax(flatspec)
ax.plot(fm.freqs[maxi], flatspec[maxi], '.',
Expand All @@ -65,18 +62,18 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
if ind < fm.n_peaks_:

gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :])
plot_spectrum(fm.freqs, gauss, ax=ax, plot_style=None,
label='Gaussian Fit', color=PLT_COLORS['periodic'],
linestyle=':', linewidth=3.0)
plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit',
color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0)

flatspec = flatspec - gauss

check_n_style(plot_style, ax, False, True)
style_spectrum_plot(ax, False, True)


@savefig
@check_dependency(plt, 'matplotlib')
def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True,
ax=None, plot_style=style_spectrum_plot):
def plot_annotated_model(fm, plt_log=False, annotate_peaks=True,
annotate_aperiodic=True, ax=None):
"""Plot a an annotated power spectrum and model, from a FOOOF object.

Parameters
Expand All @@ -87,8 +84,6 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
plot_style : callable, optional, default: style_spectrum_plot
A function to call to apply styling & aesthetics to the plots.

Raises
------
Expand All @@ -108,7 +103,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio

# Create the baseline figure
ax = check_ax(ax, PLT_FIGSIZES['spectral'])
fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, plot_style=None,
fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax,
data_kwargs={'lw' : lw1, 'alpha' : 0.6},
aperiodic_kwargs={'lw' : lw1, 'zorder' : 10},
model_kwargs={'lw' : lw1, 'alpha' : 0.5},
Expand Down Expand Up @@ -215,7 +210,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
color=PLT_COLORS['aperiodic'], fontsize=fontsize)

# Apply style to plot & tune grid styling
check_n_style(plot_style, ax, plt_log, True)
style_spectrum_plot(ax, plt_log, True)
ax.grid(True, alpha=0.5)

# Add labels to plot in the legend
Expand Down
45 changes: 20 additions & 25 deletions fooof/plts/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
from itertools import cycle

import numpy as np
import matplotlib.pyplot as plt

from fooof.sim.gen import gen_freqs, gen_aperiodic
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
from fooof.plts.style import check_n_style, style_param_plot
from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs
from fooof.plts.style import style_param_plot, style_plot
from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs

plt = safe_import('.pyplot', 'matplotlib')

###################################################################################################
###################################################################################################

@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_params(aps, colors=None, labels=None,
ax=None, plot_style=style_param_plot, **plot_kwargs):
def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot aperiodic parameters as dots representing offset and exponent value.

Parameters
Expand All @@ -30,38 +32,38 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
Label(s) for plotted data, to be added in a legend.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
plot_style : callable, optional, default: style_param_plot
A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
Keyword arguments to pass into the ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))

if isinstance(aps, list):
recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels,
plot_style=plot_style, **plot_kwargs)
recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels)

else:

# Unpack data: offset as x; exponent as y
xs, ys = aps[:, 0], aps[:, -1]
sizes = plot_kwargs.pop('s', 150)

# Create the plot
plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7})
ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs)

# Add axis labels
ax.set_xlabel('Offset')
ax.set_ylabel('Exponent')

check_n_style(plot_style, ax)
style_param_plot(ax)


@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
log_freqs=False, colors=None, labels=None,
ax=None, plot_style=style_param_plot, **plot_kwargs):
ax=None, **plot_kwargs):
"""Plot reconstructions of model aperiodic fits.

Parameters
Expand All @@ -80,10 +82,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
Label(s) for plotted data, to be added in a legend.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
plot_style : callable, optional, default: style_param_plot
A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
Keyword arguments to pass into the ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
Expand All @@ -93,11 +93,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
if not colors:
colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

recursive_plot(aps, plot_function=plot_aperiodic_fits, ax=ax,
freq_range=tuple(freq_range),
control_offset=control_offset,
log_freqs=log_freqs, colors=colors, labels=labels,
plot_style=plot_style, **plot_kwargs)
recursive_plot(aps, plot_aperiodic_fits, ax=ax, freq_range=tuple(freq_range),
control_offset=control_offset, log_freqs=log_freqs, colors=colors,
labels=labels, **plot_kwargs)
else:

freqs = gen_freqs(freq_range, 0.1)
Expand All @@ -118,17 +116,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
# Recreate & plot the aperiodic component from parameters
ap_vals = gen_aperiodic(freqs, ap_params)

plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25})
ax.plot(plt_freqs, ap_vals, color=colors, **plot_kwargs)
ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)

# Collect a running average across components
avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0)

# Plot the average component
avg = avg_vals / aps.shape[0]
avg_color = 'black' if not colors else colors
ax.plot(plt_freqs, avg, linewidth=plot_kwargs.get('linewidth')*3,
color=avg_color, label=labels)
ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels)

# Add axis labels
ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency')
Expand All @@ -137,5 +133,4 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
# Set plot limit
ax.set_xlim(np.log10(freq_range) if log_freqs else freq_range)

# Apply plot style
check_n_style(plot_style, ax)
style_param_plot(ax)
17 changes: 8 additions & 9 deletions fooof/plts/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES
from fooof.plts.style import check_n_style, style_spectrum_plot
from fooof.plts.utils import check_ax
from fooof.plts.style import style_spectrum_plot, style_plot
from fooof.plts.utils import check_ax, savefig

plt = safe_import('.pyplot', 'matplotlib')

###################################################################################################
###################################################################################################

@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **plot_kwargs):
"""Plot frequency by frequency error values.

Parameters
Expand All @@ -31,17 +32,15 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
Whether to plot the frequency axis in log spacing.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
plot_style : callable, optional, default: style_spectrum_plot
A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
Keyword arguments to be passed to `plot_spectra` or to the plot call.
Keyword arguments to pass into the ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))

plt_freqs = np.log10(freqs) if log_freqs else freqs

plot_spectrum(plt_freqs, error, plot_style=None, ax=ax, linewidth=3, **plot_kwargs)
plot_spectrum(plt_freqs, error, ax=ax, linewidth=3)

if np.any(shade):
ax.fill_between(plt_freqs, error-shade, error+shade, alpha=0.25)
Expand All @@ -51,5 +50,5 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
ax.set_ylim([0, ymax])
ax.set_xlim(plt_freqs.min(), plt_freqs.max())

check_n_style(plot_style, ax, log_freqs, True)
style_spectrum_plot(ax, log_freqs, True)
ax.set_ylabel('Absolute Error')
Loading