Skip to content

Commit

Permalink
Merge pull request #534 from vmarkovtsev/patch-3
Browse files Browse the repository at this point in the history
Allow typing.NamedTuple to be serialized
  • Loading branch information
drdavella authored Aug 29, 2018
2 parents 317a4ee + 975503f commit c44d92a
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

- Store ASDF-in-FITS data inside a 1x1 BINTABLE HDU. [#519]

- Allow implicit conversion of ``namedtuple`` into serializable types. [#534]

2.0.3 (unreleased)
------------------

Expand Down
22 changes: 17 additions & 5 deletions asdf/asdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class AsdfFile(versioning.VersionedMixin):
The main class that represents an ASDF file object.
"""
def __init__(self, tree=None, uri=None, extensions=None, version=None,
ignore_version_mismatch=True, ignore_unrecognized_tag=False,
copy_arrays=False, custom_schema=None):
ignore_version_mismatch=True, ignore_unrecognized_tag=False,
ignore_implicit_conversion=False, copy_arrays=False,
custom_schema=None):
"""
Parameters
----------
Expand Down Expand Up @@ -81,6 +82,12 @@ def __init__(self, tree=None, uri=None, extensions=None, version=None,
When `True`, do not raise warnings for unrecognized tags. Set to
`False` by default.
ignore_implicit_conversion : bool
When `True`, do not raise warnings when types in the tree are
implicitly converted into a serializable object. The motivating
case for this is currently `namedtuple`, which cannot be serialized
as-is.
copy_arrays : bool, optional
When `False`, when reading files, attempt to memmap underlying data
arrays when possible.
Expand All @@ -90,6 +97,7 @@ def __init__(self, tree=None, uri=None, extensions=None, version=None,
validation pass. This can be used to ensure that particular ASDF
files follow custom conventions beyond those enforced by the
standard.
"""

if custom_schema is not None:
Expand All @@ -104,6 +112,7 @@ def __init__(self, tree=None, uri=None, extensions=None, version=None,
self._process_extensions(extensions)
self._ignore_version_mismatch = ignore_version_mismatch
self._ignore_unrecognized_tag = ignore_unrecognized_tag
self._ignore_implicit_conversion = ignore_implicit_conversion

self._file_format_version = None

Expand Down Expand Up @@ -1068,9 +1077,12 @@ def find_references(self):
Finds all external "JSON References" in the tree and converts
them to `reference.Reference` objects.
"""
# Set directly to self._tree, since it doesn't need to be
# re-validated.
self._tree = reference.find_references(self._tree, self)
# Since this is the first place that the tree is processed when
# creating a new ASDF object, this is where we pass the option to
# ignore warnings about implicit type conversions.
# Set directly to self._tree, since it doesn't need to be re-validated.
self._tree = reference.find_references(self._tree, self,
ignore_implicit_conversion=self._ignore_implicit_conversion)

def resolve_references(self, do_not_fill_defaults=False):
"""
Expand Down
5 changes: 3 additions & 2 deletions asdf/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def validate(self, data):
pass


def find_references(tree, ctx):
def find_references(tree, ctx, ignore_implicit_conversion=False):
"""
Find all of the JSON references in the tree, and convert them into
`Reference` objects.
Expand All @@ -138,7 +138,8 @@ def do_find(tree, json_id):
return Reference(tree['$ref'], json_id, asdffile=ctx)
return tree

return treeutil.walk_and_modify(tree, do_find)
return treeutil.walk_and_modify(
tree, do_find, ignore_implicit_conversion=ignore_implicit_conversion)


def resolve_references(tree, ctx, do_not_fill_defaults=False):
Expand Down
13 changes: 7 additions & 6 deletions asdf/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def recurse(old, new):


def assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None,
raw_yaml_check_func=None, write_options={}, extensions=None,
raw_yaml_check_func=None, write_options={},
init_options={}, extensions=None,
tree_match_func='assert_equal'):
"""
Assert that a given tree saves to ASDF and, when loaded back,
Expand All @@ -171,7 +172,7 @@ def assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None,

# First, test writing/reading a BytesIO buffer
buff = io.BytesIO()
AsdfFile(tree, extensions=extensions).write_to(buff, **write_options)
AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
assert not buff.closed
buff.seek(0)
with AsdfFile.open(buff, mode='rw', extensions=extensions) as ff:
Expand All @@ -184,7 +185,7 @@ def assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None,
asdf_check_func(ff)

buff.seek(0)
ff = AsdfFile(extensions=extensions)
ff = AsdfFile(extensions=extensions, **init_options)
content = AsdfFile._open_impl(ff, buff, _get_yaml_content=True)
buff.close()
# We *never* want to get any raw python objects out
Expand All @@ -195,7 +196,7 @@ def assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None,
raw_yaml_check_func(content)

# Then, test writing/reading to a real file
ff = AsdfFile(tree, extensions=extensions)
ff = AsdfFile(tree, extensions=extensions, **init_options)
ff.write_to(fname, **write_options)
with AsdfFile.open(fname, mode='rw', extensions=extensions) as ff:
assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
Expand All @@ -205,7 +206,7 @@ def assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None,
# Make sure everything works without a block index
write_options['include_block_index'] = False
buff = io.BytesIO()
AsdfFile(tree, extensions=extensions).write_to(buff, **write_options)
AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
assert not buff.closed
buff.seek(0)
with AsdfFile.open(buff, mode='rw', extensions=extensions) as ff:
Expand All @@ -219,7 +220,7 @@ def assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None,
if not INTERNET_OFF and not sys.platform.startswith('win'):
server = RangeHTTPServer()
try:
ff = AsdfFile(tree, extensions=extensions)
ff = AsdfFile(tree, extensions=extensions, **init_options)
ff.write_to(os.path.join(server.tmpdir, 'test.asdf'), **write_options)
with AsdfFile.open(server.url + 'test.asdf', mode='r',
extensions=extensions) as ff:
Expand Down
89 changes: 84 additions & 5 deletions asdf/tests/test_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# -*- coding: utf-8 -*-

import io
from collections import OrderedDict
from collections import namedtuple, OrderedDict
from typing import NamedTuple

import numpy as np

Expand Down Expand Up @@ -88,6 +89,21 @@ class Foo(object):
ff.write_to(buff)


def run_tuple_test(tree, tmpdir):
def check_asdf(asdf):
assert isinstance(asdf.tree['val'], list)

def check_raw_yaml(content):
assert b'tuple' not in content

# Ignore these warnings for the tests that don't actually test the warning
init_options = dict(ignore_implicit_conversion=True)

helpers.assert_roundtrip_tree(tree, tmpdir, asdf_check_func=check_asdf,
raw_yaml_check_func=check_raw_yaml,
init_options=init_options)


def test_python_tuple(tmpdir):
# We don't want to store tuples as tuples, because that's not a
# built-in YAML data type. This test ensures that they are
Expand All @@ -97,14 +113,77 @@ def test_python_tuple(tmpdir):
"val": (1, 2, 3)
}

run_tuple_test(tree, tmpdir)


def test_named_tuple_collections(tmpdir):
# Ensure that we are able to serialize a collections.namedtuple.

nt = namedtuple("TestNamedTuple1", ("one", "two", "three"))

tree = {
"val": nt(1, 2, 3)
}

run_tuple_test(tree, tmpdir)

def test_named_tuple_typing(tmpdir):
# Ensure that we are able to serialize a typing.NamedTuple.

nt = NamedTuple("TestNamedTuple2",
(("one", int), ("two", int), ("three", int)))
tree = {
"val": nt(1, 2, 3)
}

run_tuple_test(tree, tmpdir)


def test_named_tuple_collections_recursive(tmpdir):
nt = namedtuple("TestNamedTuple3", ("one", "two", "three"))

tree = {
"val": nt(1, 2, np.ones(3))
}

def check_asdf(asdf):
assert isinstance(asdf.tree['val'], list)
assert (asdf.tree['val'][2] == np.ones(3)).all()

init_options = dict(ignore_implicit_conversion=True)
helpers.assert_roundtrip_tree(tree, tmpdir, asdf_check_func=check_asdf,
init_options=init_options)

def check_raw_yaml(content):
assert b'tuple' not in content

def test_named_tuple_typing_recursive(tmpdir):
nt = NamedTuple("TestNamedTuple4",
(("one", int), ("two", int), ("three", np.ndarray)))

tree = {
"val": nt(1, 2, np.ones(3))
}

def check_asdf(asdf):
assert (asdf.tree['val'][2] == np.ones(3)).all()

init_options = dict(ignore_implicit_conversion=True)
helpers.assert_roundtrip_tree(tree, tmpdir, asdf_check_func=check_asdf,
raw_yaml_check_func=check_raw_yaml)
init_options=init_options)


def test_implicit_conversion_warning():
nt = namedtuple("TestTupleWarning", ("one", "two", "three"))

tree = {
"val": nt(1, 2, np.ones(3))
}

with pytest.warns(UserWarning, match="Failed to serialize instance"):
with asdf.AsdfFile(tree) as af:
pass

with pytest.warns(None) as w:
with asdf.AsdfFile(tree, ignore_implicit_conversion=True) as af:
assert len(w) == 0


def test_tags_removed_after_load(tmpdir):
Expand Down
32 changes: 27 additions & 5 deletions asdf/treeutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import inspect
import warnings

from .tagged import tag_object

Expand Down Expand Up @@ -80,7 +81,7 @@ def recurse(tree):
return recurse(top)


def walk_and_modify(top, callback):
def walk_and_modify(top, callback, ignore_implicit_conversion=False):
"""Modify a tree by walking it with a callback function. It also has
the effect of doing a deep copy.
Expand All @@ -105,6 +106,13 @@ def walk_and_modify(top, callback):
The callback is called on an instance after all of its
children have been visited (depth-first order).
ignore_implicit_conversion : bool
Controls whether warnings should be issued when implicitly converting a
given type instance in the tree into a serializable object. The primary
case for this is currently `namedtuple`.
Defaults to `False`.
Returns
-------
tree : object
Expand Down Expand Up @@ -134,8 +142,13 @@ def recurse(tree):
result = tag_object(tree._tag, result)
elif isinstance(tree, (list, tuple)):
seen.add(id_tree)
result = tree.__class__(
[recurse(val) for val in tree])
contents = [recurse(val) for val in tree]
try:
result = tree.__class__(contents)
except TypeError:
# the derived class' signature is different
# erase the type
result = contents
seen.remove(id_tree)
if hasattr(tree, '_tag'):
result = tag_object(tree._tag, result)
Expand Down Expand Up @@ -166,8 +179,17 @@ def recurse_with_json_ids(tree, json_id):
result = tag_object(tree._tag, result)
elif isinstance(tree, (list, tuple)):
seen.add(id_tree)
result = tree.__class__(
[recurse_with_json_ids(val, json_id) for val in tree])
contents = [recurse_with_json_ids(val, json_id) for val in tree]
try:
result = tree.__class__(contents)
except TypeError:
# The derived class signature is different, so simply store the
# list representing the contents. Currently this is primarly
# intended to handle namedtuple and NamedTuple instances.
if not ignore_implicit_conversion:
msg = "Failed to serialize instance of {}, converting to list instead"
warnings.warn(msg.format(type(tree)))
result = contents
seen.remove(id_tree)
if hasattr(tree, '_tag'):
result = tag_object(tree._tag, result)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ open_files_ignore = test.fits asdf.fits
# Account for both the astropy test runner case and the native pytest case
asdf_schema_root = asdf-standard/schemas asdf/schemas
asdf_schema_skip_names = asdf-schema-1.0.0 draft-01
addopts = --doctest-rst
#addopts = --doctest-rst

[ah_bootstrap]
auto_use = True
Expand Down

0 comments on commit c44d92a

Please sign in to comment.