diff --git a/seaborn/matrix.py b/seaborn/matrix.py index 6b99c118b6..4e96132b52 100644 --- a/seaborn/matrix.py +++ b/seaborn/matrix.py @@ -69,7 +69,14 @@ def _matrix_mask(data, mask): if mask is None: mask = np.zeros(data.shape, bool) - if isinstance(mask, np.ndarray): + if isinstance(mask, pd.DataFrame): + # For DataFrame masks, ensure that semantic labels match data + if not mask.index.equals(data.index) \ + and mask.columns.equals(data.columns): + err = "Mask must have the same index and columns as data." + raise ValueError(err) + elif hasattr(mask, "__array__"): + mask = np.asarray(mask) # For array masks, ensure that shape matches data then convert if mask.shape != data.shape: raise ValueError("Mask must have the same shape as data.") @@ -79,13 +86,6 @@ def _matrix_mask(data, mask): columns=data.columns, dtype=bool) - elif isinstance(mask, pd.DataFrame): - # For DataFrame masks, ensure that semantic labels match data - if not mask.index.equals(data.index) \ - and mask.columns.equals(data.columns): - err = "Mask must have the same index and columns as data." - raise ValueError(err) - # Add any cells with missing data to the mask # This works around an issue where `plt.pcolormesh` doesn't represent # missing data properly diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 889e5da461..9416d2e3d0 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -56,6 +56,24 @@ def test_ndarray_input(self): assert p.xlabel == "" assert p.ylabel == "" + def test_array_like_input(self): + class ArrayLike: + def __init__(self, data): + self.data = data + + def __array__(self, dtype=None, copy=None): + return np.asarray(self.data, dtype=dtype, copy=copy) + + p = mat._HeatMapper(ArrayLike(self.x_norm), **self.default_kws) + npt.assert_array_equal(p.plot_data, self.x_norm) + pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm)) + + npt.assert_array_equal(p.xticklabels, np.arange(8)) + npt.assert_array_equal(p.yticklabels, np.arange(4)) + + assert p.xlabel == "" + assert p.ylabel == "" + def test_df_input(self): p = mat._HeatMapper(self.df_norm, **self.default_kws)