Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variety of issues raised by an ATLAS xAOD. #124

Merged
merged 8 commits into from
Oct 7, 2020
4 changes: 1 addition & 3 deletions tests/test_0017-multi-basket-multi-branch-fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,14 @@ def test_cache():
assert list(f.file.array_cache) == []
i4.array(uproot4.interpretation.numerical.AsDtype(">i4"), library="np")
assert list(f.file.array_cache) == [
"db4be408-93ad-11ea-9027-d201a8c0beef:/sample;1:i4(16):AsDtype(Bi4(),Li4()):0-30:np"
"db4be408-93ad-11ea-9027-d201a8c0beef:/sample;1:i4(16):i4:AsDtype(Bi4(),Li4()):0-30:np"
]

with pytest.raises(OSError):
i4.array(
uproot4.interpretation.numerical.AsDtype(">i4"), entry_start=3, library="np"
)

i4.array(uproot4.interpretation.numerical.AsDtype(">i4"), library="np")


def test_pandas():
pandas = pytest.importorskip("pandas")
Expand Down
6 changes: 0 additions & 6 deletions tests/test_0118-fix-name-fetch-again.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def test_numpy():
assert list(t.arrays("P3.Px", library="np").keys()) == ["P3.Px"]
with pytest.raises(Exception):
t.arrays("/P3.Px", library="np")
with pytest.raises(Exception):
t.arrays("P3/P3.Px", library="np")
assert list(t.arrays("evt/P3/P3.Px", library="np").keys()) == ["evt/P3/P3.Px"]
assert list(t.arrays("/evt/P3/P3.Px", library="np").keys()) == ["/evt/P3/P3.Px"]
assert list(t["evt"].arrays("P3.Px", library="np").keys()) == ["P3.Px"]
Expand Down Expand Up @@ -104,8 +102,6 @@ def test_awkward():
assert t.arrays("P3.Px", library="ak").fields == ["P3.Px"]
with pytest.raises(Exception):
t.arrays("/P3.Px", library="ak")
with pytest.raises(Exception):
t.arrays("P3/P3.Px", library="ak")
assert t.arrays("evt/P3/P3.Px", library="ak").fields == ["evt/P3/P3.Px"]
assert t.arrays("/evt/P3/P3.Px", library="ak").fields == ["/evt/P3/P3.Px"]
assert t["evt"].arrays("P3.Px", library="ak").fields == ["P3.Px"]
Expand Down Expand Up @@ -160,8 +156,6 @@ def test_pandas():
assert t.arrays("P3.Px", library="pd").columns.tolist() == ["P3.Px"]
with pytest.raises(Exception):
t.arrays("/P3.Px", library="pd")
with pytest.raises(Exception):
t.arrays("P3/P3.Px", library="pd")
assert t.arrays("evt/P3/P3.Px", library="pd").columns.tolist() == [
"evt/P3/P3.Px"
]
Expand Down
17 changes: 17 additions & 0 deletions tests/test_0123-atlas-issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE

from __future__ import absolute_import

import pytest
import skhep_testdata

import uproot4


def test_version():
assert uproot4.classname_decode(
uproot4.classname_encode("xAOD::MissingETAuxAssociationMap_v2")
) == ("xAOD::MissingETAuxAssociationMap_v2", None)
assert uproot4.classname_decode(
uproot4.classname_encode("xAOD::MissingETAuxAssociationMap_v2", 9)
) == ("xAOD::MissingETAuxAssociationMap_v2", 9)
148 changes: 119 additions & 29 deletions uproot4/behaviors/TBranch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,10 @@ def get_from_cache(branchname, interpretation):
arrays,
)

_fix_asgrouped(
arrays, expression_context, branchid_interpretation, library, how
)

if array_cache is not None:
checked = set()
for expression, context in expression_context:
Expand Down Expand Up @@ -1353,6 +1357,10 @@ def iterate(
arrays,
)

_fix_asgrouped(
arrays, expression_context, branchid_interpretation, library, how
)

output = language.compute_expressions(
self,
arrays,
Expand Down Expand Up @@ -1999,7 +2007,6 @@ def array(
interpretation = self.interpretation
else:
interpretation = _regularize_interpretation(interpretation)
branchid_interpretation = {self.cache_key: interpretation}

entry_start, entry_stop = _regularize_entries_start_stop(
self.num_entries, entry_start, entry_stop
Expand All @@ -2010,25 +2017,48 @@ def array(
array_cache = _regularize_array_cache(array_cache, self._file)
library = uproot4.interpretation.library._regularize_library(library)

cache_key = "{0}:{1}:{2}-{3}:{4}".format(
self.cache_key,
interpretation.cache_key,
entry_start,
entry_stop,
library.name,
def get_from_cache(branchname, interpretation):
if array_cache is not None:
cache_key = "{0}:{1}:{2}:{3}-{4}:{5}".format(
self.cache_key,
branchname,
interpretation.cache_key,
entry_start,
entry_stop,
library.name,
)
return array_cache.get(cache_key)
else:
return None

arrays = {}
expression_context = []
branchid_interpretation = {}
_regularize_branchname(
self,
self.name,
self,
interpretation,
get_from_cache,
arrays,
expression_context,
branchid_interpretation,
True,
False,
)
if array_cache is not None:
got = array_cache.get(cache_key)
if got is not None:
return got

ranges_or_baskets = []
for basket_num, range_or_basket in self.entries_to_ranges_or_baskets(
entry_start, entry_stop
):
ranges_or_baskets.append((self, basket_num, range_or_basket))
checked = set()
for expression, context in expression_context:
for branch in context["branches"]:
if branch.cache_key not in checked:
checked.add(branch.cache_key)
for (
basket_num,
range_or_basket,
) in branch.entries_to_ranges_or_baskets(entry_start, entry_stop):
ranges_or_baskets.append((branch, basket_num, range_or_basket))

arrays = {}
_ranges_or_baskets_to_arrays(
self,
ranges_or_baskets,
Expand All @@ -2041,7 +2071,19 @@ def array(
arrays,
)

_fix_asgrouped(
arrays, expression_context, branchid_interpretation, library, None
)

if array_cache is not None:
cache_key = "{0}:{1}:{2}:{3}-{4}:{5}".format(
self.cache_key,
self.name,
interpretation.cache_key,
entry_start,
entry_stop,
library.name,
)
array_cache[cache_key] = arrays[self.cache_key]

return arrays[self.cache_key]
Expand Down Expand Up @@ -2189,7 +2231,13 @@ def entry_offsets(self):
for basket in self.embedded_baskets:
out.append(out[-1] + basket.num_entries)

if out[-1] != self.num_entries and self.interpretation is not None:
if (
out[-1] != self.num_entries
and self.interpretation is not None
and not isinstance(
self.interpretation, uproot4.interpretation.grouped.AsGrouped
)
):
raise ValueError(
"""entries in normal baskets ({0}) plus embedded baskets ({1}) """
"""don't add up to expected number of entries ({2})
Expand Down Expand Up @@ -2922,6 +2970,29 @@ def _regularize_branchname(

is_jagged = isinstance(interpretation, uproot4.interpretation.jagged.AsJagged)

if isinstance(interpretation, uproot4.interpretation.grouped.AsGrouped):
branches = []
for subname, subinterp in interpretation.subbranches.items():
_regularize_branchname(
hasbranches,
subname,
branch[subname],
subinterp,
get_from_cache,
arrays,
expression_context,
branchid_interpretation,
False,
is_cut,
)
branches.extend(expression_context[-1][1]["branches"])

branches.append(branch)
arrays[branch.cache_key] = None

else:
branches = [branch]

if branch.cache_key in branchid_interpretation:
if (
branchid_interpretation[branch.cache_key].cache_key
Expand All @@ -2937,18 +3008,14 @@ def _regularize_branchname(
else:
branchid_interpretation[branch.cache_key] = interpretation

expression_context.append(
(
branchname,
{
"is_primary": is_primary,
"is_cut": is_cut,
"is_jagged": is_jagged,
"is_branch": True,
"branches": [branch],
},
)
)
c = {
"is_primary": is_primary,
"is_cut": is_cut,
"is_jagged": is_jagged,
"is_branch": True,
"branches": branches,
}
expression_context.append((branchname, c))


def _regularize_expression(
Expand Down Expand Up @@ -3309,6 +3376,29 @@ def basket_to_array(basket):
raise AssertionError(obj)


def _fix_asgrouped(arrays, expression_context, branchid_interpretation, library, how):
index_start = 0
for index_stop, (expression, context) in enumerate(expression_context):
if context["is_branch"]:
branch = context["branches"][-1]
interpretation = branchid_interpretation[branch.cache_key]
if isinstance(interpretation, uproot4.interpretation.grouped.AsGrouped):
assert arrays[branch.cache_key] is None

limited_context = dict(expression_context[index_start:index_stop])

subarrays = {}
subcontext = []
for subname in interpretation.subbranches:
subbranch = branch[subname]
subarrays[subname] = arrays[subbranch.cache_key]
subcontext.append((subname, limited_context[subname]))

arrays[branch.cache_key] = library.group(subarrays, subcontext, how)

index_start = index_stop


def _hasbranches_num_entries_for(
hasbranches, target_num_bytes, entry_start, entry_stop, branchid_interpretation
):
Expand Down
1 change: 1 addition & 0 deletions uproot4/deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def compile_class(file, classes, class_code, class_name):
new_scope[cls.__name__] = cls

def c(name, version=None):
name = uproot4.model.classname_regularize(name)
cls = new_scope.get(uproot4.model.classname_encode(name, version))
if cls is None:
cls = new_scope.get(uproot4.model.classname_encode(name))
Expand Down
15 changes: 10 additions & 5 deletions uproot4/interpretation/grouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@


class AsGrouped(uproot4.interpretation.Interpretation):
"""
u"""
Args:
branch (:py:class:`~uproot4.behavior.TBranch.TBranch`): The ``TBranch`` that
represents the group.
subbranches (list of :py:class:`~uproot4.behavior.TBranch.TBranch`): The
``TBranches`` that contain the actual data.
subbranches (dict of str \u2192 :py:class:`~uproot4.behavior.TBranch.TBranch`): Names
and interpretations of the ``TBranches`` that actually contain data.
typename (None or str): If None, construct a plausible C++ typename.
Otherwise, take the suggestion as given.

Expand Down Expand Up @@ -69,6 +69,7 @@ def cache_key(self):
",".join(
"{0}:{1}".format(repr(x), y.cache_key)
for x, y in self._subbranches.items()
if y is not None
),
)

Expand All @@ -81,6 +82,7 @@ def typename(self):
", ".join(
"{0}:{1}".format(x, y.typename)
for x, y in self._subbranches.items()
if y is not None
)
)

Expand All @@ -90,8 +92,11 @@ def awkward_form(self, file, index_format="i64", header=False, tobject_header=Tr
names = []
fields = []
for x, y in self._subbranches.items():
names.append(x)
fields.append(y.awkward_form(file, index_format, header, tobject_header))
if y is not None:
names.append(x)
fields.append(
y.awkward_form(file, index_format, header, tobject_header)
)

return awkward1.forms.RecordForm(fields, names)

Expand Down
Loading