Skip to content

Commit

Permalink
Variety of issues raised by an ATLAS xAOD. (#124)
Browse files Browse the repository at this point in the history
* Fix version encoding/decoding.

* Allow streamer base version to be -1 (meaning None).

* Parse the 'long long' rules before the 'long' rules (because you want greedy matches).

* AsGrouped subbranches can have missing interpretations.

* Handle zero-leaf branches.

* Logic for handling AsGrouped.

* Implemented AsGrouped.

* Strip extraneous spaces from classnames so that the lookup matches.
  • Loading branch information
jpivarski authored Oct 7, 2020
1 parent 7e42e2f commit 57fafcf
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 107 deletions.
4 changes: 1 addition & 3 deletions tests/test_0017-multi-basket-multi-branch-fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,14 @@ def test_cache():
assert list(f.file.array_cache) == []
i4.array(uproot4.interpretation.numerical.AsDtype(">i4"), library="np")
assert list(f.file.array_cache) == [
"db4be408-93ad-11ea-9027-d201a8c0beef:/sample;1:i4(16):AsDtype(Bi4(),Li4()):0-30:np"
"db4be408-93ad-11ea-9027-d201a8c0beef:/sample;1:i4(16):i4:AsDtype(Bi4(),Li4()):0-30:np"
]

with pytest.raises(OSError):
i4.array(
uproot4.interpretation.numerical.AsDtype(">i4"), entry_start=3, library="np"
)

i4.array(uproot4.interpretation.numerical.AsDtype(">i4"), library="np")


def test_pandas():
pandas = pytest.importorskip("pandas")
Expand Down
6 changes: 0 additions & 6 deletions tests/test_0118-fix-name-fetch-again.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def test_numpy():
assert list(t.arrays("P3.Px", library="np").keys()) == ["P3.Px"]
with pytest.raises(Exception):
t.arrays("/P3.Px", library="np")
with pytest.raises(Exception):
t.arrays("P3/P3.Px", library="np")
assert list(t.arrays("evt/P3/P3.Px", library="np").keys()) == ["evt/P3/P3.Px"]
assert list(t.arrays("/evt/P3/P3.Px", library="np").keys()) == ["/evt/P3/P3.Px"]
assert list(t["evt"].arrays("P3.Px", library="np").keys()) == ["P3.Px"]
Expand Down Expand Up @@ -104,8 +102,6 @@ def test_awkward():
assert t.arrays("P3.Px", library="ak").fields == ["P3.Px"]
with pytest.raises(Exception):
t.arrays("/P3.Px", library="ak")
with pytest.raises(Exception):
t.arrays("P3/P3.Px", library="ak")
assert t.arrays("evt/P3/P3.Px", library="ak").fields == ["evt/P3/P3.Px"]
assert t.arrays("/evt/P3/P3.Px", library="ak").fields == ["/evt/P3/P3.Px"]
assert t["evt"].arrays("P3.Px", library="ak").fields == ["P3.Px"]
Expand Down Expand Up @@ -160,8 +156,6 @@ def test_pandas():
assert t.arrays("P3.Px", library="pd").columns.tolist() == ["P3.Px"]
with pytest.raises(Exception):
t.arrays("/P3.Px", library="pd")
with pytest.raises(Exception):
t.arrays("P3/P3.Px", library="pd")
assert t.arrays("evt/P3/P3.Px", library="pd").columns.tolist() == [
"evt/P3/P3.Px"
]
Expand Down
17 changes: 17 additions & 0 deletions tests/test_0123-atlas-issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE

from __future__ import absolute_import

import pytest
import skhep_testdata

import uproot4


def test_version():
assert uproot4.classname_decode(
uproot4.classname_encode("xAOD::MissingETAuxAssociationMap_v2")
) == ("xAOD::MissingETAuxAssociationMap_v2", None)
assert uproot4.classname_decode(
uproot4.classname_encode("xAOD::MissingETAuxAssociationMap_v2", 9)
) == ("xAOD::MissingETAuxAssociationMap_v2", 9)
148 changes: 119 additions & 29 deletions uproot4/behaviors/TBranch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,10 @@ def get_from_cache(branchname, interpretation):
arrays,
)

_fix_asgrouped(
arrays, expression_context, branchid_interpretation, library, how
)

if array_cache is not None:
checked = set()
for expression, context in expression_context:
Expand Down Expand Up @@ -1353,6 +1357,10 @@ def iterate(
arrays,
)

_fix_asgrouped(
arrays, expression_context, branchid_interpretation, library, how
)

output = language.compute_expressions(
self,
arrays,
Expand Down Expand Up @@ -1999,7 +2007,6 @@ def array(
interpretation = self.interpretation
else:
interpretation = _regularize_interpretation(interpretation)
branchid_interpretation = {self.cache_key: interpretation}

entry_start, entry_stop = _regularize_entries_start_stop(
self.num_entries, entry_start, entry_stop
Expand All @@ -2010,25 +2017,48 @@ def array(
array_cache = _regularize_array_cache(array_cache, self._file)
library = uproot4.interpretation.library._regularize_library(library)

cache_key = "{0}:{1}:{2}-{3}:{4}".format(
self.cache_key,
interpretation.cache_key,
entry_start,
entry_stop,
library.name,
def get_from_cache(branchname, interpretation):
if array_cache is not None:
cache_key = "{0}:{1}:{2}:{3}-{4}:{5}".format(
self.cache_key,
branchname,
interpretation.cache_key,
entry_start,
entry_stop,
library.name,
)
return array_cache.get(cache_key)
else:
return None

arrays = {}
expression_context = []
branchid_interpretation = {}
_regularize_branchname(
self,
self.name,
self,
interpretation,
get_from_cache,
arrays,
expression_context,
branchid_interpretation,
True,
False,
)
if array_cache is not None:
got = array_cache.get(cache_key)
if got is not None:
return got

ranges_or_baskets = []
for basket_num, range_or_basket in self.entries_to_ranges_or_baskets(
entry_start, entry_stop
):
ranges_or_baskets.append((self, basket_num, range_or_basket))
checked = set()
for expression, context in expression_context:
for branch in context["branches"]:
if branch.cache_key not in checked:
checked.add(branch.cache_key)
for (
basket_num,
range_or_basket,
) in branch.entries_to_ranges_or_baskets(entry_start, entry_stop):
ranges_or_baskets.append((branch, basket_num, range_or_basket))

arrays = {}
_ranges_or_baskets_to_arrays(
self,
ranges_or_baskets,
Expand All @@ -2041,7 +2071,19 @@ def array(
arrays,
)

_fix_asgrouped(
arrays, expression_context, branchid_interpretation, library, None
)

if array_cache is not None:
cache_key = "{0}:{1}:{2}:{3}-{4}:{5}".format(
self.cache_key,
self.name,
interpretation.cache_key,
entry_start,
entry_stop,
library.name,
)
array_cache[cache_key] = arrays[self.cache_key]

return arrays[self.cache_key]
Expand Down Expand Up @@ -2189,7 +2231,13 @@ def entry_offsets(self):
for basket in self.embedded_baskets:
out.append(out[-1] + basket.num_entries)

if out[-1] != self.num_entries and self.interpretation is not None:
if (
out[-1] != self.num_entries
and self.interpretation is not None
and not isinstance(
self.interpretation, uproot4.interpretation.grouped.AsGrouped
)
):
raise ValueError(
"""entries in normal baskets ({0}) plus embedded baskets ({1}) """
"""don't add up to expected number of entries ({2})
Expand Down Expand Up @@ -2922,6 +2970,29 @@ def _regularize_branchname(

is_jagged = isinstance(interpretation, uproot4.interpretation.jagged.AsJagged)

if isinstance(interpretation, uproot4.interpretation.grouped.AsGrouped):
branches = []
for subname, subinterp in interpretation.subbranches.items():
_regularize_branchname(
hasbranches,
subname,
branch[subname],
subinterp,
get_from_cache,
arrays,
expression_context,
branchid_interpretation,
False,
is_cut,
)
branches.extend(expression_context[-1][1]["branches"])

branches.append(branch)
arrays[branch.cache_key] = None

else:
branches = [branch]

if branch.cache_key in branchid_interpretation:
if (
branchid_interpretation[branch.cache_key].cache_key
Expand All @@ -2937,18 +3008,14 @@ def _regularize_branchname(
else:
branchid_interpretation[branch.cache_key] = interpretation

expression_context.append(
(
branchname,
{
"is_primary": is_primary,
"is_cut": is_cut,
"is_jagged": is_jagged,
"is_branch": True,
"branches": [branch],
},
)
)
c = {
"is_primary": is_primary,
"is_cut": is_cut,
"is_jagged": is_jagged,
"is_branch": True,
"branches": branches,
}
expression_context.append((branchname, c))


def _regularize_expression(
Expand Down Expand Up @@ -3309,6 +3376,29 @@ def basket_to_array(basket):
raise AssertionError(obj)


def _fix_asgrouped(arrays, expression_context, branchid_interpretation, library, how):
index_start = 0
for index_stop, (expression, context) in enumerate(expression_context):
if context["is_branch"]:
branch = context["branches"][-1]
interpretation = branchid_interpretation[branch.cache_key]
if isinstance(interpretation, uproot4.interpretation.grouped.AsGrouped):
assert arrays[branch.cache_key] is None

limited_context = dict(expression_context[index_start:index_stop])

subarrays = {}
subcontext = []
for subname in interpretation.subbranches:
subbranch = branch[subname]
subarrays[subname] = arrays[subbranch.cache_key]
subcontext.append((subname, limited_context[subname]))

arrays[branch.cache_key] = library.group(subarrays, subcontext, how)

index_start = index_stop


def _hasbranches_num_entries_for(
hasbranches, target_num_bytes, entry_start, entry_stop, branchid_interpretation
):
Expand Down
1 change: 1 addition & 0 deletions uproot4/deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def compile_class(file, classes, class_code, class_name):
new_scope[cls.__name__] = cls

def c(name, version=None):
name = uproot4.model.classname_regularize(name)
cls = new_scope.get(uproot4.model.classname_encode(name, version))
if cls is None:
cls = new_scope.get(uproot4.model.classname_encode(name))
Expand Down
15 changes: 10 additions & 5 deletions uproot4/interpretation/grouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@


class AsGrouped(uproot4.interpretation.Interpretation):
"""
u"""
Args:
branch (:py:class:`~uproot4.behavior.TBranch.TBranch`): The ``TBranch`` that
represents the group.
subbranches (list of :py:class:`~uproot4.behavior.TBranch.TBranch`): The
``TBranches`` that contain the actual data.
subbranches (dict of str \u2192 :py:class:`~uproot4.behavior.TBranch.TBranch`): Names
and interpretations of the ``TBranches`` that actually contain data.
typename (None or str): If None, construct a plausible C++ typename.
Otherwise, take the suggestion as given.
Expand Down Expand Up @@ -69,6 +69,7 @@ def cache_key(self):
",".join(
"{0}:{1}".format(repr(x), y.cache_key)
for x, y in self._subbranches.items()
if y is not None
),
)

Expand All @@ -81,6 +82,7 @@ def typename(self):
", ".join(
"{0}:{1}".format(x, y.typename)
for x, y in self._subbranches.items()
if y is not None
)
)

Expand All @@ -90,8 +92,11 @@ def awkward_form(self, file, index_format="i64", header=False, tobject_header=Tr
names = []
fields = []
for x, y in self._subbranches.items():
names.append(x)
fields.append(y.awkward_form(file, index_format, header, tobject_header))
if y is not None:
names.append(x)
fields.append(
y.awkward_form(file, index_format, header, tobject_header)
)

return awkward1.forms.RecordForm(fields, names)

Expand Down
Loading

0 comments on commit 57fafcf

Please sign in to comment.