Skip to content

Commit

Permalink
Check for memberwise serialization and raise NotImplementedError as n…
Browse files Browse the repository at this point in the history
…eeded. (#87)

* Check for memberwise serialization and raise NotImplementedError as needed.

* Added test.
  • Loading branch information
jpivarski authored Aug 31, 2020
1 parent 7f9257b commit fe0add3
Show file tree
Hide file tree
Showing 21 changed files with 465 additions and 57 deletions.
32 changes: 32 additions & 0 deletions tests/test_0087-memberwise-splitting-not-implemented-messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE

from __future__ import absolute_import

import pytest
import skhep_testdata

import uproot4


def test_issue510b():
with pytest.raises(NotImplementedError):
with uproot4.open(skhep_testdata.data_path("uproot-issue510b.root"))[
"EDepSimEvents"
] as t:
t["Event"]["Trajectories.Points"].array()


def test_issue403():
with pytest.raises(NotImplementedError):
with uproot4.open(skhep_testdata.data_path("uproot-issue403.root"))[
"Model"
] as t:
t["Model.collimatorInfo"].array()


def test_issue475():
with pytest.raises(NotImplementedError):
with uproot4.open(skhep_testdata.data_path("uproot-issue475.root"))[
"Event/Elec/ElecEvent"
] as t:
t["fElecChannels"].array()
4 changes: 4 additions & 0 deletions uproot4/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@
############# IOFeatures

kGenerateOffsetMap = numpy.uint8(1)

############# other

kStreamedMemberWise = numpy.uint16(1 << 14)
149 changes: 99 additions & 50 deletions uproot4/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,11 @@ def awkward_form(self, file, index_format="i64", header=False, tobject_header=Tr
def read(self, chunk, cursor, context, file, selffile, parent, header=True):
if self._header and header:
start_cursor = cursor.copy()
num_bytes, instance_version = uproot4.deserialization.numbytes_version(
chunk, cursor, context
)
(
num_bytes,
instance_version,
is_memberwise,
) = uproot4.deserialization.numbytes_version(chunk, cursor, context)

if self._length_bytes == "1-5":
out = cursor.string(chunk, context)
Expand Down Expand Up @@ -378,9 +380,19 @@ def awkward_form(self, file, index_format="i64", header=False, tobject_header=Tr
def read(self, chunk, cursor, context, file, selffile, parent, header=True):
if self._header and header:
start_cursor = cursor.copy()
num_bytes, instance_version = uproot4.deserialization.numbytes_version(
chunk, cursor, context
)
(
num_bytes,
instance_version,
is_memberwise,
) = uproot4.deserialization.numbytes_version(chunk, cursor, context)

if is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, selffile.file_path
)
)

if isinstance(self._values, numpy.dtype):
remainder = chunk.get(
Expand Down Expand Up @@ -530,8 +542,20 @@ def awkward_form(self, file, index_format="i64", header=False, tobject_header=Tr
def read(self, chunk, cursor, context, file, selffile, parent, header=True):
if self._header and header:
start_cursor = cursor.copy()
num_bytes, instance_version = uproot4.deserialization.numbytes_version(
chunk, cursor, context
(
num_bytes,
instance_version,
is_memberwise,
) = uproot4.deserialization.numbytes_version(chunk, cursor, context)
else:
is_memberwise = False

if is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, selffile.file_path
)
)

length = cursor.field(chunk, _stl_container_size, context)
Expand Down Expand Up @@ -675,8 +699,20 @@ def awkward_form(self, file, index_format="i64", header=False, tobject_header=Tr
def read(self, chunk, cursor, context, file, selffile, parent, header=True):
if self._header and header:
start_cursor = cursor.copy()
num_bytes, instance_version = uproot4.deserialization.numbytes_version(
chunk, cursor, context
(
num_bytes,
instance_version,
is_memberwise,
) = uproot4.deserialization.numbytes_version(chunk, cursor, context)
else:
is_memberwise = False

if is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, selffile.file_path
)
)

length = cursor.field(chunk, _stl_container_size, context)
Expand Down Expand Up @@ -864,55 +900,68 @@ def awkward_form(self, file, index_format="i64", header=False, tobject_header=Tr
def read(self, chunk, cursor, context, file, selffile, parent, header=True):
if self._header and header:
start_cursor = cursor.copy()
num_bytes, instance_version = uproot4.deserialization.numbytes_version(
chunk, cursor, context
)
cursor.skip(6)

length = cursor.field(chunk, _stl_container_size, context)

if _has_nested_header(self._keys) and header:
(
num_bytes,
instance_version,
is_memberwise,
) = uproot4.deserialization.numbytes_version(chunk, cursor, context)
cursor.skip(6)
keys = _read_nested(
self._keys,
length,
chunk,
cursor,
context,
file,
selffile,
parent,
header=False,
)
else:
is_memberwise = False

if _has_nested_header(self._values) and header:
cursor.skip(6)
values = _read_nested(
self._values,
length,
chunk,
cursor,
context,
file,
selffile,
parent,
header=False,
)
if is_memberwise:
length = cursor.field(chunk, _stl_container_size, context)

out = STLMap(keys, values)
if _has_nested_header(self._keys) and header:
cursor.skip(6)
keys = _read_nested(
self._keys,
length,
chunk,
cursor,
context,
file,
selffile,
parent,
header=False,
)

if self._header and header:
uproot4.deserialization.numbytes_check(
if _has_nested_header(self._values) and header:
cursor.skip(6)
values = _read_nested(
self._values,
length,
chunk,
start_cursor,
cursor,
num_bytes,
self.typename,
context,
file.file_path,
file,
selffile,
parent,
header=False,
)

return out
out = STLMap(keys, values)

if self._header and header:
uproot4.deserialization.numbytes_check(
chunk,
start_cursor,
cursor,
num_bytes,
self.typename,
context,
file.file_path,
)

return out

else:
raise NotImplementedError(
"""non-memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, selffile.file_path
)
)

def __eq__(self, other):
if not isinstance(other, AsMap):
Expand Down
6 changes: 5 additions & 1 deletion uproot4/deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ def numbytes_version(chunk, cursor, context, move=True):
num_bytes = None
version = cursor.field(chunk, _numbytes_version_2, context, move=move)

return num_bytes, version
is_memberwise = version & uproot4.const.kStreamedMemberWise
if is_memberwise:
version = version & ~uproot4.const.kStreamedMemberWise

return num_bytes, version, is_memberwise


def numbytes_check(
Expand Down
9 changes: 8 additions & 1 deletion uproot4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def empty(cls):
self._bases = []
self._num_bytes = None
self._instance_version = None
self._is_memberwise = False
return self

@classmethod
Expand All @@ -103,6 +104,7 @@ def read(cls, chunk, cursor, context, file, selffile, parent, concrete=None):
self._bases = []
self._num_bytes = None
self._instance_version = None
self._is_memberwise = False

old_breadcrumbs = context.get("breadcrumbs", ())
context["breadcrumbs"] = old_breadcrumbs + (self,)
Expand Down Expand Up @@ -152,6 +154,7 @@ def read_numbytes_version(self, chunk, cursor, context):
(
self._num_bytes,
self._instance_version,
self._is_memberwise,
) = uproot4.deserialization.numbytes_version(chunk, cursor, context)

def read_members(self, chunk, cursor, context, file):
Expand Down Expand Up @@ -235,6 +238,10 @@ def num_bytes(self):
def instance_version(self):
return self._instance_version

@property
def is_memberwise(self):
return self._is_memberwise

@property
def members(self):
return self._members
Expand Down Expand Up @@ -487,7 +494,7 @@ def read(cls, chunk, cursor, context, file, selffile, parent, concrete=None):
import uproot4.deserialization

start_cursor = cursor.copy()
num_bytes, version = uproot4.deserialization.numbytes_version(
(num_bytes, version, is_memberwise,) = uproot4.deserialization.numbytes_version(
chunk, cursor, context, move=False
)

Expand Down
8 changes: 8 additions & 0 deletions uproot4/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@

class Model_ROOT_3a3a_Experimental_3a3a_RNTuple(uproot4.model.Model):
def read_members(self, chunk, cursor, context, file):
if self.is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, self.file.file_path
)
)

cursor.skip(4)
(
self._members["fVersion"],
Expand Down
7 changes: 7 additions & 0 deletions uproot4/models/TArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def read_numbytes_version(self, chunk, cursor, context):
pass

def read_members(self, chunk, cursor, context, file):
if self.is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, self.file.file_path
)
)
self._members["fN"] = cursor.field(chunk, _tarray_format1, context)
self._data = cursor.array(chunk, self._members["fN"], self.dtype, context)

Expand Down
35 changes: 35 additions & 0 deletions uproot4/models/TAtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@

class Model_TAttLine_v1(uproot4.model.VersionedModel):
def read_members(self, chunk, cursor, context, file):
if self.is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, self.file.file_path
)
)
(
self._members["fLineColor"],
self._members["fLineStyle"],
Expand Down Expand Up @@ -69,6 +76,13 @@ def awkward_form(cls, file, index_format="i64", header=False, tobject_header=Tru

class Model_TAttLine_v2(uproot4.model.VersionedModel):
def read_members(self, chunk, cursor, context, file):
if self.is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, self.file.file_path
)
)
(
self._members["fLineColor"],
self._members["fLineStyle"],
Expand Down Expand Up @@ -127,6 +141,13 @@ def awkward_form(cls, file, index_format="i64", header=False, tobject_header=Tru

class Model_TAttFill_v1(uproot4.model.VersionedModel):
def read_members(self, chunk, cursor, context, file):
if self.is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, self.file.file_path
)
)
self._members["fFillColor"], self._members["fFillStyle"] = cursor.fields(
chunk, _tattfill1_format1, context
)
Expand Down Expand Up @@ -175,6 +196,13 @@ def awkward_form(cls, file, index_format="i64", header=False, tobject_header=Tru

class Model_TAttFill_v2(uproot4.model.VersionedModel):
def read_members(self, chunk, cursor, context, file):
if self.is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, self.file.file_path
)
)
self._members["fFillColor"], self._members["fFillStyle"] = cursor.fields(
chunk, _tattfill2_format1, context
)
Expand Down Expand Up @@ -226,6 +254,13 @@ def awkward_form(cls, file, index_format="i64", header=False, tobject_header=Tru

class Model_TAttMarker_v2(uproot4.model.VersionedModel):
def read_members(self, chunk, cursor, context, file):
if self.is_memberwise:
raise NotImplementedError(
"""memberwise serialization of {0}
in file {1}""".format(
type(self).__name__, self.file.file_path
)
)
(
self._members["fMarkerColor"],
self._members["fMarkerStyle"],
Expand Down
Loading

0 comments on commit fe0add3

Please sign in to comment.