Skip to content

Commit

Permalink
Fix Required If (Not) Properties (#89)
Browse files Browse the repository at this point in the history
* add bug fix

add second bug fix

add test for serialization of dictionary with None values and required_if properties

* bug fix and tests for required if not

* remove code comments

* add explanatory comments

* add function comment

* increase clarity of function doc and add link to constraint definitions
  • Loading branch information
mfleader authored Apr 13, 2023
1 parent 5fee110 commit f97757c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 3 deletions.
32 changes: 29 additions & 3 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5004,13 +5004,33 @@ def _validate_property(

@staticmethod
def _validate_not_set(data, object_property: PropertyType, path: typing.Tuple[str]):
"""
Validate required_if and required_if_not constraints on a property in the given
data object. If a constraint has been broken, then raise a ConstraintException.
For a description of the required_if constraint visit
[https://arcalot.io/arcaflow/plugins/python/schema/?h=required_if#objecttype].
For a description of the required_if_not constraint visit
[https://arcalot.io/arcaflow/plugins/python/schema/?h=required_if_not#objecttype].
:param data: a dictionary representation of an ObjectType
:param object_property: a property of an ObjectType
:param path: a traversal from data to object_property
"""
if object_property.required:
raise ConstraintException(path, "This field is required")
if object_property.required_if is not None:
for required_if in object_property.required_if:
if (isinstance(data, dict) and required_if in data) or (
hasattr(data, required_if) and getattr(data, required_if) is None
if (isinstance(data, dict) and required_if in data and data[required_if] is not None) or (
hasattr(data, required_if) and getattr(data, required_if) is not None
):
# (here, required_if refers to its value)
# if data is a dict, has this required_if as a key, and the
# dict value paired with this required_if key is not None
# or
# if data is an object with attribute required_if, and
# data.required_if is not None
raise ConstraintException(
path,
"This field is required because '{}' is set".format(
Expand All @@ -5023,10 +5043,16 @@ def _validate_not_set(data, object_property: PropertyType, path: typing.Tuple[st
):
none_set = True
for required_if_not in object_property.required_if_not:
if (isinstance(data, dict) and required_if_not in data) or (
if (isinstance(data, dict) and required_if_not in data and data[required_if_not] is not None) or (
hasattr(data, required_if_not)
and getattr(data, required_if_not) is not None
):
# (here, required_if_not refers to its value)
# if data is a dict, has this required_if_not as a key, and the
# dict value paired with this required_if_not key is not None
# or
# if data is an object with attribute required_if_not, and
# data.required_if_not is not None
none_set = False
break
if none_set:
Expand Down
56 changes: 56 additions & 0 deletions src/arcaflow_plugin_sdk/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ class OneOfData2:
self.assertIsInstance(unserialized_data, OneOfData2)



class SerializationTest(unittest.TestCase):
def test_serialization_cycle(self):
@dataclasses.dataclass
Expand Down Expand Up @@ -753,13 +754,68 @@ class TestData1:
self.assertIsNone(unserialized.A)
self.assertIsNone(unserialized.B)

unserialized = s.unserialize({"A": None, "B": None})
self.assertIsNone(unserialized.A)
self.assertIsNone(unserialized.B)

unserialized = s.unserialize({"A": "Foo"})
self.assertEqual(unserialized.A, "Foo")
self.assertIsNone(unserialized.B)

with self.assertRaises(schema.ConstraintException):
s.unserialize({"B": "Foo"})

with self.assertRaises(schema.ConstraintException):
s.validate(TestData1(B="Foo"))

with self.assertRaises(schema.ConstraintException):
s.serialize(TestData1(B="Foo"))

def test_required_if_not(self):
@dataclasses.dataclass
class TestData1:
A: typing.Optional[str] = None
B: typing.Annotated[typing.Optional[str], schema.required_if_not("A")] = None

s = schema.build_object_schema(TestData1)

with self.assertRaises(schema.ConstraintException):
s.unserialize({})

with self.assertRaises(schema.ConstraintException):
s.unserialize({"A": None, "B": None})

unserialized = s.unserialize({"A": "Foo"})
self.assertEqual(unserialized.A, "Foo")
self.assertIsNone(unserialized.B)

unserialized = s.unserialize({"B": "Foo"})
self.assertEqual(unserialized.B, "Foo")
self.assertIsNone(unserialized.A)

s.validate(TestData1(B="Foo"))
s.serialize(TestData1(B="Foo"))

@dataclasses.dataclass
class TestData2:
A: typing.Optional[str] = None
B: typing.Optional[str] = None
C: typing.Annotated[typing.Optional[str], schema.required_if_not("A"), schema.required_if_not("B")] = None

s = schema.build_object_schema(TestData2)

with self.assertRaises(schema.ConstraintException):
s.unserialize({"A": None, "B": None, "C": None})

unserialized = s.unserialize({"C": "Foo"})
self.assertIsNone(unserialized.A)
self.assertIsNone(unserialized.B)
self.assertEqual(unserialized.C, "Foo")

td2_c = TestData2(C="Foo")
s.validate(td2_c)
s.serialize(td2_c)

def test_int_optional(self):
@dataclasses.dataclass
class TestData1:
Expand Down

0 comments on commit f97757c

Please sign in to comment.