From 6644dcd123af9ebc6d4b40dd003e0750534b8d8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michal=20=C4=8Ciha=C5=99?= Date: Thu, 13 Feb 2025 19:49:46 +0100 Subject: [PATCH] fix: revert API changes from #986 This was a breaking change in API that social-app-django and other storages rely on. --- social_core/backends/discourse.py | 6 +++--- social_core/backends/open_id_connect.py | 6 +++--- social_core/storage.py | 11 +++++------ social_core/tests/models.py | 8 ++++---- social_core/tests/test_storage.py | 2 +- 5 files changed, 16 insertions(+), 17 deletions(-) diff --git a/social_core/backends/discourse.py b/social_core/backends/discourse.py index 8ebe99e29..398e1a743 100644 --- a/social_core/backends/discourse.py +++ b/social_core/backends/discourse.py @@ -52,8 +52,8 @@ def get_user_details(self, response): def add_nonce(self, nonce): self.strategy.storage.nonce.use(self.setting("SERVER_URL"), time.time(), nonce) - def get_nonce(self, nonce): - return self.strategy.storage.nonce.get_nonce(self.setting("SERVER_URL"), nonce) + def get(self, nonce): + return self.strategy.storage.nonce.get(self.setting("SERVER_URL"), nonce) def delete_nonce(self, nonce): self.strategy.storage.nonce.delete(nonce) @@ -79,7 +79,7 @@ def auth_complete(self, *args, **kwargs): # Validate the nonce to ensure the request was not modified response = parse_qs(decoded_params) - nonce_obj = self.get_nonce(response.get("nonce")) + nonce_obj = self.get(response.get("nonce")) if nonce_obj: self.delete_nonce(nonce_obj) else: diff --git a/social_core/backends/open_id_connect.py b/social_core/backends/open_id_connect.py index 3c9108a19..b7ee429a3 100644 --- a/social_core/backends/open_id_connect.py +++ b/social_core/backends/open_id_connect.py @@ -139,9 +139,9 @@ def get_and_store_nonce(self, url, state): self.strategy.storage.association.store(url, association) return nonce - def get_nonce(self, nonce): + def get(self, nonce): try: - return self.strategy.storage.association.get_association( + return self.strategy.storage.association.get( server_url=self.authorization_url(), handle=nonce )[0] except IndexError: @@ -166,7 +166,7 @@ def validate_claims(self, id_token): if not nonce: raise AuthTokenError(self, "Incorrect id_token: nonce") - nonce_obj = self.get_nonce(nonce) + nonce_obj = self.get(nonce) if nonce_obj: self.remove_nonce(nonce_obj.id) else: diff --git a/social_core/storage.py b/social_core/storage.py index 979acab35..57f583f6e 100644 --- a/social_core/storage.py +++ b/social_core/storage.py @@ -1,5 +1,7 @@ """Models mixins for Social Auth""" +from __future__ import annotations + import base64 import re import uuid @@ -198,7 +200,7 @@ def use(cls, server_url, timestamp, salt): raise NotImplementedError("Implement in subclass") @classmethod - def get_nonce(cls, server_url, salt): + def get(cls, server_url, salt): """Retrieve a Nonce instance""" raise NotImplementedError("Implement in subclass") @@ -224,10 +226,7 @@ def oids(cls, server_url, handle=None): if handle is not None: kwargs["handle"] = handle return sorted( - ( - (assoc.id, cls.openid_association(assoc)) - for assoc in cls.get_association(**kwargs) - ), + ((assoc.id, cls.openid_association(assoc)) for assoc in cls.get(**kwargs)), key=lambda x: x[1].issued, reverse=True, ) @@ -251,7 +250,7 @@ def store(cls, server_url, association): raise NotImplementedError("Implement in subclass") @classmethod - def get_association(cls, *args, **kwargs): + def get(cls, server_url: str | None = None, handle: str | None = None): """Get an Association instance""" raise NotImplementedError("Implement in subclass") diff --git a/social_core/tests/models.py b/social_core/tests/models.py index 2d05bc12b..66a7f596d 100644 --- a/social_core/tests/models.py +++ b/social_core/tests/models.py @@ -169,7 +169,7 @@ def use(cls, server_url, timestamp, salt): return nonce @classmethod - def get_nonce(cls, server_url, salt): + def get(cls, server_url, salt): return TestNonce.cache[server_url] @classmethod @@ -204,10 +204,10 @@ def store(cls, server_url, association): assoc.save() @classmethod - def get_association( + def get( # type: ignore[override] cls: type[TestAssociation], - server_url=None, - handle=None, + server_url: str | None = None, + handle: str | None = None, ) -> list[AssociationMixin]: result = [] for assoc in TestAssociation.cache.values(): diff --git a/social_core/tests/test_storage.py b/social_core/tests/test_storage.py index e4f4ece7d..3e328d041 100644 --- a/social_core/tests/test_storage.py +++ b/social_core/tests/test_storage.py @@ -90,7 +90,7 @@ def test_store(self): def test_get(self): with self.assertRaisesRegex(NotImplementedError, NOT_IMPLEMENTED_MSG): - self.association.get_association() + self.association.get() def test_remove(self): with self.assertRaisesRegex(NotImplementedError, NOT_IMPLEMENTED_MSG):