Skip to content

Commit

Permalink
Merge pull request #416 from drdavella/asdf-fits-gzip
Browse files Browse the repository at this point in the history
Fix bug when reading gzipped ASDF-in-FITS file
  • Loading branch information
drdavella authored Dec 21, 2017
2 parents 5c69797 + a5570ab commit e8cd5a5
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
1.4.0 (unreleased)
------------------

- Improve the way URIs are detected for ASDF-in-FITS files in order to fix bug
with reading gzipped ASDF-in-FITS files. [#416]

- Explicitly disallow access to entire tree for ASDF file objects that have
been closed. [#407]

Expand Down
8 changes: 3 additions & 5 deletions asdf/fits_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,15 @@ def open(cls, fd, uri=None, validate_checksums=False, extensions=None,
if isinstance(fd, fits.hdu.hdulist.HDUList):
hdulist = fd
else:
file_obj = generic_io.get_file(fd, uri=uri)
uri = file_obj._uri if uri is None and file_obj._uri else ''
uri = generic_io.get_uri(fd)
try:
hdulist = fits.open(file_obj)
hdulist = fits.open(fd)
# Since we created this HDUList object, we need to be
# responsible for cleaning up upon close() or __exit__
close_hdulist = True
except IOError:
file_obj.close()
msg = "Failed to parse given file '{}'. Is it FITS?"
raise ValueError(msg.format(file_obj.uri))
raise ValueError(msg.format(uri))

self = cls(hdulist, uri=uri, extensions=extensions,
ignore_version_mismatch=ignore_version_mismatch,
Expand Down
18 changes: 17 additions & 1 deletion asdf/generic_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .extern import atomicfile


__all__ = ['get_file', 'resolve_uri', 'relative_uri']
__all__ = ['get_file', 'get_uri', 'resolve_uri', 'relative_uri']


_local_file_schemes = ['', 'file']
Expand Down Expand Up @@ -1083,6 +1083,22 @@ def _make_http_connection(init, mode, uri=None):
return HTTPConnection(connection, size, parsed.path, uri or init,
first_chunk)

def get_uri(file_obj):
"""
Returns the uri of the given file object
Parameters
----------
uri : object
"""
if isinstance(file_obj, six.string_types):
return file_obj
if isinstance(file_obj, GenericFile):
return file_obj.uri

# A catch-all for types from Python's io module that have names
return getattr(file_obj, 'name', '')


def get_file(init, mode='r', uri=None):
"""
Expand Down
Binary file added asdf/tests/data/asdf.fits.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion asdf/tests/setup_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
def get_package_data(): # pragma: no cover
return {
str(_PACKAGE_NAME_ + '.tests'):
['coveragerc', 'data/*.yaml', 'data/*.json', 'data/*.fits']}
['coveragerc', 'data/*.yaml', 'data/*.json', 'data/*.fits', 'data/*.fits.gz']}
18 changes: 16 additions & 2 deletions asdf/tests/test_fits_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ def test_asdf_open(tmpdir):
with asdf_open(hdulist) as ff:
compare_asdfs(asdf_in_fits, ff)

def test_open_gzipped():
testfile = os.path.join(TEST_DATA_PATH, 'asdf.fits.gz')

# Opening as an HDU should work
with fits.open(testfile) as ff:
with asdf.AsdfFile.open(ff) as af:
assert af.tree['stuff'].shape == (20, 20)

with fits_embed.AsdfInFits.open(testfile) as af:
assert af.tree['stuff'].shape == (20, 20)

with asdf.AsdfFile.open(testfile) as af:
assert af.tree['stuff'].shape == (20, 20)

def test_bad_input(tmpdir):
"""Make sure these functions behave properly with bad input"""
text_file = os.path.join(str(tmpdir), 'test.txt')
Expand All @@ -237,7 +251,7 @@ def test_version_mismatch_file():
assert len(w) == 1
assert str(w[0].message) == (
"'tag:stsci.edu:asdf/core/complex' with version 7.0.0 found in file "
"'file://{}', but latest supported version is 1.0.0".format(testfile))
"'{}', but latest supported version is 1.0.0".format(testfile))

# Make sure warning does not occur when warning is ignored (default)
with catch_warnings() as w:
Expand All @@ -252,7 +266,7 @@ def test_version_mismatch_file():
assert len(w) == 1
assert str(w[0].message) == (
"'tag:stsci.edu:asdf/core/complex' with version 7.0.0 found in file "
"'file://{}', but latest supported version is 1.0.0".format(testfile))
"'{}', but latest supported version is 1.0.0".format(testfile))

# Make sure warning does not occur when warning is ignored (default)
with catch_warnings() as w:
Expand Down

0 comments on commit e8cd5a5

Please sign in to comment.