diff --git a/aries_cloudagent/messaging/valid.py b/aries_cloudagent/messaging/valid.py index 6253c6b304..5af2a2d7be 100644 --- a/aries_cloudagent/messaging/valid.py +++ b/aries_cloudagent/messaging/valid.py @@ -23,60 +23,41 @@ class StrOrDictField(Field): """URI or Dict field for Marshmallow.""" - def _serialize(self, value, attr, obj, **kwargs): - return value - def _deserialize(self, value, attr, data, **kwargs): - if isinstance(value, (str, dict)): - return value - else: + if not isinstance(value, (str, dict)): raise ValidationError("Field should be str or dict") + return super()._deserialize(value, attr, data, **kwargs) class StrOrNumberField(Field): """String or Number field for Marshmallow.""" - def _serialize(self, value, attr, obj, **kwargs): - return value - def _deserialize(self, value, attr, data, **kwargs): - if isinstance(value, (str, float, int)): - return value - else: + if not isinstance(value, (str, float, int)): raise ValidationError("Field should be str or int or float") + return super()._deserialize(value, attr, data, **kwargs) class DictOrDictListField(Field): """Dict or Dict List field for Marshmallow.""" - def _serialize(self, value, attr, obj, **kwargs): - return value - def _deserialize(self, value, attr, data, **kwargs): - # dict - if isinstance(value, dict): - return value - # list of dicts - elif isinstance(value, list) and all(isinstance(item, dict) for item in value): - return value - else: - raise ValidationError("Field should be dict or list of dicts") + if not isinstance(value, dict): + if not isinstance(value, list) or not all( + isinstance(item, dict) for item in value + ): + raise ValidationError("Field should be dict or list of dicts") + return super()._deserialize(value, attr, data, **kwargs) class UriOrDictField(StrOrDictField): """URI or Dict field for Marshmallow.""" - def __init__(self, *args, **kwargs): - """Initialize new UriOrDictField instance.""" - super().__init__(*args, **kwargs) - - # Insert validation into self.validators so that multiple errors can be stored. - self.validators.insert(0, self._uri_validator) - - def _uri_validator(self, value): - # Check if URI when + def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, str): - return Uri()(value) + # Check regex + Uri()(value) + return super()._deserialize(value, attr, data, **kwargs) class IntEpoch(Range): @@ -775,7 +756,7 @@ def __call__(self, value): except ValidationError: raise ValidationError( f"credential subject id {value[0]} must be URI" - ) + ) from None return value diff --git a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py index cdbac51cd7..96a64abed1 100644 --- a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py +++ b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py @@ -23,7 +23,7 @@ def _serialize(self, value, attr, obj, **kwargs): """ return value.serialize() - def _deserialize(self, value, attr, data, **kwargs): + def _deserialize(self, value, attr=None, data=None, **kwargs): """ Deserialize a value into a DIDDoc. diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py index e461201975..be3c0b80a8 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py @@ -96,13 +96,17 @@ def _serialize(self, value, attr, obj, **kwargs): def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, dict): return Service.deserialize(value) + elif isinstance(value, Service): + return value elif isinstance(value, str): - if bool(DIDValidation.PATTERN.match(value)): - return value - else: + if not DIDValidation.PATTERN.match(value): raise ValidationError( "Service item must be a valid decentralized identifier (DID)" ) + return value + raise ValidationError( + "Service item must be a valid decentralized identifier (DID) or object" + ) class InvitationMessage(AgentMessage): @@ -221,9 +225,6 @@ class Meta: fields.Str( description="Handshake protocol", example=DIDCommPrefix.qualify_current(HSProto.RFC23.name), - validate=lambda hsp: ( - DIDCommPrefix.unqualify(hsp) in [p.name for p in HSProto] - ), ), required=False, ) @@ -276,13 +277,10 @@ def validate_fields(self, data, **kwargs): """ handshake_protocols = data.get("handshake_protocols") requests_attach = data.get("requests_attach") - if not ( - (handshake_protocols and len(handshake_protocols) > 0) - or (requests_attach and len(requests_attach) > 0) - ): + if not handshake_protocols and not requests_attach: raise ValidationError( "Model must include non-empty " - "handshake_protocols or requests_attach or both" + "handshake_protocols or requests~attach or both" ) # services = data.get("services") diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py index 801b17680e..004c7d5ca5 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py @@ -139,9 +139,8 @@ def test_invalid_invi_wrong_type_services(self): "services": [123], } - invi_schema = InvitationMessageSchema() - with pytest.raises(test_module.ValidationError): - invi_schema.validate_fields(obj_x) + errs = InvitationMessageSchema().validate(obj_x) + assert errs and "services" in errs def test_assign_msg_type_version_to_model_inst(self): test_msg = InvitationMessage()