Skip to content

Commit

Permalink
Merge pull request #24 from Cadair/plotting_1
Browse files Browse the repository at this point in the history
Bug fixes and cleanup on gwcs slicing and dataset creation
  • Loading branch information
Cadair authored Nov 21, 2018
2 parents 79a1a61 + 8f02ada commit e65a056
Show file tree
Hide file tree
Showing 18 changed files with 465 additions and 523 deletions.
1 change: 1 addition & 0 deletions changelog/24.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a lot of bugs in dataset generation and wcs slicing.
54 changes: 24 additions & 30 deletions dkist/asdf_maker/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,14 @@ def headers_from_filenames(filenames, hdu=0):
"""
A generator to get the headers from filenames.
"""
return [fits.getheader(fname, ext=hdu) for fname in filenames]
return [dict(fits.getheader(fname, ext=hdu)) for fname in filenames]


def table_from_headers(headers):
h0 = headers[0]
return Table(rows=headers, names=list(headers[0].keys()))

t = Table(list(zip(h0.values())), names=list(h0.keys()))

for h in headers[1:]:
t.add_row(h)

return t


def validate_headers(headers):
def validate_headers(table_headers):
"""
Given a bunch of headers, validate that they form a coherent set. This
function also adds the headers to a list as they are read from the file.
Expand All @@ -55,7 +48,7 @@ def validate_headers(headers):
out_headers : `list`
A list of headers.
"""
t = table_from_headers(headers)
t = table_headers

"""
Let's do roughly the minimal amount of verification here.
Expand All @@ -80,7 +73,7 @@ def validate_headers(headers):
if not all(col == col[0]):
raise ValueError(f"The {col.name} values did not all match:\n {col}")

return headers
return table_headers


def build_pixel_frame(header):
Expand Down Expand Up @@ -201,7 +194,7 @@ def slice_for_n(self):
naxes = self.header['DEAXES']
ss = [0] * naxes
ss[i] = slice(None)
return ss
return ss[::-1]

@property
def slice_headers(self):
Expand Down Expand Up @@ -322,16 +315,18 @@ def gwcs_from_headers(headers):
output_frame=world_frame)


def sorter_DINDEX(headers):
def make_sorted_table(headers, filenames):
"""
A sorting function based on the values of DINDEX in the header.
Return an `astropy.table.Table` instance where the rows are correctly sorted.
"""
t = table_from_headers(headers)
theaders = table_from_headers(headers)
theaders['filenames'] = filenames
theaders['headers'] = headers
dataset_axes = headers[0]['DNAXIS']
array_axes = headers[0]['DAAXES']
keys = [f'DINDEX{k}' for k in range(dataset_axes, array_axes, -1)]
t = np.array(t[keys])
return np.argsort(t, order=keys)
t = np.array(theaders[keys])
return theaders[np.argsort(t, order=keys)]


def asdf_tree_from_filenames(filenames, hdu=0, relative_to=None):
Expand All @@ -350,34 +345,30 @@ def asdf_tree_from_filenames(filenames, hdu=0, relative_to=None):
# headers is an iterator
headers = headers_from_filenames(filenames, hdu=hdu)

# headers is a now list
headers = validate_headers(headers)
table_headers = make_sorted_table(headers, filenames)

sort_inds = sorter_DINDEX(headers)

sort_heads = ((head, sort_inds[i]) for i, head in enumerate(headers))
heads = sorted(sort_heads, key=lambda h: h[1])
headers = [head[0] for head in heads]
validate_headers(table_headers)

# Sort the filenames into DS order.
sorted_filenames = np.array(filenames)[sort_inds]
sorted_filenames = np.array(table_headers['filenames'])
sorted_headers = np.array(table_headers['headers'])

# Get the array shape
shape = tuple((headers[0][f'DNAXIS{n}'] for n in range(headers[0]['DNAXIS'],
headers[0]['DAAXES'], -1)))
# References from filenames
reference_array = references_from_filenames(sorted_filenames, array_shape=shape,
reference_array = references_from_filenames(sorted_filenames, sorted_headers, array_shape=shape,
hdu_index=hdu, relative_to=relative_to)

tree = {'dataset': reference_array,
'gwcs': gwcs_from_headers(headers)}
'gwcs': gwcs_from_headers(sorted_headers)}

# TODO: Write a schema for the tree.

return tree


def dataset_from_fits(path, asdf_filename, hdu=0, relative_to=None):
def dataset_from_fits(path, asdf_filename, hdu=0, relative_to=None, **kwargs):
"""
Given a path containing FITS files write an asdf file in the same path.
Expand All @@ -392,6 +383,9 @@ def dataset_from_fits(path, asdf_filename, hdu=0, relative_to=None):
hdu : `int`
The HDU to read from the FITS files.
kwargs
Additional kwargs are passed to `asdf.AsdfFile.write_to`.
"""
path = pathlib.Path(path)

Expand All @@ -400,4 +394,4 @@ def dataset_from_fits(path, asdf_filename, hdu=0, relative_to=None):
tree = asdf_tree_from_filenames(list(files), hdu=hdu, relative_to=relative_to)

with asdf.AsdfFile(tree) as afile:
afile.write_to(str(path/asdf_filename))
afile.write_to(str(path/asdf_filename), **kwargs)
29 changes: 15 additions & 14 deletions dkist/asdf_maker/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import asdf
import astropy.units as u
from astropy.io import fits
from astropy.time import Time
from gwcs.lookup_table import LookupTable
from astropy.modeling.models import (Shift, Linear1D, Multiply, Pix2Sky_TAN,
Expand All @@ -16,7 +15,7 @@
'spatial_model_from_quantity', 'spatial_model_from_header', 'references_from_filenames']


def references_from_filenames(filenames, array_shape, hdu_index=0, relative_to=None):
def references_from_filenames(filenames, headers, array_shape, hdu_index=0, relative_to=None):
"""
Given an array of paths to FITS files create a set of nested lists of
`asdf.external_reference.ExternalArrayReference` objects with the same
Expand All @@ -28,6 +27,9 @@ def references_from_filenames(filenames, array_shape, hdu_index=0, relative_to=N
filenames : `numpy.ndarray`
An array of filenames, in numpy order for the output array (i.e. ``.flat``)
headers : `list`
A list of headers for files
array_shape : `tuple`
The desired output shape of the reference array. (i.e the shape of the
data minus the HDU dimensions.)
Expand All @@ -45,20 +47,17 @@ def references_from_filenames(filenames, array_shape, hdu_index=0, relative_to=N
raise ValueError(f"An incorrect number of filenames ({filenames.size})"
f" supplied for array_shape ({array_shape})")

for i, filepath in enumerate(filenames.flat):
with fits.open(filepath) as hdul:
hdu = hdul[hdu_index]
dtype = BITPIX2DTYPE[hdu.header['BITPIX']]
# hdu.shape is already in Python order
shape = tuple(hdu.shape)
for i, (filepath, head) in enumerate(zip(filenames.flat, headers.flat)):
dtype = BITPIX2DTYPE[head['BITPIX']]
shape = tuple([int(head[f"NAXIS{a}"]) for a in range(head["NAXIS"], 0, -1)])

# Convert paths to relative paths
relative_path = filepath
if relative_to:
relative_path = os.path.relpath(filepath, relative_to)
# Convert paths to relative paths
relative_path = filepath
if relative_to:
relative_path = os.path.relpath(filepath, str(relative_to))

reference_array.flat[i] = ExternalArrayReference(
relative_path, hdu_index, dtype, shape)
reference_array.flat[i] = ExternalArrayReference(
relative_path, hdu_index, dtype, shape)

return reference_array.tolist()

Expand Down Expand Up @@ -160,6 +159,7 @@ def time_model_from_date_obs(date_obs, date_bgn=None):
intercept = 0 * u.s
return linear_time_model(cadence=slope, reference_val=intercept)
else:
print(f"creating tabular temporal axis. ddeltas: {ddelta}")
return LookupTable(deltas.to(u.s))


Expand All @@ -180,6 +180,7 @@ def spectral_model_from_framewave(framewav):
slope = ddeltas[0]
return linear_spectral_model(slope, wave_bgn)
else:
print(f"creating tabular wavelength axis. ddeltas: {ddeltas}")
return LookupTable(framewav)


Expand Down
12 changes: 7 additions & 5 deletions dkist/asdf_maker/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from dkist.dataset import Dataset
from dkist.asdf_maker.generator import (validate_headers, dataset_from_fits, gwcs_from_headers,
headers_from_filenames, asdf_tree_from_filenames)
headers_from_filenames, asdf_tree_from_filenames,
table_from_headers)


@pytest.fixture
Expand All @@ -36,8 +37,9 @@ def test_frames(transform_builder):

def test_input_name_ordering(wcs):
# Check the ordering of the input and output frames
allowed_pixel_names = (('spatial x', 'spatial y', 'wavelength position', 'scan number', 'stokes'),
('wavelength', 'slit position', 'raster position', 'scan number', 'stokes'))
allowed_pixel_names = (('spatial x', 'spatial y', 'wavelength position', 'scan number',
'stokes'), ('wavelength', 'slit position', 'raster position',
'scan number', 'stokes'))
assert wcs.input_frame.axes_names in allowed_pixel_names


Expand Down Expand Up @@ -72,12 +74,12 @@ def test_validator(header_filenames):
headers = headers_from_filenames(header_filenames)
headers[10]['NAXIS'] = 5
with pytest.raises(ValueError) as excinfo:
validate_headers(headers)
validate_headers(table_from_headers(headers))
assert "NAXIS" in str(excinfo)


def test_make_asdf(header_filenames, tmpdir):
path = pathlib.Path(header_filenames[0])
dataset_from_fits(path.parent, "test.asdf")
assert (path.parent/"test.asdf").exists()
assert (path.parent / "test.asdf").exists()
assert isinstance(Dataset.from_directory(str(path.parent)), Dataset)
10 changes: 7 additions & 3 deletions dkist/asdf_maker/tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
import numpy as np

import asdf
import astropy.units as u
Expand All @@ -13,21 +14,24 @@
from dkist.asdf_maker.helpers import (make_asdf, linear_time_model, linear_spectral_model,
time_model_from_date_obs, references_from_filenames,
spatial_model_from_header, spectral_model_from_framewave)
from dkist.asdf_maker.generator import asdf_tree_from_filenames
from dkist.asdf_maker.generator import asdf_tree_from_filenames, headers_from_filenames


def test_references_from_filesnames_shape_error(header_filenames):
headers = headers_from_filenames(header_filenames, hdu=0)
with pytest.raises(ValueError) as exc:
references_from_filenames(header_filenames, [2, 3])
references_from_filenames(header_filenames, headers, [2, 3])

assert "incorrect number" in str(exc)
assert "2, 3" in str(exc)
assert str(len(header_filenames)) in str(exc)


def test_references_from_filenames(header_filenames):
headers = headers_from_filenames(header_filenames, hdu=0)
base = os.path.split(header_filenames[0])[0]
refs = references_from_filenames(header_filenames, (len(header_filenames),), relative_to=base)
refs = references_from_filenames(header_filenames, np.array(headers, dtype=object),
(len(header_filenames),), relative_to=base)

for ref in refs:
assert base not in ref.fileuri
Expand Down
Loading

0 comments on commit e65a056

Please sign in to comment.