From 8e3011210f6c3eabab6dc105f2b19a5c49c08eb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89loi=20Rivard?= Date: Wed, 11 Dec 2024 16:30:53 +0100 Subject: [PATCH] feat: implement Schema[attribute_name] to access schema attributes --- doc/changelog.rst | 4 ++++ scim2_models/rfc7643/schema.py | 12 ++++++++++++ tests/test_schema.py | 14 ++++++++++++-- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 1286e9b..f3ebfc4 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,6 +4,10 @@ Changelog [0.3.0] - Unreleased -------------------- +Added +^^^^^ +- :meth:`Attribute.get_attribute ` can be called with brackets. + Changed ^^^^^^^ - Add a :paramref:`~scim2_models.BaseModel.model_validate.original` diff --git a/scim2_models/rfc7643/schema.py b/scim2_models/rfc7643/schema.py index 2390f5e..8e49276 100644 --- a/scim2_models/rfc7643/schema.py +++ b/scim2_models/rfc7643/schema.py @@ -245,6 +245,12 @@ def get_attribute(self, attribute_name: str) -> Optional["Attribute"]: return sub_attribute return None + def __getitem__(self, name): + """Find an attribute by its name.""" + if attribute := self.get_attribute(name): + return attribute + raise KeyError(f"This attribute has no '{name}' sub-attribute") + class Schema(Resource): schemas: Annotated[list[str], Required.true] = [ @@ -280,3 +286,9 @@ def get_attribute(self, attribute_name: str) -> Optional[Attribute]: if attribute.name == attribute_name: return attribute return None + + def __getitem__(self, name): + """Find an attribute by its name.""" + if attribute := self.get_attribute(name): + return attribute + raise KeyError(f"This schema has no '{name}' attribute") diff --git a/tests/test_schema.py b/tests/test_schema.py index 8d2bba9..6721940 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -92,13 +92,18 @@ def test_get_schema_attribute(load_sample): payload = load_sample("rfc7643-8.7.1-schema-user.json") schema = Schema.model_validate(payload) assert schema.get_attribute("invalid") is None + with pytest.raises(KeyError): + schema["invalid"] assert schema.attributes[0].name == "userName" assert schema.attributes[0].mutability == Mutability.read_write - schema.get_attribute("userName").mutability = Mutability.read_only + schema.get_attribute("userName").mutability = Mutability.read_only assert schema.attributes[0].mutability == Mutability.read_only + schema["userName"].mutability = Mutability.read_write + assert schema.attributes[0].mutability == Mutability.read_write + def test_get_attribute_attribute(load_sample): """Test the Schema.get_attribute method.""" @@ -107,9 +112,14 @@ def test_get_attribute_attribute(load_sample): attribute = schema.get_attribute("members") assert attribute.get_attribute("invalid") is None + with pytest.raises(KeyError): + attribute["invalid"] assert attribute.sub_attributes[0].name == "value" assert attribute.sub_attributes[0].mutability == Mutability.immutable - attribute.get_attribute("value").mutability = Mutability.read_only + attribute.get_attribute("value").mutability = Mutability.read_only assert attribute.sub_attributes[0].mutability == Mutability.read_only + + attribute["value"].mutability = Mutability.read_write + assert attribute.sub_attributes[0].mutability == Mutability.read_write