Skip to content

Commit

Permalink
Update behaviour of subject-id requirements entity attribute
Browse files Browse the repository at this point in the history
When the subject-id requiment is "any", both the subject-id and pairwise-id should be processsed.

Signed-off-by: Ivan Kanakarakis <[email protected]>
  • Loading branch information
c00kiemon5ter committed Feb 14, 2023
1 parent 936ce58 commit a9fe345
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 33 deletions.
11 changes: 6 additions & 5 deletions src/saml2/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,11 +556,12 @@ def restrict(self, ava, sp_entity_id, metadata=None):

metadata_store = metadata or self.metadata_store
spec = metadata_store.attribute_requirement(sp_entity_id) or {} if metadata_store else {}
required_attributes = spec.get("required", [])
optional_attributes = spec.get("optional", [])
required_subject_id = metadata_store.subject_id_requirement(sp_entity_id) if metadata_store else None
if required_subject_id and required_subject_id not in required_attributes:
required_attributes.append(required_subject_id)
required_attributes = spec.get("required") or []
optional_attributes = spec.get("optional") or []
requirements_subject_id = metadata_store.subject_id_requirement(sp_entity_id) if metadata_store else []
for r in requirements_subject_id:
if r not in required_attributes:
required_attributes.extend(r)
return self.filter(
ava,
sp_entity_id,
Expand Down
70 changes: 47 additions & 23 deletions src/saml2/mdstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,15 @@ def all_locations(srvs):
return values


def attribute_requirement(entity, index=None):
def attribute_requirement(entity_descriptor, index=None):
res = {"required": [], "optional": []}
for acs in entity["attribute_consuming_service"]:
acss = entity_descriptor.get("attribute_consuming_service") or []
for acs in acss:
if index is not None and acs["index"] != index:
continue

for attr in acs["requested_attribute"]:
if "is_required" in attr and attr["is_required"] == "true":
if attr.get("is_required") == "true":
res["required"].append(attr)
else:
res["optional"].append(attr)
Expand Down Expand Up @@ -676,24 +677,26 @@ def service(self, entity_id, typ, service, binding=None):
return res

def attribute_requirement(self, entity_id, index=None):
"""Returns what attributes the SP requires and which are optional
"""
Returns what attributes the SP requires and which are optional
if any such demands are registered in the Metadata.
In case the metadata have multiple SPSSODescriptor elements,
the sum of the required and optional attributes is returned.
:param entity_id: The entity id of the SP
:param index: which of the attribute consumer services its all about
if index=None then return all attributes expected by all
attribute_consuming_services.
:return: 2-tuple, list of required and list of optional attributes
:return: dict of required and optional list of attributes
"""
res = {"required": [], "optional": []}

try:
for sp in self[entity_id]["spsso_descriptor"]:
_res = attribute_requirement(sp, index)
res["required"].extend(_res["required"])
res["optional"].extend(_res["optional"])
except KeyError:
return None
sp_descriptors = self[entity_id].get("spsso_descriptor") or []
for sp_desc in sp_descriptors:
_res = attribute_requirement(sp_desc, index)
res["required"].extend(_res.get("required") or [])
res["optional"].extend(_res.get("optional") or [])

return res

Expand Down Expand Up @@ -1297,35 +1300,56 @@ def discovery_response(self, entity_id, binding=None, _="spsso"):
)

def attribute_requirement(self, entity_id, index=None):
for _md in self.metadata.values():
if entity_id in _md:
return _md.attribute_requirement(entity_id, index)
for md_source in self.metadata.values():
if entity_id in md_source:
return md_source.attribute_requirement(entity_id, index)

def subject_id_requirement(self, entity_id):
try:
entity_attributes = self.entity_attributes(entity_id)
except KeyError:
return None
return []

if "urn:oasis:names:tc:SAML:profiles:subject-id:req" in entity_attributes:
subject_id_req = entity_attributes["urn:oasis:names:tc:SAML:profiles:subject-id:req"][0]
if subject_id_req == "any" or subject_id_req == "pairwise-id":
return {
subject_id_reqs = entity_attributes.get("urn:oasis:names:tc:SAML:profiles:subject-id:req") or []
subject_id_req = next(iter(subject_id_reqs), None)
if subject_id_req == "any":
return [
{
"__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute",
"name": "urn:oasis:names:tc:SAML:attribute:pairwise-id",
"name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"friendly_name": "pairwise-id",
"is_required": "true",
},
{
"__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute",
"name": "urn:oasis:names:tc:SAML:attribute:subject-id",
"name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"friendly_name": "subject-id",
"is_required": "true",
}
]
elif subject_id_req == "pairwise-id":
return [
{
"__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute",
"name": "urn:oasis:names:tc:SAML:attribute:pairwise-id",
"name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"friendly_name": "pairwise-id",
"is_required": "true",
}
elif subject_id_req == "subject-id":
return {
]
elif subject_id_req == "subject-id":
return [
{
"__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute",
"name": "urn:oasis:names:tc:SAML:attribute:subject-id",
"name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"friendly_name": "subject-id",
"is_required": "true",
}
return None
]
return []

def keys(self):
res = []
Expand Down
23 changes: 18 additions & 5 deletions tests/test_30_mdstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,11 +664,24 @@ def test_subject_id_requirement():
mds = MetadataStore(ATTRCONV, sec_config, disable_ssl_certificate_validation=True)
mds.imp(METADATACONF["17"])
required_subject_id = mds.subject_id_requirement(entity_id="https://esi-coco.example.edu/saml2/metadata/")
assert required_subject_id["__class__"] == "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute"
assert required_subject_id["name"] == "urn:oasis:names:tc:SAML:attribute:pairwise-id"
assert required_subject_id["name_format"] == "urn:oasis:names:tc:SAML:2.0:attrname-format:uri"
assert required_subject_id["friendly_name"] == "pairwise-id"
assert required_subject_id["is_required"] == "true"
expected = [
{
"__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute",
"name": "urn:oasis:names:tc:SAML:attribute:pairwise-id",
"name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"friendly_name": "pairwise-id",
"is_required": "true",
},
{
"__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute",
"name": "urn:oasis:names:tc:SAML:attribute:subject-id",
"name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
"friendly_name": "subject-id",
"is_required": "true",
},
]
assert required_subject_id
assert all(e in expected for e in required_subject_id)


def test_extension():
Expand Down

0 comments on commit a9fe345

Please sign in to comment.