Skip to content

Commit

Permalink
Split, simplify, and handle warnings (#66)
Browse files Browse the repository at this point in the history
- Split translate.py into basic translations, staying in translate.py, and more complicated tools in tools.py
- Add helper methods to Axes and AxesImage classes to simplify the main methods; e.g. method for interpreting data keyword to scatter()
- Avoid some warnings and deal with new deprecation warnings in recent pyqtgraph versions
- Simplify several functions in different files
- Expand test for dealias() and fix a typo in its keyword/alias dictionary
- kwargs are passed through Figure.gca() to add_subplot
- Add test for detecting and handling axes that have been deleted by Qt
- Deduplicate ugly try/except for removing stuff from Qt layout
  • Loading branch information
eldond authored Oct 6, 2021
1 parent c61cf3f commit 924cc78
Show file tree
Hide file tree
Showing 10 changed files with 395 additions and 283 deletions.
146 changes: 92 additions & 54 deletions pgmpl/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
# pgmpl
# noinspection PyUnresolvedReferences
import pgmpl.__init__ # __init__ does setup stuff like making sure a QApp exists
from pgmpl.translate import plotkw_translator, color_translator, setup_pen_kw, color_map_translator, dealias
from pgmpl.translate import color_translator, color_map_translator
from pgmpl.tools import setup_pen_kw, plotkw_translator
from pgmpl.legend import Legend
from pgmpl.util import printd, tolist, is_numeric
from pgmpl.tools import dealias
from pgmpl.text import Text
from pgmpl.contour import QuadContourSet

Expand Down Expand Up @@ -146,6 +148,34 @@ def _make_custom_verts(verts):
Symbols[key] = pg.arrayToQPath(verts_x, verts_y, connect='all')
return key

@staticmethod
def _interpret_xy_scatter_data(*args, **kwargs):
x, y = (list(args) + [None] * (2 - len(args)))[:3]
data = kwargs.pop('data', None)
if data is not None:
x = data.get('x')
y = data.get('y')
kwargs['s'] = data.get('s', None)
kwargs['c'] = data.get('c', None)
kwargs['edgecolors'] = data.get('edgecolors', None)
kwargs['linewidths'] = data.get('linewidths', None)
# The following keywords are apparently valid within `data`,
# but they'd conflict with `c`, so they've been neglected: color facecolor facecolors
return x, y, kwargs

@staticmethod
def _setup_scatter_symbol_pen(brush_edges, linewidths):
"""Sets up the pen for drawing symbols on scatter plot"""
sympen_kw = [{'color': cc} for cc in brush_edges]
if linewidths is not None:
n = len(brush_edges)
if (len(tolist(linewidths)) == 1) and (n > 1):
# Make list of lw the same length as x for cases where only one setting value was provided
linewidths = tolist(linewidths) * n
for i in range(n):
sympen_kw[i]['width'] = linewidths[i]
return [pg.mkPen(**spkw) for spkw in sympen_kw]

def scatter(self, x=None, y=None, **kwargs):
"""
Translates arguments and keywords for matplotlib.axes.Axes.scatter() method so they can be passed to pyqtgraph.
Expand Down Expand Up @@ -198,41 +228,24 @@ def scatter(self, x=None, y=None, **kwargs):
:return: plotItem instance created by plot()
"""
data = kwargs.pop('data', None)
linewidths = kwargs.pop('linewidths', None)
if data is not None:
x = data.get('x')
y = data.get('y')
kwargs['s'] = data.get('s', None)
kwargs['c'] = data.get('c', None)
kwargs['edgecolors'] = data.get('edgecolors', None)
linewidths = data.get('linewidths', None)
# The following keywords are apparently valid within `data`,
# but they'd conflict with `c`, so they've been neglected: color facecolor facecolors
x, y, kwargs = self._interpret_xy_scatter_data(x, y, **kwargs)
n = len(x)

linewidths = kwargs.pop('linewidths', None)
brush_colors, brush_edges = self._prep_scatter_colors(n, **kwargs)

for popit in ['cmap', 'norm', 'vmin', 'vmax', 'alpha', 'edgecolors', 'c']:
kwargs.pop(popit, None) # Make sure all the color keywords are gone now that they've been used.

# Make the lists of symbol settings the same length as x for cases where only one setting value was provided
if linewidths is not None and (len(tolist(linewidths)) == 1) and (n > 1):
linewidths = tolist(linewidths) * n

# Catch & translate other keywords
kwargs['markersize'] = kwargs.pop('s', 10)
kwargs.setdefault('marker', 'o')
plotkw = plotkw_translator(**kwargs)

# Fill in keywords we already prepared
sympen_kw = [{'color': cc} for cc in brush_edges]
if linewidths is not None:
for i in range(n):
sympen_kw[i]['width'] = linewidths[i]
plotkw['symbolPen'] = self._setup_scatter_symbol_pen(brush_edges, linewidths)
plotkw['pen'] = None
plotkw['symbolBrush'] = [pg.mkBrush(color=cc) for cc in brush_colors]
plotkw['symbolPen'] = [pg.mkPen(**spkw) for spkw in sympen_kw]

plotkw['symbol'] = plotkw.get('symbol', None) or self._make_custom_verts(kwargs.pop('verts', None))
return super(Axes, self).plot(x=x, y=y, **plotkw)

Expand Down Expand Up @@ -430,7 +443,7 @@ def _draw_errbar_caps(self, x, y, **capkw):
self._errbar_ycap_mark(x, y, yerr, **capkw)

@staticmethod
def _sanitize_errbar_data(x, y=None, xerr=None, yerr=None, mask=None):
def _sanitize_errbar_data(*args, mask=None):
"""
Helper function for errorbar. Does not map to a matplotlib method.
Expand All @@ -448,6 +461,9 @@ def _sanitize_errbar_data(x, y=None, xerr=None, yerr=None, mask=None):
:return: tuple of sanitized x, y, xerr, yerr
"""

x, y, xerr, yerr = (list(args) + [None] * (4 - len(args)))[:5]


def prep(v):
"""
Prepares a value so it has the appropriate dimensions with proper filtering to respect errorevery keyword
Expand All @@ -468,32 +484,48 @@ def prep(v):

return prep(x), prep(y), prep(xerr), prep(yerr)

def errorbar(self, x=None, y=None, yerr=None, xerr=None, **kwargs):
def _interpret_xy_errorbar_data(self, *args, **kwargs):
"""
Imitates matplotlib.axes.Axes.errorbar
https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.errorbar.html
Interprets x, and y arguments and xerr, yerr, and data keywords
:return: pyqtgraph.ErrorBarItem instance
Does not include the line through nominal values as would be included in matplotlib's errorbar; this is
drawn, but it is a separate object.
:return: tuple containing
x, y, xerr, yerr
"""
kwargs = dealias(**kwargs)
data = kwargs.pop('data', None)
x, y, xerr, yerr = (list(args) + [None] * (4 - len(args)))[:5]

if data is not None:
x = data.get('x', None)
y = data.get('y', None)
xerr = data.get('xerr', None)
yerr = data.get('yerr', None)
return x, y, xerr, yerr

# Separate keywords into those that affect a line through the data and those that affect the errorbars
def _process_errorbar_keywords(self, **kwargs):
"""Separate keywords affecting error bars from those affecting nominal values & translate to pyqtgraph"""
ekwargs = copy.deepcopy(kwargs)
if kwargs.get('ecolor', None) is not None:
ekwargs['color'] = kwargs.pop('ecolor')
if kwargs.get('elinewidth', None) is not None:
ekwargs['linewidth'] = kwargs.pop('elinewidth')
epgkw = plotkw_translator(**ekwargs)
return epgkw

def errorbar(self, x=None, y=None, yerr=None, xerr=None, **kwargs):
"""
Imitates matplotlib.axes.Axes.errorbar
https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.errorbar.html
:return: pyqtgraph.ErrorBarItem instance
Does not include the line through nominal values as would be included in matplotlib's errorbar; this is
drawn, but it is a separate object.
"""
kwargs = dealias(**kwargs)
x, y, xerr, yerr = self._interpret_xy_errorbar_data(x, y, xerr, yerr, **kwargs)

# Separate keywords into those that affect a line through the data and those that affect the errorbars
epgkw = self._process_errorbar_keywords(**kwargs)
w = np.array([True if i % int(round(kwargs.pop('errorevery', 1))) == 0 else False
for i in range(len(np.atleast_1d(x)))])

Expand All @@ -502,7 +534,7 @@ def errorbar(self, x=None, y=None, yerr=None, xerr=None, **kwargs):
self.plot(x, y, **kwargs)

# Draw the errorbars
xp, yp, xerrp, yerrp = self._sanitize_errbar_data(x, y, xerr, yerr, w)
xp, yp, xerrp, yerrp = self._sanitize_errbar_data(x, y, xerr, yerr, mask=w)

errb = pg.ErrorBarItem(
x=xp, y=yp, height=0 if yerr is None else yerrp*2, width=0 if xerr is None else xerrp*2, **epgkw
Expand Down Expand Up @@ -707,41 +739,47 @@ def __init__(self, x=None, **kwargs):
self.cmap = kwargs.pop('cmap', None)
self.norm = kwargs.pop('norm', None)
self.alpha = kwargs.pop('alpha', None)
vmin = kwargs.pop('vmin', None)
vmax = kwargs.pop('vmax', None)
origin = kwargs.pop('origin', None)

if data is not None:
x = data['x']
if len(data.keys()) > 1:
warnings.warn('Axes.imshow does not extract keywords from data yet (just x).')

xs = copy.copy(x)

self.vmin = kwargs.pop('vmin', x.min())
self.vmax = kwargs.pop('vmax', x.max())
self.check_inputs(**kwargs)
self._set_up_imange_extent(x=copy.copy(x), **kwargs)

def _set_up_imange_extent(self, x, **kwargs):
"""
Handles setup of image extent, translate, and scale
"""
origin = kwargs.pop('origin', None)

if origin in ['upper', None]:
xs = xs[::-1]
x = x[::-1]
extent = kwargs.pop('extent', None) or (-0.5, x.shape[1]-0.5, -(x.shape[0]-0.5), -(0-0.5))
else:
extent = kwargs.pop('extent', None) or (-0.5, x.shape[1]-0.5, -0.5, x.shape[0]-0.5)

if len(np.shape(xs)) == 3:
xs = np.transpose(xs, (2, 0, 1))
if len(np.shape(x)) == 3:
x = np.transpose(x, (2, 0, 1))
else:
xs = np.array(color_map_translator(
xs.flatten(), cmap=self.cmap, norm=self.norm, vmin=vmin, vmax=vmax, clip=kwargs.pop('clip', False),
ncol=kwargs.pop('N', 256), alpha=self.alpha,
)).T.reshape([4] + tolist(xs.shape))

super(AxesImage, self).__init__(np.transpose(xs))
if extent is not None:
self.resetTransform()
self.translate(extent[0], extent[2])
self.scale((extent[1] - extent[0]) / self.width(), (extent[3] - extent[2]) / self.height())

self.vmin = vmin or x.min()
self.vmax = vmax or x.max()
x = np.array(color_map_translator(
x.flatten(),
cmap=self.cmap,
norm=self.norm,
vmin=self.vmin,
vmax=self.vmax,
clip=kwargs.pop('clip', False),
ncol=kwargs.pop('N', 256),
alpha=self.alpha,
)).T.reshape([4] + tolist(x.shape))

super(AxesImage, self).__init__(np.transpose(x))
self.resetTransform()
self.translate(extent[0], extent[2])
self.scale(int(round((extent[1] - extent[0]) / self.width())), int(round((extent[3] - extent[2]) / self.height())))

@staticmethod
def check_inputs(**kw):
Expand Down
5 changes: 3 additions & 2 deletions pgmpl/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
# pgmpl
# noinspection PyUnresolvedReferences
import pgmpl.__init__ # __init__ does setup stuff like making sure a QApp exists
from pgmpl.translate import setup_pen_kw, color_map_translator
from pgmpl.translate import color_map_translator
from pgmpl.tools import setup_pen_kw
from pgmpl.util import printd, tolist


Expand Down Expand Up @@ -120,7 +121,7 @@ def draw_unfilled(self):
x0, y0, x1, y1 = self.x.min(), self.y.min(), self.x.max(), self.y.max()
for contour in contours:
contour.translate(x0, y0) # https://stackoverflow.com/a/51109935/6605826
contour.scale((x1 - x0) / np.shape(self.z)[0], (y1 - y0) / np.shape(self.z)[1])
contour.scale(int(round((x1 - x0) / np.shape(self.z)[0])), int(round((y1 - y0) / np.shape(self.z)[1])))
self.ax.addItem(contour)


Expand Down
46 changes: 25 additions & 21 deletions pgmpl/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, **kw):
self.resizeEvent = self.resize_event
dpi = rcParams['figure.dpi'] if dpi is None else dpi
figsize = rcParams['figure.figsize'] if figsize is None else figsize
self.width, self.height = np.array(figsize)*dpi
self.width, self.height = (np.array(figsize)*dpi).astype(int)
self.resize(self.width, self.height)
for init_to_none in ['axes', 'suptitle_label']:
setattr(self, init_to_none, None)
Expand Down Expand Up @@ -92,18 +92,18 @@ def resize_event(self, event):
def set_subplotpars(self, pars):
"""
Sets margins and spacing between Axes. Not a direct matplotlib imitation.
:param pars: SubplotParams instance
The subplotpars keyword to __init__ goes straight to here.
"""
if pars is None or self.layout is None:
# Either no pars were provided or the layout has already been set to None because the figure is closing.
# Don't do any margin adjustments.
return
if pars is not None:
self.margins = {
'left': pars.left, 'top': pars.top, 'right': pars.right, 'bottom': pars.bottom,
'hspace': pars.hspace, 'wspace': pars.wspace,
}
self.margins = {
'left': pars.left, 'top': pars.top, 'right': pars.right, 'bottom': pars.bottom,
'hspace': pars.hspace, 'wspace': pars.wspace,
}
if self.margins is not None:
if self.tight:
self.layout.setContentsMargins(
Expand Down Expand Up @@ -148,9 +148,21 @@ def add_subplot(self, nrows, ncols, index, **kwargs):
self.refresh_suptitle()
return ax

def _try_remove_from_layout(self, obj):
"""
Try to remove an item from the layout, catching the naughty `Exception`
:param obj: object
"""
if obj is not None:
# noinspection PyBroadException
try:
self.layout.removeItem(obj)
except Exception: # pyqtgraph raises this type, so we can't be narrower
pass

def colorbar(self, mappable, cax=None, ax=None, **kwargs):
if ax is None:
ax = self.add_subplot(1, 1, 1) if self.axes is None else np.atleast_1d(self.axes).flatten()[-1]
ax = ax or self.gca()
if cax is None:
orientation = kwargs.get('orientation', 'vertical')
row = int(np.floor((ax.index - 1) / ax.ncols))
Expand All @@ -165,11 +177,7 @@ def colorbar(self, mappable, cax=None, ax=None, **kwargs):
else:
sub_layout.layout.setColumnFixedWidth(1, 50) # https://stackoverflow.com/a/36897295/6605826

# noinspection PyBroadException
try:
self.layout.removeItem(ax)
except Exception:
pass
self._try_remove_from_layout(ax)
self.layout.addItem(sub_layout, row + 1, col)
return Colorbar(cax, mappable, **kwargs)

Expand All @@ -180,12 +188,7 @@ def suptitle(self, t, **kwargs):
self.refresh_suptitle()

def refresh_suptitle(self):
if self.suptitle_label is not None:
# noinspection PyBroadException
try:
self.layout.removeItem(self.suptitle_label)
except Exception: # pyqtgraph raises this type, so we can't be narrower
pass
self._try_remove_from_layout(self.suptitle_label)
self.suptitle_label = self.layout.addLabel(self.suptitle_text, 0, 0, 1, self.fig_colspan)

def closeEvent(self, event):
Expand All @@ -199,16 +202,17 @@ def closeEvent(self, event):
event.accept()
return

def gca(self):
def gca(self, **kwargs):
"""
Imitation of matplotlib gca()
:return: Current axes for this figure, creating them if necessary
"""
self._deleted_axes_protection('gca')
if self.axes is not None:
ax = list(flatten(np.atleast_1d(self.axes)))[-1]
if self.axes is None:
ax = self.add_subplot(1, 1, 1)
ax = self.add_subplot(1, 1, 1, **kwargs)
return ax

def close(self):
Expand Down
3 changes: 2 additions & 1 deletion pgmpl/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

# Plotting imports
import pyqtgraph as pg
from pgmpl.translate import color_translator, dealias
from pgmpl.translate import color_translator
from pgmpl.tools import dealias


class Text(pg.TextItem):
Expand Down
Loading

0 comments on commit 924cc78

Please sign in to comment.