Skip to content

Commit

Permalink
fix all link identity methods
Browse files Browse the repository at this point in the history
  • Loading branch information
silentworks committed Jun 21, 2024
1 parent 6563daf commit 3570ace
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 19 deletions.
44 changes: 34 additions & 10 deletions supabase_auth/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
model_validate,
parse_auth_otp_response,
parse_auth_response,
parse_link_identity_response,
parse_sso_response,
parse_user_response,
)
Expand Down Expand Up @@ -70,6 +71,7 @@
SignUpWithPasswordCredentials,
Subscription,
UserAttributes,
UserIdentity,
UserResponse,
VerifyOtpParams,
)
Expand Down Expand Up @@ -366,10 +368,14 @@ async def sign_in_with_oauth(
params["redirect_to"] = redirect_to
if scopes:
params["scopes"] = scopes
url = await self._get_url_for_provider(provider, params)
url = await self._get_url_for_provider(
f"{self._url}/authorize", provider, params
)
return OAuthResponse(provider=provider, url=url)

async def link_identity(self, credentials):
async def link_identity(
self, credentials: SignInWithOAuthCredentials
) -> OAuthResponse:
provider = credentials.get("provider")
options = credentials.get("options", {})
redirect_to = options.get("redirect_to")
Expand All @@ -379,23 +385,40 @@ async def link_identity(self, credentials):
params["redirect_to"] = redirect_to
if scopes:
params["scopes"] = scopes
params["skip_browser_redirect"] = True
params["skip_http_redirect"] = "true"
url = await self._get_url_for_provider(
"user/identities/authorize", provider, params
)

url = await self._get_url_for_provider(provider, params)
return OAuthResponse(provider=provider, url=url)
session = await self.get_session()
if not session:
raise AuthSessionMissingError()

response = await self._request(
method="GET",
path=url,
jwt=session.access_token,
xform=parse_link_identity_response,
)
return OAuthResponse(provider=provider, url=response.url)

async def get_user_identities(self):
response = self.get_user()
response = await self.get_user()
return (
IdentitiesResponse(identities=response.user.identities)
if response.user
else AuthSessionMissingError()
)

async def unlink_identity(self, identity):
async def unlink_identity(self, identity: UserIdentity):
session = await self.get_session()
if not session:
raise AuthSessionMissingError()

return await self._request(
"POST",
f"/user/identities/{identity.id}",
"DELETE",
f"user/identities/{identity.identity_id}",
jwt=session.access_token,
)

async def sign_in_with_otp(
Expand Down Expand Up @@ -975,6 +998,7 @@ def _is_implicit_grant_flow(self, url: str) -> bool:

async def _get_url_for_provider(
self,
url: str,
provider: Provider,
params: Dict[str, str],
) -> str:
Expand All @@ -992,7 +1016,7 @@ async def _get_url_for_provider(

params["provider"] = provider
query = urlencode(params)
return f"{self._url}/authorize?{query}"
return f"{url}?{query}"

def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
"""
Expand Down
36 changes: 27 additions & 9 deletions supabase_auth/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
model_validate,
parse_auth_otp_response,
parse_auth_response,
parse_link_identity_response,
parse_sso_response,
parse_user_response,
)
Expand Down Expand Up @@ -70,6 +71,7 @@
SignUpWithPasswordCredentials,
Subscription,
UserAttributes,
UserIdentity,
UserResponse,
VerifyOtpParams,
)
Expand Down Expand Up @@ -366,10 +368,10 @@ def sign_in_with_oauth(
params["redirect_to"] = redirect_to
if scopes:
params["scopes"] = scopes
url = self._get_url_for_provider(provider, params)
url = self._get_url_for_provider(f"{self._url}/authorize", provider, params)
return OAuthResponse(provider=provider, url=url)

def link_identity(self, credentials):
def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
provider = credentials.get("provider")
options = credentials.get("options", {})
redirect_to = options.get("redirect_to")
Expand All @@ -379,10 +381,20 @@ def link_identity(self, credentials):
params["redirect_to"] = redirect_to
if scopes:
params["scopes"] = scopes
params["skip_browser_redirect"] = True
params["skip_http_redirect"] = "true"
url = self._get_url_for_provider("user/identities/authorize", provider, params)

url = self._get_url_for_provider(provider, params)
return OAuthResponse(provider=provider, url=url)
session = self.get_session()
if not session:
raise AuthSessionMissingError()

response = self._request(
method="GET",
path=url,
jwt=session.access_token,
xform=parse_link_identity_response,
)
return OAuthResponse(provider=provider, url=response.url)

def get_user_identities(self):
response = self.get_user()
Expand All @@ -392,10 +404,15 @@ def get_user_identities(self):
else AuthSessionMissingError()
)

def unlink_identity(self, identity):
def unlink_identity(self, identity: UserIdentity):
session = self.get_session()
if not session:
raise AuthSessionMissingError()

return self._request(
"POST",
f"/user/identities/{identity.id}",
"DELETE",
f"user/identities/{identity.identity_id}",
jwt=session.access_token,
)

def sign_in_with_otp(
Expand Down Expand Up @@ -973,6 +990,7 @@ def _is_implicit_grant_flow(self, url: str) -> bool:

def _get_url_for_provider(
self,
url: str,
provider: Provider,
params: Dict[str, str],
) -> str:
Expand All @@ -988,7 +1006,7 @@ def _get_url_for_provider(

params["provider"] = provider
query = urlencode(params)
return f"{self._url}/authorize?{query}"
return f"{url}?{query}"

def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
"""
Expand Down
5 changes: 5 additions & 0 deletions supabase_auth/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AuthResponse,
GenerateLinkProperties,
GenerateLinkResponse,
LinkIdentityResponse,
Session,
SSOResponse,
User,
Expand Down Expand Up @@ -77,6 +78,10 @@ def parse_auth_otp_response(data: Any) -> AuthOtpResponse:
return model_validate(AuthOtpResponse, data)


def parse_link_identity_response(data: Any) -> LinkIdentityResponse:
return model_validate(LinkIdentityResponse, data)


def parse_link_response(data: Any) -> GenerateLinkResponse:
properties = GenerateLinkProperties(
action_link=data.get("action_link"),
Expand Down
5 changes: 5 additions & 0 deletions supabase_auth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class SSOResponse(BaseModel):
url: str


class LinkIdentityResponse(BaseModel):
url: str


class IdentitiesResponse(BaseModel):
identities: List[UserIdentity]

Expand Down Expand Up @@ -151,6 +155,7 @@ def validator(cls, values: dict) -> dict:

class UserIdentity(BaseModel):
id: str
identity_id: str
user_id: str
identity_data: Dict[str, Any]
provider: str
Expand Down

0 comments on commit 3570ace

Please sign in to comment.