diff --git a/dkist/dataset/tiled_dataset.py b/dkist/dataset/tiled_dataset.py index b3770f36..638b4721 100644 --- a/dkist/dataset/tiled_dataset.py +++ b/dkist/dataset/tiled_dataset.py @@ -10,10 +10,12 @@ import matplotlib.pyplot as plt import numpy as np +from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec import astropy from astropy.table import vstack +from astropy.wcs.wcsapi import HighLevelWCSWrapper from dkist.io.file_manager import FileManager, StripedExternalArray from dkist.io.loaders import AstropyFITSLoader @@ -168,7 +170,26 @@ def _get_axislabels(ax): ylabel = coord.get_axislabel() or coord._get_default_axislabel() return (xlabel, ylabel) - def plot(self, slice_index, share_zscale=False, fig=None, **kwargs): + @staticmethod + def _ensure_wcs_ordered_axis_lims(wcs: HighLevelWCSWrapper, ax: Axes) -> None: + """ + Adjust axis limits so the WCS values go from smaller numbers in the bottom left to higher numbers in the upper right. + """ + xmin, xmax = ax.get_xlim() + ymin, ymax = ax.get_ylim() + + world_xmin, world_ymin = wcs.low_level_wcs.pixel_to_world_values(xmin, ymin) + world_xmax, world_ymax = wcs.low_level_wcs.pixel_to_world_values(xmax, ymax) + + new_xmin, new_ymin = wcs.low_level_wcs.world_to_pixel_values(min(world_xmin, world_xmax), min(world_ymin, world_ymax)) + new_xmax, new_ymax = wcs.low_level_wcs.world_to_pixel_values(max(world_xmin, world_xmax), max(world_ymin, world_ymax)) + + ax.set_xlim(new_xmin, new_xmax) + ax.set_ylim(new_ymin, new_ymax) + + return + + def plot(self, slice_index, share_zscale=False, fig=None, limits_from_wcs=True, **kwargs): """ Plot a slice of each tile in the TiledDataset @@ -185,6 +206,9 @@ def plot(self, slice_index, share_zscale=False, fig=None, **kwargs): fig : `matplotlib.figure.Figure` A figure to use for the plot. If not specified the current pyplot figure will be used, or a new one created. + limits_from_wcs : `bool` + If ``True`` the plots will be adjusted so smaller WCS values in the lower left, increasing to the upper + right. If ``False`` then the raw orientation of the data is used. """ if isinstance(slice_index, int): slice_index = (slice_index,) @@ -202,6 +226,10 @@ def plot(self, slice_index, share_zscale=False, fig=None, **kwargs): ax_gridspec = gridspec[dataset_nrows - row - 1, col] ax = fig.add_subplot(ax_gridspec, projection=tile.wcs) tile.plot(axes=ax, **kwargs) + + if limits_from_wcs: + self._ensure_wcs_ordered_axis_lims(tile.wcs, ax) + if col == row == 0: xlabel, ylabel = self._get_axislabels(ax) fig.supxlabel(xlabel, y=0.05)