diff --git a/changelog/71.breaking.rst b/changelog/71.breaking.rst new file mode 100644 index 00000000..079b6ba5 --- /dev/null +++ b/changelog/71.breaking.rst @@ -0,0 +1 @@ +Move the `dkist.asdf_maker` package to `dkist.io.asdf.generator` while also refactoring its internal structure to hopefully make it a little easier to follow. diff --git a/dkist/asdf_maker/__init__.py b/dkist/asdf_maker/__init__.py deleted file mode 100644 index 9f84f22b..00000000 --- a/dkist/asdf_maker/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .generator import * -from .helpers import * diff --git a/dkist/asdf_maker/generator.py b/dkist/asdf_maker/generator.py deleted file mode 100644 index 92534910..00000000 --- a/dkist/asdf_maker/generator.py +++ /dev/null @@ -1,527 +0,0 @@ -import random -import string -import pathlib - -import numpy as np - -import asdf -import astropy.units as u -import gwcs -import gwcs.coordinate_frames as cf -from astropy.io import fits -from astropy.table import Table -from astropy.time import Time -from sunpy.coordinates import Helioprojective -from sunpy.time import parse_time - -from dkist.asdf_maker.helpers import (generate_lookup_table, linear_spectral_model, - references_from_filenames, spatial_model_from_header, - spectral_model_from_framewave, time_model_from_date_obs) -from dkist.dataset import Dataset -from dkist.io import AstropyFITSLoader -from dkist.io.array_containers import DaskFITSArrayContainer - -try: - from importlib import resources # >= py 3.7 -except ImportError: - import importlib_resources as resources - -__all__ = ['generate_datset_inventory_from_headers', 'dataset_from_fits', - 'asdf_tree_from_filenames', 'gwcs_from_headers', 'TransformBuilder', - 'build_pixel_frame', 'validate_headers', 'table_from_headers', - 'headers_from_filenames'] - - -def headers_from_filenames(filenames, hdu=0): - """ - A generator to get the headers from filenames. - """ - return [dict(fits.getheader(fname, ext=hdu)) for fname in filenames] - - -def table_from_headers(headers): - return Table(rows=headers, names=list(headers[0].keys())) - - -def validate_headers(table_headers): - """ - Given a bunch of headers, validate that they form a coherent set. This - function also adds the headers to a list as they are read from the file. - - Parameters - ---------- - - headers : iterator - An iterator of headers. - - Returns - ------- - out_headers : `list` - A list of headers. - """ - t = table_headers - - """ - Let's do roughly the minimal amount of verification here. - """ - - # For some keys all the values must be the same - same_keys = ['NAXIS', 'DNAXIS'] - naxis_same_keys = ['NAXISn', 'CTYPEn', 'CRVALn'] # 'CRPIXn' - dnaxis_same_keys = ['DNAXISn', 'DTYPEn', 'DPNAMEn', 'DWNAMEn'] - # Expand n in NAXIS keys - for nsk in naxis_same_keys: - for naxis in range(1, t['NAXIS'][0] + 1): - same_keys.append(nsk.replace('n', str(naxis))) - # Expand n in DNAXIS keys - for dsk in dnaxis_same_keys: - for dnaxis in range(1, t['DNAXIS'][0] + 1): - same_keys.append(dsk.replace('n', str(dnaxis))) - - validate_t = t[same_keys] - - for col in validate_t.columns.values(): - if not all(col == col[0]): - raise ValueError(f"The {col.name} values did not all match:\n {col}") - - return table_headers - - -def build_pixel_frame(header): - """ - Given a header, build the input - `gwcs.coordinate_frames.CoordinateFrame` object describing the pixel frame. - - Parameters - ---------- - - header : `dict` - A fits header. - - Returns - ------- - - pixel_frame : `gwcs.coordinate_frames.CoordinateFrame` - The pixel frame. - """ - axes_types = [header[f'DTYPE{n}'] for n in range(1, header['DNAXIS'] + 1)] - - return cf.CoordinateFrame(naxes=header['DNAXIS'], - axes_type=axes_types, - axes_order=range(header['DNAXIS']), - unit=[u.pixel]*header['DNAXIS'], - axes_names=[header[f'DPNAME{n}'] for n in range(1, header['DNAXIS'] + 1)], - name='pixel') - - -class TransformBuilder: - """ - This class builds compound models and frames in order when given axes types. - """ - - def __init__(self, headers): - self.header = headers[0] - - # Reshape the headers to match the Dataset shape, so we can extract headers along various axes. - shape = tuple((self.header[f'DNAXIS{n}'] for n in range(self.header['DNAXIS'], - self.header['DAAXES'], -1))) - arr_headers = np.empty(shape, dtype=object) - for i in range(arr_headers.size): - arr_headers.flat[i] = headers[i] - - self.headers = arr_headers - self.reset() - self._build() - - @property - def frames(self): - """ - The coordinate frames, in Python order. - """ - return self._frames - - @property - def transform(self): - """ - Return the compound model. - """ - tf = self._transforms[0] - - for i in range(1, len(self._transforms)): - tf = tf & self._transforms[i] - - return tf - - """ - Internal Stuff - """ - - def _build(self): - """ - Build the state of the thing. - """ - type_map = {'STOKES': self.make_stokes, - 'TEMPORAL': self.make_temporal, - 'SPECTRAL': self.make_spectral, - 'SPATIAL': self.make_spatial} - - xx = 0 - while self._i < self.header['DNAXIS']: # < because FITS is i+1 - atype = self.axes_types[self._i] - type_map[atype]() - xx += 1 - if xx > 100: - raise ValueError("Infinite loop in header parsing") # pragma: no cover - - @property - def axes_types(self): - """ - The list of DTYPEn for the first header. - """ - return [self.header[f'DTYPE{n}'] for n in range(1, self.header['DNAXIS'] + 1)] - - def reset(self): - """ - Reset the builder. - """ - self._i = 0 - self._frames = [] - self._transforms = [] - - @property - def n(self): - return self._n(self._i) - - def _n(self, i): - """ - Convert a Python index ``i`` to a FITS order index for keywords ``n``. - """ - # return range(self.header['DNAXIS'], 0, -1)[i] - return i + 1 - - @property - def slice_for_n(self): - i = self._i - self.header['DAAXES'] - naxes = self.header['DEAXES'] - ss = [0] * naxes - ss[i] = slice(None) - return ss[::-1] - - @property - def slice_headers(self): - return self.headers[self.slice_for_n] - - def get_units(self, *iargs): - """ - Get zee units - """ - u = [self.header.get(f'DUNIT{self._n(i)}', None) for i in iargs] - - return u - - def make_stokes(self): - """ - Add a stokes axes to the builder. - """ - name = self.header[f'DWNAME{self.n}'] - self._frames.append(cf.StokesFrame(axes_order=(self._i,), name=name)) - self._transforms.append(generate_lookup_table([0, 1, 2, 3] * u.one, interpolation='nearest')) - self._i += 1 - - def make_temporal(self): - """ - Add a temporal axes to the builder. - """ - - name = self.header[f'DWNAME{self.n}'] - self._frames.append(cf.TemporalFrame(axes_order=(self._i,), - name=name, - axes_names=(name,), - unit=self.get_units(self._i), - reference_frame=Time(self.header['DATE-BGN']))) - self._transforms.append(time_model_from_date_obs([e['DATE-OBS'] for e in self.slice_headers], - self.header['DATE-BGN'])) - - self._i += 1 - - def make_spatial(self): - """ - Add a helioprojective spatial pair to the builder. - - .. note:: - This increments the counter by two. - - """ - i = self._i - name = self.header[f'DWNAME{self.n}'] - name = name.split(' ')[0] - axes_names = [(self.header[f'DWNAME{nn}'].rsplit(' ')[1]) for nn in (self.n, self._n(i+1))] - - obstime = Time(self.header['DATE-BGN']) - axes_types = ["lat" if "LT" in self.axes_types[i] else "lon", "lon" if "LN" in self.axes_types[i] else "lat"] - self._frames.append(cf.CelestialFrame(axes_order=(i, i+1), name=name, - reference_frame=Helioprojective(obstime=obstime), - axes_names=axes_names, - unit=self.get_units(self._i, self._i+1), - axis_physical_types=(f"custom:pos.helioprojective.{axes_types[0]}", - f"custom:pos.helioprojective.{axes_types[1]}"))) - - self._transforms.append(spatial_model_from_header(self.header)) - - self._i += 2 - - def make_spectral(self): - """ - Decide how to make a spectral axes. - """ - name = self.header[f'DWNAME{self.n}'] - self._frames.append(cf.SpectralFrame(axes_order=(self._i,), - axes_names=(name,), - unit=self.get_units(self._i), - name=name)) - - if "WAVE" in self.header.get(f'CTYPE{self.n}', ''): - transform = self.make_spectral_from_wcs() - elif "FRAMEWAV" in self.header.keys(): - transform = self.make_spectral_from_dataset() - else: - raise ValueError("Could not parse spectral WCS information from this header.") # pragma: no cover - - self._transforms.append(transform) - - self._i += 1 - - def make_spectral_from_dataset(self): - """ - Make a spectral axes from (VTF) dataset info. - """ - framewave = [h['FRAMEWAV'] for h in self.slice_headers[:self.header[f'DNAXIS{self.n}']]] - return spectral_model_from_framewave(framewave) - - def make_spectral_from_wcs(self): - """ - Add a spectral axes from the FITS-WCS keywords. - """ - return linear_spectral_model(self.header[f'CDELT{self.n}']*u.nm, - self.header[f'CRVAL{self.n}']*u.nm) - - -def gwcs_from_headers(headers): - """ - Given a list of headers build a gwcs. - - Parameters - ---------- - - headers : `list` - A list of headers. These are expected to have already been validated. - """ - header = headers[0] - - pixel_frame = build_pixel_frame(header) - - builder = TransformBuilder(headers) - world_frame = cf.CompositeFrame(builder.frames) - - return gwcs.WCS(forward_transform=builder.transform, - input_frame=pixel_frame, - output_frame=world_frame) - - -def make_sorted_table(headers, filenames): - """ - Return an `astropy.table.Table` instance where the rows are correctly sorted. - """ - theaders = table_from_headers(headers) - theaders['filenames'] = filenames - theaders['headers'] = headers - dataset_axes = headers[0]['DNAXIS'] - array_axes = headers[0]['DAAXES'] - keys = [f'DINDEX{k}' for k in range(dataset_axes, array_axes, -1)] - t = np.array(theaders[keys]) - return theaders[np.argsort(t, order=keys)] - - -def _preprocess_headers(headers, filenames): - table_headers = make_sorted_table(headers, filenames) - - validate_headers(table_headers) - - # Sort the filenames into DS order. - sorted_filenames = np.array(table_headers['filenames']) - sorted_headers = np.array(table_headers['headers']) - - table_headers.remove_columns(["headers", "filenames"]) - - return table_headers, sorted_filenames, sorted_headers - - - -def asdf_tree_from_filenames(filenames, asdf_filename, inventory=None, hdu=0, - relative_to=None, extra_inventory=None): - """ - Build a DKIST asdf tree from a list of (unsorted) filenames. - - Parameters - ---------- - - filenames : `list` - The filenames to process into a DKIST asdf dataset. - - hdu : `int` - The HDU to read from the FITS files. - """ - # In case filenames is a generator we cast to list. - filenames = list(filenames) - - # headers is an iterator - headers = headers_from_filenames(filenames, hdu=hdu) - - table_headers, sorted_filenames, sorted_headers = _preprocess_headers(headers, filenames) - - if not inventory: - inventory = generate_datset_inventory_from_headers(table_headers, asdf_filename) - if extra_inventory: - inventory.update(extra_inventory) - - # Get the array shape - shape = tuple((headers[0][f'DNAXIS{n}'] for n in range(headers[0]['DNAXIS'], - headers[0]['DAAXES'], -1))) - # References from filenames - array_container = references_from_filenames(sorted_filenames, sorted_headers, array_shape=shape, - hdu_index=hdu, relative_to=relative_to) - - ds = Dataset(array_container.array, gwcs_from_headers(sorted_headers), meta=inventory, headers=table_headers) - - ds._array_container = array_container - - tree = {'dataset': ds} - - return tree - - -def dataset_from_fits(path, asdf_filename, inventory=None, hdu=0, relative_to=None, **kwargs): - """ - Given a path containing FITS files write an asdf file in the same path. - - Parameters - ---------- - path : `pathlib.Path` or `str` - The path to read the FITS files (with a `.fits` file extension) from - and save the asdf file. - - asdf_filename : `str` - The filename to save the asdf with in the path. - - inventory : `dict`, optional - The dataset inventory for this collection of FITS. If `None` a random one will be generated. - - hdu : `int`, optional - The HDU to read from the FITS files. - - relative_to: `pathlib.Path` or `str`, optional - The base path to use in the asdf references. - - kwargs - Additional kwargs are passed to `asdf.AsdfFile.write_to`. - - """ - path = pathlib.Path(path) - - files = path.glob("*fits") - - tree = asdf_tree_from_filenames(list(files), asdf_filename, inventory=inventory, - hdu=hdu, relative_to=relative_to) - - with resources.path("dkist.io", "level_1_dataset_schema.yaml") as schema_path: - with asdf.AsdfFile(tree, custom_schema=schema_path.as_posix()) as afile: - afile.write_to(path / asdf_filename, **kwargs) - - -def _gen_type(gen_type, max_int=1e6, max_float=1e6, len_str=30): - if gen_type is bool: - return bool(random.randint(0, 1)) - elif gen_type is int: - return random.randint(0, max_int) - elif gen_type is float: - return random.random() * max_float - elif gen_type is list: - return [_gen_type(str)] - elif gen_type is Time: - return parse_time("now") - elif gen_type is str: - return ''.join( - random.choice(string.ascii_uppercase + string.digits) for _ in range(len_str)) - else: - raise ValueError("Type {} is not supported".format(gen_type)) # pragma: no cover - - -def generate_datset_inventory_from_headers(headers, asdf_name): - """ - Generate a dummy dataset inventory from headers. - - .. note:: - This is just for test data, it should not be used on real data. - - Parameters - ---------- - - headers: `astropy.table.Table` - asdf_name: `str` - - """ - - schema = [ - ('asdf_object_key', str), - ('browse_movie_object_key', str), - ('browse_movie_url', str), - ('bucket', str), - ('contributing_experiment_ids', list), - ('contributing_proposal_ids', list), - ('dataset_id', str), - ('dataset_inventory_id', int), - ('dataset_size', int), - ('end_time', Time), - ('exposure_time', float), - ('filter_wavelengths', list), - ('frame_count', int), - ('has_all_stokes', bool), - ('instrument_name', str), - ('observables', list), - ('original_frame_count', int), - ('primary_experiment_id', str), - ('primary_proposal_id', str), - ('quality_average_fried_parameter', float), - ('quality_average_polarimetric_accuracy', float), - ('recipe_id', int), - ('recipe_instance_id', int), - ('recipe_run_id', int), - ('start_time', Time), - # ('stokes_parameters', str), - ('target_type', str), - ('wavelength_max', float), - ('wavelength_min', float) - ] - - header_mapping = { - 'start_time': 'DATE-BGN', - 'end_time': 'DATE-END', - 'filter_wavelengths': 'WAVELNGTH'} - - constants = { - 'frame_count': len(headers), - 'bucket': 'data', - 'asdf_object_key': str(asdf_name) - } - - output = {} - - for key, ktype in schema: - if key in header_mapping: - hdict = dict(zip(headers.colnames, headers[0])) - output[key] = ktype(hdict.get(header_mapping[key], _gen_type(ktype))) - else: - output[key] = _gen_type(ktype) - - output.update(constants) - return output diff --git a/dkist/asdf_maker/helpers.py b/dkist/asdf_maker/helpers.py deleted file mode 100644 index a773ed3c..00000000 --- a/dkist/asdf_maker/helpers.py +++ /dev/null @@ -1,227 +0,0 @@ -import os - -import numpy as np - -import asdf -import astropy.units as u -from astropy.io.fits.hdu.base import BITPIX2DTYPE -from astropy.modeling.models import (AffineTransformation2D, Linear1D, Multiply, - Pix2Sky_TAN, RotateNative2Celestial, Shift, Tabular1D) -from astropy.time import Time - -from dkist.io import AstropyFITSLoader, DaskFITSArrayContainer - -__all__ = ['make_asdf', 'time_model_from_date_obs', 'linear_time_model', 'linear_spectral_model', - 'spatial_model_from_quantity', 'spatial_model_from_header', 'references_from_filenames'] - - -def references_from_filenames(filenames, headers, array_shape, hdu_index=0, relative_to=None): - """ - Given an array of paths to FITS files create a `dkist.io.DaskFITSArrayContainer`. - - Parameters - ---------- - filenames : `numpy.ndarray` - An array of filenames, in numpy order for the output array (i.e. ``.flat``) - - headers : `list` - A list of headers for files - - array_shape : `tuple` - The desired output shape of the reference array. (i.e the shape of the - data minus the HDU dimensions.) - - hdu_index : `int` (optional, default 0) - The index of the HDU to reference. (Zero indexed) - - relative_to : `str` (optional) - If set convert the filenames to be relative to this path. - """ - - filenames = np.asanyarray(filenames) - filepaths = np.empty(array_shape, dtype=object) - if filenames.size != filepaths.size: - raise ValueError(f"An incorrect number of filenames ({filenames.size})" - f" supplied for array_shape ({array_shape})") - - dtypes = [] - shapes = [] - for i, (filepath, head) in enumerate(zip(filenames.flat, headers.flat)): - dtypes.append(BITPIX2DTYPE[head['BITPIX']]) - shapes.append(tuple([int(head[f"NAXIS{a}"]) for a in range(head["NAXIS"], 0, -1)])) - - # Convert paths to relative paths - relative_path = filepath - if relative_to: - relative_path = os.path.relpath(filepath, str(relative_to)) - - filepaths.flat[i] = str(relative_path) - - # Validate all shapes and dtypes are consistent. - dtype = set(dtypes) - if len(dtype) != 1: - raise ValueError("Not all the dtypes of these files are the same.") - dtype = list(dtype)[0] - - shape = set(shapes) - if len(shape) != 1: - raise ValueError("Not all the shapes of these files are the same") - shape = list(shape)[0] - - return DaskFITSArrayContainer(filepaths.tolist(), hdu_index, dtype, shape, loader=AstropyFITSLoader) - - -def spatial_model_from_header(header): - """ - Given a FITS compliant header with CTYPEx,y as HPLN, HPLT return a - `~astropy.modeling.CompositeModel` for the transform. - - This function finds the HPLN and HPLT keys in the header and returns a - model in Lon, Lat order. - """ - latind = None - lonind = None - for k, v in header.items(): - if isinstance(v, str) and "HPLN" in v: - lonind = int(k[5:]) - if isinstance(v, str) and "HPLT" in v: - latind = int(k[5:]) - - if latind is None or lonind is None: - raise ValueError("Could not extract HPLN and HPLT from the header.") - - latproj = header[f'CTYPE{latind}'][5:] - lonproj = header[f'CTYPE{lonind}'][5:] - - if latproj != lonproj: - raise ValueError("The projection of the two spatial axes did not match.") # pragma: no cover - - cunit1, cunit2 = u.Unit(header[f'CUNIT{lonind}']), u.Unit(header[f'CUNIT{latind}']) - crpix1, crpix2 = header[f'CRPIX{lonind}'] * u.pix, header[f'CRPIX{latind}'] * u.pix - crval1, crval2 = (header[f'CRVAL{lonind}'] * cunit1, header[f'CRVAL{latind}'] * cunit2) - cdelt1, cdelt2 = (header[f'CDELT{lonind}'] * (cunit1 / u.pix), - header[f'CDELT{latind}'] * (cunit2 / u.pix)) - pc = np.matrix([[header[f'PC{lonind}_{lonind}'], header[f'PC{lonind}_{latind}']], - [header[f'PC{latind}_{lonind}'], header[f'PC{latind}_{latind}']]]) * cunit1 - - return spatial_model_from_quantity(crpix1, crpix2, cdelt1, cdelt2, pc, crval1, crval2, - projection=latproj) - - -def spatial_model_from_quantity(crpix1, crpix2, cdelt1, cdelt2, pc, crval1, crval2, - projection='TAN'): - """ - Given quantity representations of a HPLx FITS WCS return a model for the - spatial transform. - - The ordering of ctype1 and ctype2 should be LON, LAT - """ - - # TODO: Find this from somewhere else or extend it or something - projections = {'TAN': Pix2Sky_TAN()} - - shiftu = Shift(-crpix1) & Shift(-crpix2) - scale = Multiply(cdelt1) & Multiply(cdelt2) - rotu = AffineTransformation2D(pc, translation=(0, 0)*u.arcsec) - tanu = projections[projection] - skyrotu = RotateNative2Celestial(crval1, crval2, 180*u.deg) - return shiftu | scale | rotu | tanu | skyrotu - - -@u.quantity_input -def linear_spectral_model(spectral_width: u.nm, reference_val: u.nm): - """ - A linear model in a spectral dimension. The reference pixel is always 0. - """ - return Linear1D(slope=spectral_width/(1*u.pix), intercept=reference_val) - - -@u.quantity_input -def linear_time_model(cadence: u.s, reference_val: u.s = 0*u.s): - """ - A linear model in a temporal dimension. The reference pixel is always 0. - """ - if not reference_val: - reference_val = 0 * cadence.unit - return Linear1D(slope=cadence / (1 * u.pix), intercept=reference_val) - - -def generate_lookup_table(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs): - if not isinstance(lookup_table, u.Quantity): - raise TypeError("lookup_table must be a Quantity.") - - # The integer location is at the centre of the pixel. - points = (np.arange(lookup_table.size) - 0) * points_unit - - kwargs = { - 'bounds_error': False, - 'fill_value': np.nan, - 'method': interpolation, - **kwargs - } - - return Tabular1D(points, lookup_table, **kwargs) - - -def time_model_from_date_obs(date_obs, date_bgn=None): - """ - Return a time model that best fits a list of dateobs's. - """ - if not date_bgn: - date_bgn = date_obs[0] - date_obs = Time(date_obs) - date_bgn = Time(date_bgn) - - deltas = date_bgn - date_obs - - # Work out if we have a uniform delta (i.e. a linear model) - ddelta = (deltas.to(u.s)[:-1] - deltas.to(u.s)[1:]) - - # If the length of the axis is one, then return a very simple model - if ddelta.size == 0: - return linear_time_model(cadence=0*u.s, reference_val=0*u.s) - elif u.allclose(ddelta[0], ddelta): - slope = ddelta[0] - intercept = 0 * u.s - return linear_time_model(cadence=slope, reference_val=intercept) - else: - print(f"creating tabular temporal axis. ddeltas: {ddelta}") - return generate_lookup_table(deltas.to(u.s)) - - -def spectral_model_from_framewave(framewav): - """ - Construct a linear or lookup table model for wavelength based on the - framewav keys. - """ - framewav = u.Quantity(framewav, unit=u.nm) - wave_bgn = framewav[0] - - deltas = wave_bgn - framewav - ddeltas = (deltas[:-1] - deltas[1:]) - # If the length of the axis is one, then return a very simple model - if ddeltas.size == 0: - return linear_spectral_model(0*u.nm, wave_bgn) - if u.allclose(ddeltas[0], ddeltas): - slope = ddeltas[0] - return linear_spectral_model(slope, wave_bgn) - else: - print(f"creating tabular wavelength axis. ddeltas: {ddeltas}") - return generate_lookup_table(framewav) - - -def make_asdf(filename, *, dataset, **kwargs): - """ - Save an asdf file. - - All keyword arguments become keys in the top level of the asdf tree. - """ - tree = { - 'dataset': dataset, - **kwargs - } - - with asdf.AsdfFile(tree) as ff: - ff.write_to(filename) - - return filename diff --git a/dkist/asdf_maker/tests/test_generator.py b/dkist/asdf_maker/tests/test_generator.py deleted file mode 100644 index 887127f4..00000000 --- a/dkist/asdf_maker/tests/test_generator.py +++ /dev/null @@ -1,101 +0,0 @@ -import pathlib - -import pytest - -import gwcs -import gwcs.coordinate_frames as cf -from astropy.modeling import Model, models - -from dkist.asdf_maker.generator import (asdf_tree_from_filenames, dataset_from_fits, - gwcs_from_headers, headers_from_filenames, - table_from_headers, validate_headers) -from dkist.dataset import Dataset - - -@pytest.fixture -def wcs(header_filenames): - wcs = gwcs_from_headers(headers_from_filenames(header_filenames)) - assert isinstance(wcs, gwcs.WCS) - return wcs - - -def test_reset(transform_builder): - transform_builder._i = 2 - transform_builder.reset() - assert transform_builder._i == 0 - - -def test_transform(transform_builder): - assert isinstance(transform_builder.transform, Model) - - -def test_frames(transform_builder): - frames = transform_builder.frames - assert all([isinstance(frame, cf.CoordinateFrame) for frame in frames]) - - -def test_input_name_ordering(wcs): - # Check the ordering of the input and output frames - allowed_pixel_names = (('spatial x', 'spatial y', 'wavelength position', 'scan number', - 'stokes'), ('wavelength', 'slit position', 'raster position', - 'scan number', 'stokes')) - assert wcs.input_frame.axes_names in allowed_pixel_names - - -def test_output_name_ordering(wcs): - allowed_world_names = (('latitude', 'longitude', 'wavelength', 'time', 'stokes'), - ('wavelength', 'latitude', 'longitude', 'time', 'stokes')) - assert wcs.output_frame.axes_names in allowed_world_names - - -def test_output_frames(wcs): - allowed_frame_orders = ((cf.CelestialFrame, cf.SpectralFrame, cf.TemporalFrame, cf.StokesFrame), - (cf.SpectralFrame, cf.CelestialFrame, cf.TemporalFrame, cf.StokesFrame)) - types = tuple((type(frame) for frame in wcs.output_frame.frames)) - assert types in allowed_frame_orders - - -def test_transform_models(wcs): - # Test that there is one lookup table and two linear models for both the - # wcses - sms = wcs.forward_transform._leaflist - smtypes = [type(m) for m in sms] - assert sum(mt is models.Linear1D for mt in smtypes) == 2 - assert sum(mt is models.Tabular1D for mt in smtypes) == 1 - - -def test_array_container_shape(header_filenames): - from dkist.asdf_maker.generator import _preprocess_headers, references_from_filenames - from dkist.io import DaskFITSArrayContainer, AstropyFITSLoader - - headers = headers_from_filenames(header_filenames, hdu=0) - table_headers, sorted_filenames, sorted_headers = _preprocess_headers(headers, header_filenames) - # Get the array shape - shape = tuple((headers[0][f'DNAXIS{n}'] for n in range(headers[0]['DNAXIS'], - headers[0]['DAAXES'], -1))) - # References from filenames - array_container = references_from_filenames(sorted_filenames, sorted_headers, array_shape=shape, - hdu_index=0, relative_to=".") - - assert len(array_container.output_shape) == 5 - assert array_container.output_shape == array_container.array.shape - - -def test_asdf_tree(header_filenames): - tree = asdf_tree_from_filenames(header_filenames, "test_file.asdf") - assert isinstance(tree, dict) - - -def test_validator(header_filenames): - headers = headers_from_filenames(header_filenames) - headers[10]['NAXIS'] = 5 - with pytest.raises(ValueError) as excinfo: - validate_headers(table_from_headers(headers)) - assert "NAXIS" in str(excinfo) - - -def test_make_asdf(header_filenames, tmpdir): - path = pathlib.Path(header_filenames[0]) - dataset_from_fits(path.parent, "test.asdf") - assert (path.parent / "test.asdf").exists() - assert isinstance(Dataset.from_directory(str(path.parent)), Dataset) diff --git a/dkist/asdf_maker/tests/test_helpers.py b/dkist/asdf_maker/tests/test_helpers.py deleted file mode 100644 index 708b524a..00000000 --- a/dkist/asdf_maker/tests/test_helpers.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -from pathlib import Path - -import numpy as np -import pytest - -import asdf -import astropy.units as u -from astropy.io import fits -from astropy.modeling import Model, models -from astropy.time import Time - -from dkist.asdf_maker.generator import asdf_tree_from_filenames, headers_from_filenames -from dkist.asdf_maker.helpers import (linear_spectral_model, linear_time_model, make_asdf, - references_from_filenames, spatial_model_from_header, - spectral_model_from_framewave, time_model_from_date_obs) - - -def test_references_from_filesnames_shape_error(header_filenames): - headers = headers_from_filenames(header_filenames, hdu=0) - with pytest.raises(ValueError) as exc: - references_from_filenames(header_filenames, headers, [2, 3]) - - assert "incorrect number" in str(exc) - assert "2, 3" in str(exc) - assert str(len(header_filenames)) in str(exc) - - -def test_references_from_filenames(header_filenames): - headers = headers_from_filenames(header_filenames, hdu=0) - base = os.path.split(header_filenames[0])[0] - refs = references_from_filenames(header_filenames, np.array(headers, dtype=object), - (len(header_filenames),), relative_to=base) - - for ref in refs.fileuris: - assert base not in ref - - -def first_header(header_filenames): - return fits.getheader(header_filenames[0]) - - -def test_spatial_model(header_filenames): - spatial = spatial_model_from_header(first_header(header_filenames)) - assert isinstance(spatial, Model) - - -def test_spatial_model_fail(header_filenames): - header = first_header(header_filenames) - header['CTYPE2'] = 'WAVE' - with pytest.raises(ValueError): - spatial_model_from_header(header) - - -def test_linear_spectral(): - lin = linear_spectral_model(10*u.nm, 0*u.nm) - assert isinstance(lin, models.Linear1D) - assert u.allclose(lin.slope, 10*u.nm/u.pix) - assert u.allclose(lin.intercept, 0*u.nm) - - -def test_linear_time(): - lin = linear_time_model(10*u.s) - assert isinstance(lin, models.Linear1D) - assert u.allclose(lin.slope, 10*u.s/u.pix) - assert u.allclose(lin.intercept, 0*u.s) - - -def test_time_from_dateobs(header_filenames): - date_obs = [fits.getheader(f)['DATE-OBS'] for f in header_filenames] - time = time_model_from_date_obs(date_obs) - assert isinstance(time, models.Linear1D) - - -def test_time_from_dateobs_lookup(header_filenames): - date_obs = [fits.getheader(f)['DATE-OBS'] for f in header_filenames] - date_obs[5] = (Time(date_obs[5]) + 10*u.s).isot - time = time_model_from_date_obs(date_obs) - assert isinstance(time, models.Tabular1D) - - -def test_spectral_framewave(header_filenames): - head = first_header(header_filenames) - - # Skip the VISP headers - if "FRAMEWAV" not in head: - return - - nwave = head['DNAXIS3'] - framewave = [fits.getheader(h)['FRAMEWAV'] for h in header_filenames] - - m = spectral_model_from_framewave(framewave[:nwave]) - assert isinstance(m, models.Linear1D) - - m2 = spectral_model_from_framewave(framewave) - assert isinstance(m2, models.Tabular1D) - - -def test_make_asdf(header_filenames, tmpdir): - tree = asdf_tree_from_filenames(header_filenames, "test.asdf") - fname = Path(tmpdir.join("test.asdf")) - asdf_file = make_asdf(fname, dataset=tree['dataset']) - - with open(asdf_file, "rb") as fd: - af = asdf.AsdfFile() - af = asdf.AsdfFile._open_asdf(af, fd) diff --git a/dkist/io/asdf/generator/__init__.py b/dkist/io/asdf/generator/__init__.py new file mode 100644 index 00000000..ea29e882 --- /dev/null +++ b/dkist/io/asdf/generator/__init__.py @@ -0,0 +1,4 @@ +""" +Functions for building asdf files from a set of calibrated DKIST FITS files and their headers. +""" +from .generator import * diff --git a/dkist/io/asdf/generator/generator.py b/dkist/io/asdf/generator/generator.py new file mode 100644 index 00000000..3cc1c5df --- /dev/null +++ b/dkist/io/asdf/generator/generator.py @@ -0,0 +1,166 @@ +import os +import pathlib + +import numpy as np + +import asdf +from astropy.io.fits.hdu.base import BITPIX2DTYPE + +from dkist.io.asdf.generator.helpers import headers_from_filenames, preprocess_headers +from dkist.io.asdf.generator.simulated_data import generate_datset_inventory_from_headers +from dkist.io.asdf.generator.transforms import TransformBuilder +from dkist.dataset import Dataset +from dkist.io import AstropyFITSLoader, DaskFITSArrayContainer + +try: + from importlib import resources # >= py 3.7 +except ImportError: + import importlib_resources as resources + + +__all__ = ['references_from_filenames', 'dataset_from_fits', 'asdf_tree_from_filenames'] + + +def references_from_filenames(filenames, headers, array_shape, hdu_index=0, relative_to=None): + """ + Given an array of paths to FITS files create a `dkist.io.DaskFITSArrayContainer`. + + Parameters + ---------- + filenames : `numpy.ndarray` + An array of filenames, in numpy order for the output array (i.e. ``.flat``) + + headers : `list` + A list of headers for files + + array_shape : `tuple` + The desired output shape of the reference array. (i.e the shape of the + data minus the HDU dimensions.) + + hdu_index : `int` (optional, default 0) + The index of the HDU to reference. (Zero indexed) + + relative_to : `str` (optional) + If set convert the filenames to be relative to this path. + + Returns + ------- + `dkist.io.DaskFITSArrayContainer` + A container that represents a set of FITS files, and can generate a `dask.array.Array` from them. + """ + + filenames = np.asanyarray(filenames) + filepaths = np.empty(array_shape, dtype=object) + if filenames.size != filepaths.size: + raise ValueError(f"An incorrect number of filenames ({filenames.size})" + f" supplied for array_shape ({array_shape})") + + dtypes = [] + shapes = [] + for i, (filepath, head) in enumerate(zip(filenames.flat, headers.flat)): + dtypes.append(BITPIX2DTYPE[head['BITPIX']]) + shapes.append(tuple([int(head[f"NAXIS{a}"]) for a in range(head["NAXIS"], 0, -1)])) + + # Convert paths to relative paths + relative_path = filepath + if relative_to: + relative_path = os.path.relpath(filepath, str(relative_to)) + + filepaths.flat[i] = str(relative_path) + + # Validate all shapes and dtypes are consistent. + dtype = set(dtypes) + if len(dtype) != 1: + raise ValueError("Not all the dtypes of these files are the same.") + dtype = list(dtype)[0] + + shape = set(shapes) + if len(shape) != 1: + raise ValueError("Not all the shapes of these files are the same") + shape = list(shape)[0] + + return DaskFITSArrayContainer(filepaths.tolist(), hdu_index, dtype, shape, loader=AstropyFITSLoader) + + +def asdf_tree_from_filenames(filenames, asdf_filename, inventory=None, hdu=0, + relative_to=None, extra_inventory=None): + """ + Build a DKIST asdf tree from a list of (unsorted) filenames. + + Parameters + ---------- + + filenames : `list` + The filenames to process into a DKIST asdf dataset. + + hdu : `int` + The HDU to read from the FITS files. + """ + # In case filenames is a generator we cast to list. + filenames = list(filenames) + + # headers is an iterator + headers = headers_from_filenames(filenames, hdu=hdu) + + table_headers, sorted_filenames, sorted_headers = preprocess_headers(headers, filenames) + + if not inventory: + inventory = generate_datset_inventory_from_headers(table_headers, asdf_filename) + if extra_inventory: + inventory.update(extra_inventory) + + + ds_wcs = TransformBuilder(sorted_headers).gwcs + + # Get the array shape + shape = tuple((headers[0][f'DNAXIS{n}'] for n in range(headers[0]['DNAXIS'], + headers[0]['DAAXES'], -1))) + # References from filenames + array_container = references_from_filenames(sorted_filenames, sorted_headers, array_shape=shape, + hdu_index=hdu, relative_to=relative_to) + + ds = Dataset(array_container.array, ds_wcs, meta=inventory, headers=table_headers) + + ds._array_container = array_container + + tree = {'dataset': ds} + + return tree + + +def dataset_from_fits(path, asdf_filename, inventory=None, hdu=0, relative_to=None, **kwargs): + """ + Given a path containing FITS files write an asdf file in the same path. + + Parameters + ---------- + path : `pathlib.Path` or `str` + The path to read the FITS files (with a `.fits` file extension) from + and save the asdf file. + + asdf_filename : `str` + The filename to save the asdf with in the path. + + inventory : `dict`, optional + The dataset inventory for this collection of FITS. If `None` a random one will be generated. + + hdu : `int`, optional + The HDU to read from the FITS files. + + relative_to: `pathlib.Path` or `str`, optional + The base path to use in the asdf references. + + kwargs + Additional kwargs are passed to `asdf.AsdfFile.write_to`. + + """ + path = pathlib.Path(path) + + files = path.glob("*fits") + + tree = asdf_tree_from_filenames(list(files), asdf_filename, inventory=inventory, + hdu=hdu, relative_to=relative_to) + + with resources.path("dkist.io", "level_1_dataset_schema.yaml") as schema_path: + with asdf.AsdfFile(tree, custom_schema=schema_path.as_posix()) as afile: + afile.write_to(path / asdf_filename, **kwargs) diff --git a/dkist/io/asdf/generator/helpers.py b/dkist/io/asdf/generator/helpers.py new file mode 100644 index 00000000..45a6eee7 --- /dev/null +++ b/dkist/io/asdf/generator/helpers.py @@ -0,0 +1,95 @@ +""" +Helper functions for parsing files and processing headers. +""" + +import numpy as np + +import asdf +from astropy.io import fits +from astropy.table import Table + +__all__ = ['preprocess_headers', 'make_sorted_table', 'validate_headers', + 'table_from_headers', 'headers_from_filenames'] + + +def headers_from_filenames(filenames, hdu=0): + """ + A generator to get the headers from filenames. + """ + return [dict(fits.getheader(fname, ext=hdu)) for fname in filenames] + + +def table_from_headers(headers): + return Table(rows=headers, names=list(headers[0].keys())) + + +def validate_headers(table_headers): + """ + Given a bunch of headers, validate that they form a coherent set. This + function also adds the headers to a list as they are read from the file. + + Parameters + ---------- + + headers : iterator + An iterator of headers. + + Returns + ------- + out_headers : `list` + A list of headers. + """ + t = table_headers + + """ + Let's do roughly the minimal amount of verification here. + """ + + # For some keys all the values must be the same + same_keys = ['NAXIS', 'DNAXIS'] + naxis_same_keys = ['NAXISn', 'CTYPEn', 'CRVALn'] # 'CRPIXn' + dnaxis_same_keys = ['DNAXISn', 'DTYPEn', 'DPNAMEn', 'DWNAMEn'] + # Expand n in NAXIS keys + for nsk in naxis_same_keys: + for naxis in range(1, t['NAXIS'][0] + 1): + same_keys.append(nsk.replace('n', str(naxis))) + # Expand n in DNAXIS keys + for dsk in dnaxis_same_keys: + for dnaxis in range(1, t['DNAXIS'][0] + 1): + same_keys.append(dsk.replace('n', str(dnaxis))) + + validate_t = t[same_keys] + + for col in validate_t.columns.values(): + if not all(col == col[0]): + raise ValueError(f"The {col.name} values did not all match:\n {col}") + + return table_headers + + +def make_sorted_table(headers, filenames): + """ + Return an `astropy.table.Table` instance where the rows are correctly sorted. + """ + theaders = table_from_headers(headers) + theaders['filenames'] = filenames + theaders['headers'] = headers + dataset_axes = headers[0]['DNAXIS'] + array_axes = headers[0]['DAAXES'] + keys = [f'DINDEX{k}' for k in range(dataset_axes, array_axes, -1)] + t = np.array(theaders[keys]) + return theaders[np.argsort(t, order=keys)] + + +def preprocess_headers(headers, filenames): + table_headers = make_sorted_table(headers, filenames) + + validate_headers(table_headers) + + # Sort the filenames into DS order. + sorted_filenames = np.array(table_headers['filenames']) + sorted_headers = np.array(table_headers['headers']) + + table_headers.remove_columns(["headers", "filenames"]) + + return table_headers, sorted_filenames, sorted_headers diff --git a/dkist/io/asdf/generator/simulated_data.py b/dkist/io/asdf/generator/simulated_data.py new file mode 100644 index 00000000..60d356f3 --- /dev/null +++ b/dkist/io/asdf/generator/simulated_data.py @@ -0,0 +1,99 @@ +""" +Functions and helpers relating to working with simulated data. +""" +import random +import string + +from astropy.time import Time +from sunpy.time import parse_time + +__all__ = ['generate_datset_inventory_from_headers'] + + +def _gen_type(gen_type, max_int=1e6, max_float=1e6, len_str=30): + if gen_type is bool: + return bool(random.randint(0, 1)) + elif gen_type is int: + return random.randint(0, max_int) + elif gen_type is float: + return random.random() * max_float + elif gen_type is list: + return [_gen_type(str)] + elif gen_type is Time: + return parse_time("now") + elif gen_type is str: + return ''.join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(len_str)) + else: + raise ValueError("Type {} is not supported".format(gen_type)) # pragma: no cover + + +def generate_datset_inventory_from_headers(headers, asdf_name): + """ + Generate a dummy dataset inventory from headers. + + .. note:: + This is just for test data, it should not be used on real data. + + Parameters + ---------- + + headers: `astropy.table.Table` + asdf_name: `str` + + """ + + schema = [ + ('asdf_object_key', str), + ('browse_movie_object_key', str), + ('browse_movie_url', str), + ('bucket', str), + ('contributing_experiment_ids', list), + ('contributing_proposal_ids', list), + ('dataset_id', str), + ('dataset_inventory_id', int), + ('dataset_size', int), + ('end_time', Time), + ('exposure_time', float), + ('filter_wavelengths', list), + ('frame_count', int), + ('has_all_stokes', bool), + ('instrument_name', str), + ('observables', list), + ('original_frame_count', int), + ('primary_experiment_id', str), + ('primary_proposal_id', str), + ('quality_average_fried_parameter', float), + ('quality_average_polarimetric_accuracy', float), + ('recipe_id', int), + ('recipe_instance_id', int), + ('recipe_run_id', int), + ('start_time', Time), + # ('stokes_parameters', str), + ('target_type', str), + ('wavelength_max', float), + ('wavelength_min', float) + ] + + header_mapping = { + 'start_time': 'DATE-BGN', + 'end_time': 'DATE-END', + 'filter_wavelengths': 'WAVELNGTH'} + + constants = { + 'frame_count': len(headers), + 'bucket': 'data', + 'asdf_object_key': str(asdf_name) + } + + output = {} + + for key, ktype in schema: + if key in header_mapping: + hdict = dict(zip(headers.colnames, headers[0])) + output[key] = ktype(hdict.get(header_mapping[key], _gen_type(ktype))) + else: + output[key] = _gen_type(ktype) + + output.update(constants) + return output diff --git a/dkist/asdf_maker/tests/__init__.py b/dkist/io/asdf/generator/tests/__init__.py similarity index 100% rename from dkist/asdf_maker/tests/__init__.py rename to dkist/io/asdf/generator/tests/__init__.py diff --git a/dkist/asdf_maker/tests/conftest.py b/dkist/io/asdf/generator/tests/conftest.py similarity index 74% rename from dkist/asdf_maker/tests/conftest.py rename to dkist/io/asdf/generator/tests/conftest.py index ccab93c2..eb7a1426 100644 --- a/dkist/asdf_maker/tests/conftest.py +++ b/dkist/io/asdf/generator/tests/conftest.py @@ -6,36 +6,32 @@ import pytest -from dkist.asdf_maker.generator import TransformBuilder, headers_from_filenames +from dkist.io.asdf.generator.generator import headers_from_filenames +from dkist.io.asdf.generator.transforms import TransformBuilder from dkist.data.test import rootdir DATA_DIR = os.path.join(rootdir, 'datasettestfiles') -def extract(name): +@pytest.fixture(scope="session", params=["vtf.zip", "visp.zip"]) +def header_directory(request): atmpdir = tempfile.mkdtemp() - with ZipFile(os.path.join(DATA_DIR, name)) as myzip: + with ZipFile(os.path.join(DATA_DIR, request.param)) as myzip: myzip.extractall(atmpdir) return atmpdir -@pytest.fixture(scope="session", params=["vtf.zip", "visp.zip"]) -def header_filenames(request): - tdir = extract(request.param) - files = glob.glob(os.path.join(tdir, '*')) +@pytest.fixture +def header_filenames(header_directory): + files = glob.glob(os.path.join(header_directory, '*')) files.sort() - yield files - shutil.rmtree(tdir) + return files -@pytest.fixture(params=["vtf.zip", "visp.zip"]) -def transform_builder(request): - tdir = extract(request.param) - files = glob.glob(os.path.join(tdir, '*')) - files.sort() - headers = headers_from_filenames(files) - yield TransformBuilder(headers) - shutil.rmtree(tdir) +@pytest.fixture +def transform_builder(header_filenames): + headers = headers_from_filenames(header_filenames) + return TransformBuilder(headers) def make_header_files(): diff --git a/dkist/io/asdf/generator/tests/test_generator.py b/dkist/io/asdf/generator/tests/test_generator.py new file mode 100644 index 00000000..c2a0cea0 --- /dev/null +++ b/dkist/io/asdf/generator/tests/test_generator.py @@ -0,0 +1,56 @@ +import pathlib + +import pytest + +import asdf +import gwcs +import gwcs.coordinate_frames as cf +from astropy.modeling import Model, models + +from dkist import Dataset +from dkist.io.asdf.generator.generator import (asdf_tree_from_filenames, dataset_from_fits, + references_from_filenames) +from dkist.io.asdf.generator.helpers import (headers_from_filenames, preprocess_headers, + table_from_headers, validate_headers) +from dkist.dataset import Dataset +from dkist.io import AstropyFITSLoader, DaskFITSArrayContainer + + +def test_array_container_shape(header_filenames): + + headers = headers_from_filenames(header_filenames, hdu=0) + table_headers, sorted_filenames, sorted_headers = preprocess_headers(headers, header_filenames) + # Get the array shape + shape = tuple((headers[0][f'DNAXIS{n}'] for n in range(headers[0]['DNAXIS'], + headers[0]['DAAXES'], -1))) + # References from filenames + array_container = references_from_filenames(sorted_filenames, sorted_headers, array_shape=shape, + hdu_index=0, relative_to=".") + + assert len(array_container.output_shape) == 5 + assert array_container.output_shape == array_container.array.shape + + +def test_asdf_tree(header_filenames): + tree = asdf_tree_from_filenames(header_filenames, "test_file.asdf") + assert isinstance(tree, dict) + + +def test_validator(header_filenames): + headers = headers_from_filenames(header_filenames) + headers[10]['NAXIS'] = 5 + with pytest.raises(ValueError) as excinfo: + validate_headers(table_from_headers(headers)) + assert "NAXIS" in str(excinfo) + + +def test_dataset_from_fits(header_directory): + dataset_from_fits(header_directory, "test_asdf.asdf") + + asdf_file = pathlib.Path(header_directory) / "test_asdf.asdf" + assert asdf_file.exists() + + with asdf.open(asdf_file) as adf: + assert isinstance(adf['dataset'], Dataset) + + asdf_file.unlink() diff --git a/dkist/io/asdf/generator/tests/test_helpers.py b/dkist/io/asdf/generator/tests/test_helpers.py new file mode 100644 index 00000000..7b464c41 --- /dev/null +++ b/dkist/io/asdf/generator/tests/test_helpers.py @@ -0,0 +1,34 @@ +import os +from pathlib import Path + +import numpy as np +import pytest + +import asdf +import astropy.units as u +from astropy.io import fits +from astropy.modeling import Model, models +from astropy.time import Time + +from dkist.io.asdf.generator.helpers import headers_from_filenames +from dkist.io.asdf.generator.generator import asdf_tree_from_filenames, references_from_filenames + + +def test_references_from_filesnames_shape_error(header_filenames): + headers = headers_from_filenames(header_filenames, hdu=0) + with pytest.raises(ValueError) as exc: + references_from_filenames(header_filenames, headers, [2, 3]) + + assert "incorrect number" in str(exc) + assert "2, 3" in str(exc) + assert str(len(header_filenames)) in str(exc) + + +def test_references_from_filenames(header_filenames): + headers = headers_from_filenames(header_filenames, hdu=0) + base = os.path.split(header_filenames[0])[0] + refs = references_from_filenames(header_filenames, np.array(headers, dtype=object), + (len(header_filenames),), relative_to=base) + + for ref in refs.fileuris: + assert base not in ref diff --git a/dkist/io/asdf/generator/tests/test_transforms.py b/dkist/io/asdf/generator/tests/test_transforms.py new file mode 100644 index 00000000..3663f3c4 --- /dev/null +++ b/dkist/io/asdf/generator/tests/test_transforms.py @@ -0,0 +1,123 @@ + +import pytest + +import gwcs.coordinate_frames as cf +from astropy.modeling import Model, models +from astropy.io import fits +from astropy.time import Time +import astropy.units as u + +from dkist.io.asdf.generator.helpers import headers_from_filenames +from dkist.io.asdf.generator.transforms import (linear_spectral_model, linear_time_model, + spatial_model_from_header, spectral_model_from_framewave, + time_model_from_date_obs) + + +@pytest.fixture +def wcs(transform_builder): + return transform_builder.gwcs + + +def test_reset(transform_builder): + transform_builder._i = 2 + transform_builder.reset() + assert transform_builder._i == 0 + + +def test_transform(transform_builder): + assert isinstance(transform_builder.transform, Model) + + +def test_frames(transform_builder): + frames = transform_builder.frames + assert all([isinstance(frame, cf.CoordinateFrame) for frame in frames]) + + +def test_input_name_ordering(wcs): + # Check the ordering of the input and output frames + allowed_pixel_names = (('spatial x', 'spatial y', 'wavelength position', 'scan number', + 'stokes'), ('wavelength', 'slit position', 'raster position', + 'scan number', 'stokes')) + assert wcs.input_frame.axes_names in allowed_pixel_names + + +def test_output_name_ordering(wcs): + allowed_world_names = (('latitude', 'longitude', 'wavelength', 'time', 'stokes'), + ('wavelength', 'latitude', 'longitude', 'time', 'stokes')) + assert wcs.output_frame.axes_names in allowed_world_names + + +def test_output_frames(wcs): + allowed_frame_orders = ((cf.CelestialFrame, cf.SpectralFrame, cf.TemporalFrame, cf.StokesFrame), + (cf.SpectralFrame, cf.CelestialFrame, cf.TemporalFrame, cf.StokesFrame)) + types = tuple((type(frame) for frame in wcs.output_frame.frames)) + assert types in allowed_frame_orders + + +def test_transform_models(wcs): + # Test that there is one lookup table and two linear models for both the + # wcses + sms = wcs.forward_transform._leaflist + smtypes = [type(m) for m in sms] + assert sum(mt is models.Linear1D for mt in smtypes) == 2 + assert sum(mt is models.Tabular1D for mt in smtypes) == 1 + + +def first_header(header_filenames): + return fits.getheader(header_filenames[0]) + + +def test_spatial_model(header_filenames): + spatial = spatial_model_from_header(first_header(header_filenames)) + assert isinstance(spatial, Model) + + +def test_spatial_model_fail(header_filenames): + header = first_header(header_filenames) + header['CTYPE2'] = 'WAVE' + with pytest.raises(ValueError): + spatial_model_from_header(header) + + +def test_linear_spectral(): + lin = linear_spectral_model(10*u.nm, 0*u.nm) + assert isinstance(lin, models.Linear1D) + assert u.allclose(lin.slope, 10*u.nm/u.pix) + assert u.allclose(lin.intercept, 0*u.nm) + + +def test_linear_time(): + lin = linear_time_model(10*u.s) + assert isinstance(lin, models.Linear1D) + assert u.allclose(lin.slope, 10*u.s/u.pix) + assert u.allclose(lin.intercept, 0*u.s) + + +def test_time_from_dateobs(header_filenames): + date_obs = [fits.getheader(f)['DATE-OBS'] for f in header_filenames] + time = time_model_from_date_obs(date_obs) + assert isinstance(time, models.Linear1D) + + +def test_time_from_dateobs_lookup(header_filenames): + date_obs = [fits.getheader(f)['DATE-OBS'] for f in header_filenames] + date_obs[5] = (Time(date_obs[5]) + 10*u.s).isot + time = time_model_from_date_obs(date_obs) + assert isinstance(time, models.Tabular1D) + + +def test_spectral_framewave(header_filenames): + head = first_header(header_filenames) + + # Skip the VISP headers + if "FRAMEWAV" not in head: + return + + nwave = head['DNAXIS3'] + framewave = [fits.getheader(h)['FRAMEWAV'] for h in header_filenames] + + m = spectral_model_from_framewave(framewave[:nwave]) + assert isinstance(m, models.Linear1D) + + m2 = spectral_model_from_framewave(framewave) + assert isinstance(m2, models.Tabular1D) diff --git a/dkist/io/asdf/generator/transforms.py b/dkist/io/asdf/generator/transforms.py new file mode 100644 index 00000000..6851d3f8 --- /dev/null +++ b/dkist/io/asdf/generator/transforms.py @@ -0,0 +1,375 @@ +""" +Functionality relating to creating gWCS frames and Astropy models from SPEC 214 headers. +""" +import numpy as np + +import astropy.units as u +import gwcs +import gwcs.coordinate_frames as cf +from astropy.modeling.models import (AffineTransformation2D, Linear1D, Multiply, + Pix2Sky_TAN, RotateNative2Celestial, Shift, Tabular1D) +from astropy.time import Time +from sunpy.coordinates import Helioprojective + +__all__ = ['TransformBuilder', 'spectral_model_from_framewave', + 'time_model_from_date_obs', 'generate_lookup_table', + 'linear_time_model', 'linear_spectral_model', + 'spatial_model_from_quantity', 'spatial_model_from_header'] + + +def spatial_model_from_quantity(crpix1, crpix2, cdelt1, cdelt2, pc, crval1, crval2, + projection='TAN'): + """ + Given quantity representations of a HPLx FITS WCS return a model for the + spatial transform. + + The ordering of ctype1 and ctype2 should be LON, LAT + """ + + # TODO: Find this from somewhere else or extend it or something + projections = {'TAN': Pix2Sky_TAN()} + + shiftu = Shift(-crpix1) & Shift(-crpix2) + scale = Multiply(cdelt1) & Multiply(cdelt2) + rotu = AffineTransformation2D(pc, translation=(0, 0)*u.arcsec) + tanu = projections[projection] + skyrotu = RotateNative2Celestial(crval1, crval2, 180*u.deg) + return shiftu | scale | rotu | tanu | skyrotu + + +def spatial_model_from_header(header): + """ + Given a FITS compliant header with CTYPEx,y as HPLN, HPLT return a + `~astropy.modeling.CompositeModel` for the transform. + + This function finds the HPLN and HPLT keys in the header and returns a + model in Lon, Lat order. + """ + latind = None + lonind = None + for k, v in header.items(): + if isinstance(v, str) and "HPLN" in v: + lonind = int(k[5:]) + if isinstance(v, str) and "HPLT" in v: + latind = int(k[5:]) + + if latind is None or lonind is None: + raise ValueError("Could not extract HPLN and HPLT from the header.") + + latproj = header[f'CTYPE{latind}'][5:] + lonproj = header[f'CTYPE{lonind}'][5:] + + if latproj != lonproj: + raise ValueError("The projection of the two spatial axes did not match.") # pragma: no cover + + cunit1, cunit2 = u.Unit(header[f'CUNIT{lonind}']), u.Unit(header[f'CUNIT{latind}']) + crpix1, crpix2 = header[f'CRPIX{lonind}'] * u.pix, header[f'CRPIX{latind}'] * u.pix + crval1, crval2 = (header[f'CRVAL{lonind}'] * cunit1, header[f'CRVAL{latind}'] * cunit2) + cdelt1, cdelt2 = (header[f'CDELT{lonind}'] * (cunit1 / u.pix), + header[f'CDELT{latind}'] * (cunit2 / u.pix)) + pc = np.matrix([[header[f'PC{lonind}_{lonind}'], header[f'PC{lonind}_{latind}']], + [header[f'PC{latind}_{lonind}'], header[f'PC{latind}_{latind}']]]) * cunit1 + + return spatial_model_from_quantity(crpix1, crpix2, cdelt1, cdelt2, pc, crval1, crval2, + projection=latproj) + + +@u.quantity_input +def linear_spectral_model(spectral_width: u.nm, reference_val: u.nm): + """ + A linear model in a spectral dimension. The reference pixel is always 0. + """ + return Linear1D(slope=spectral_width/(1*u.pix), intercept=reference_val) + + +@u.quantity_input +def linear_time_model(cadence: u.s, reference_val: u.s = 0*u.s): + """ + A linear model in a temporal dimension. The reference pixel is always 0. + """ + if not reference_val: + reference_val = 0 * cadence.unit + return Linear1D(slope=cadence / (1 * u.pix), intercept=reference_val) + + +def generate_lookup_table(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs): + if not isinstance(lookup_table, u.Quantity): + raise TypeError("lookup_table must be a Quantity.") + + # The integer location is at the centre of the pixel. + points = (np.arange(lookup_table.size) - 0) * points_unit + + kwargs = { + 'bounds_error': False, + 'fill_value': np.nan, + 'method': interpolation, + **kwargs + } + + return Tabular1D(points, lookup_table, **kwargs) + + +def time_model_from_date_obs(date_obs, date_bgn=None): + """ + Return a time model that best fits a list of dateobs's. + """ + if not date_bgn: + date_bgn = date_obs[0] + date_obs = Time(date_obs) + date_bgn = Time(date_bgn) + + deltas = date_bgn - date_obs + + # Work out if we have a uniform delta (i.e. a linear model) + ddelta = (deltas.to(u.s)[:-1] - deltas.to(u.s)[1:]) + + # If the length of the axis is one, then return a very simple model + if ddelta.size == 0: + return linear_time_model(cadence=0*u.s, reference_val=0*u.s) + elif u.allclose(ddelta[0], ddelta): + slope = ddelta[0] + intercept = 0 * u.s + return linear_time_model(cadence=slope, reference_val=intercept) + else: + print(f"Creating tabular temporal axis. ddeltas: {ddelta}") + return generate_lookup_table(deltas.to(u.s)) + + +def spectral_model_from_framewave(framewav): + """ + Construct a linear or lookup table model for wavelength based on the + framewav keys. + """ + framewav = u.Quantity(framewav, unit=u.nm) + wave_bgn = framewav[0] + + deltas = wave_bgn - framewav + ddeltas = (deltas[:-1] - deltas[1:]) + # If the length of the axis is one, then return a very simple model + if ddeltas.size == 0: + return linear_spectral_model(0*u.nm, wave_bgn) + if u.allclose(ddeltas[0], ddeltas): + slope = ddeltas[0] + return linear_spectral_model(slope, wave_bgn) + else: + print(f"creating tabular wavelength axis. ddeltas: {ddeltas}") + return generate_lookup_table(framewav) + + +class TransformBuilder: + """ + This class builds compound models and frames in order when given axes types. + """ + + def __init__(self, headers): + self.header = headers[0] + + # Reshape the headers to match the Dataset shape, so we can extract headers along various axes. + shape = tuple((self.header[f'DNAXIS{n}'] for n in range(self.header['DNAXIS'], + self.header['DAAXES'], -1))) + arr_headers = np.empty(shape, dtype=object) + for i in range(arr_headers.size): + arr_headers.flat[i] = headers[i] + + self.headers = arr_headers + self.reset() + self._build() + + @property + def pixel_frame(self): + """ + A `gwcs.coordinate_frames.CoordinateFrame` object describing the pixel frame. + """ + return cf.CoordinateFrame(naxes=self.header['DNAXIS'], + axes_type=self.axes_types, + axes_order=range(self.header['DNAXIS']), + unit=[u.pixel]*self.header['DNAXIS'], + axes_names=[self.header[f'DPNAME{n}'] for n in range(1, self.header['DNAXIS'] + 1)], + name='pixel') + + @property + def gwcs(self): + """ + A `gwcs.WCS` object representing these headers. + """ + world_frame = cf.CompositeFrame(self.frames) + + return gwcs.WCS(forward_transform=self.transform, + input_frame=self.pixel_frame, + output_frame=world_frame) + + + @property + def frames(self): + """ + The coordinate frames, in Python order. + """ + return self._frames + + @property + def transform(self): + """ + Return the compound model. + """ + tf = self._transforms[0] + + for i in range(1, len(self._transforms)): + tf = tf & self._transforms[i] + + return tf + + """ + Internal Stuff + """ + + def _build(self): + """ + Build the state of the thing. + """ + type_map = {'STOKES': self.make_stokes, + 'TEMPORAL': self.make_temporal, + 'SPECTRAL': self.make_spectral, + 'SPATIAL': self.make_spatial} + + xx = 0 + while self._i < self.header['DNAXIS']: # < because FITS is i+1 + atype = self.axes_types[self._i] + type_map[atype]() + xx += 1 + if xx > 100: + raise ValueError("Infinite loop in header parsing") # pragma: no cover + + @property + def axes_types(self): + """ + The list of DTYPEn for the first header. + """ + return [self.header[f'DTYPE{n}'] for n in range(1, self.header['DNAXIS'] + 1)] + + def reset(self): + """ + Reset the builder. + """ + self._i = 0 + self._frames = [] + self._transforms = [] + + @property + def n(self): + """ + The FITS index of the current dimension. + """ + return self._n(self._i) + + def _n(self, i): + """ + Convert a Python index ``i`` to a FITS order index for keywords ``n``. + """ + # return range(self.header['DNAXIS'], 0, -1)[i] + return i + 1 + + @property + def slice_for_n(self): + i = self._i - self.header['DAAXES'] + naxes = self.header['DEAXES'] + ss = [0] * naxes + ss[i] = slice(None) + return ss[::-1] + + @property + def slice_headers(self): + return self.headers[self.slice_for_n] + + def get_units(self, *iargs): + """ + Get zee units + """ + u = [self.header.get(f'DUNIT{self._n(i)}', None) for i in iargs] + + return u + + def make_stokes(self): + """ + Add a stokes axes to the builder. + """ + name = self.header[f'DWNAME{self.n}'] + self._frames.append(cf.StokesFrame(axes_order=(self._i,), name=name)) + self._transforms.append(generate_lookup_table([0, 1, 2, 3] * u.one, interpolation='nearest')) + self._i += 1 + + def make_temporal(self): + """ + Add a temporal axes to the builder. + """ + + name = self.header[f'DWNAME{self.n}'] + self._frames.append(cf.TemporalFrame(axes_order=(self._i,), + name=name, + axes_names=(name,), + unit=self.get_units(self._i), + reference_frame=Time(self.header['DATE-BGN']))) + self._transforms.append(time_model_from_date_obs([e['DATE-OBS'] for e in self.slice_headers], + self.header['DATE-BGN'])) + + self._i += 1 + + def make_spatial(self): + """ + Add a helioprojective spatial pair to the builder. + + .. note:: + This increments the counter by two. + + """ + i = self._i + name = self.header[f'DWNAME{self.n}'] + name = name.split(' ')[0] + axes_names = [(self.header[f'DWNAME{nn}'].rsplit(' ')[1]) for nn in (self.n, self._n(i+1))] + + obstime = Time(self.header['DATE-BGN']) + axes_types = ["lat" if "LT" in self.axes_types[i] else "lon", "lon" if "LN" in self.axes_types[i] else "lat"] + self._frames.append(cf.CelestialFrame(axes_order=(i, i+1), name=name, + reference_frame=Helioprojective(obstime=obstime), + axes_names=axes_names, + unit=self.get_units(self._i, self._i+1), + axis_physical_types=(f"custom:pos.helioprojective.{axes_types[0]}", + f"custom:pos.helioprojective.{axes_types[1]}"))) + + self._transforms.append(spatial_model_from_header(self.header)) + + self._i += 2 + + def make_spectral(self): + """ + Decide how to make a spectral axes. + """ + name = self.header[f'DWNAME{self.n}'] + self._frames.append(cf.SpectralFrame(axes_order=(self._i,), + axes_names=(name,), + unit=self.get_units(self._i), + name=name)) + + if "WAVE" in self.header.get(f'CTYPE{self.n}', ''): + transform = self.make_spectral_from_wcs() + elif "FRAMEWAV" in self.header.keys(): + transform = self.make_spectral_from_dataset() + else: + raise ValueError("Could not parse spectral WCS information from this header.") # pragma: no cover + + self._transforms.append(transform) + + self._i += 1 + + def make_spectral_from_dataset(self): + """ + Make a spectral axes from (VTF) dataset info. + """ + framewave = [h['FRAMEWAV'] for h in self.slice_headers[:self.header[f'DNAXIS{self.n}']]] + return spectral_model_from_framewave(framewave) + + def make_spectral_from_wcs(self): + """ + Add a spectral axes from the FITS-WCS keywords. + """ + return linear_spectral_model(self.header[f'CDELT{self.n}']*u.nm, + self.header[f'CRVAL{self.n}']*u.nm) diff --git a/docs/asdfmaker.rst b/docs/asdfmaker.rst deleted file mode 100644 index dabd7206..00000000 --- a/docs/asdfmaker.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. _dkist-asdf-maker: - -asdf Maker Module -================= - -This module provides tools to make DKIST asdf datasets. It provides tools to -read DKIST FITS files and generate the gWCS and array structures that are then -saved in the asdf file. - - -.. automodapi:: dkist.asdf_maker diff --git a/docs/index.rst b/docs/index.rst index fbe397cf..48b6d81f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,5 @@ A Python library of tools for obtaining, processing and interacting with DKIST d self dataset io - asdfmaker generated/gallery/index whatsnew/index diff --git a/docs/io.rst b/docs/io.rst index 90ddece2..287cc371 100644 --- a/docs/io.rst +++ b/docs/io.rst @@ -24,3 +24,6 @@ API Reference .. automodapi:: dkist.io.array_containers :headings: #^ + +.. automodapi:: dkist.io.asdf.generator + :headings: #^ diff --git a/pyproject.toml b/pyproject.toml index 67fd6a4c..0035a675 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ length_sort_stdlib = true help_url = "https://github.com/DKISTDC/dkist/changelog/README.rst" [ tool.gilesbot.milestones ] - enabled = true + enabled = false [tool.towncrier] package = "dkist" diff --git a/tox.ini b/tox.ini index e1384a15..7ff22650 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{36,37,38}, build_docs +envlist = py{36,37,38},build_docs isolated_build = True [testenv] @@ -14,7 +14,6 @@ extras = tests commands = {env:PYTEST_COMMAND} {posargs} [testenv:build_docs] -basepython = python3.7 extras = docs deps = {[testenv]deps}