diff --git a/gotrue/__init__.py b/gotrue/__init__.py index 6db087d8..f478859f 100644 --- a/gotrue/__init__.py +++ b/gotrue/__init__.py @@ -2,13 +2,12 @@ __version__ = "0.5.4" -from ._async.api import AsyncGoTrueAPI -from ._async.client import AsyncGoTrueClient -from ._async.storage import AsyncMemoryStorage, AsyncSupportedStorage -from ._sync.api import SyncGoTrueAPI -from ._sync.client import SyncGoTrueClient -from ._sync.storage import SyncMemoryStorage, SyncSupportedStorage -from .types import * - -Client = SyncGoTrueClient -GoTrueAPI = SyncGoTrueAPI +from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI # type: ignore # noqa: F401 +from ._async.gotrue_client import AsyncGoTrueClient # type: ignore # noqa: F401 +from ._async.storage import AsyncMemoryStorage # type: ignore # noqa: F401 +from ._async.storage import AsyncSupportedStorage # type: ignore # noqa: F401 +from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI # type: ignore # noqa: F401 +from ._sync.gotrue_client import SyncGoTrueClient # type: ignore # noqa: F401 +from ._sync.storage import SyncMemoryStorage # type: ignore # noqa: F401 +from ._sync.storage import SyncSupportedStorage # type: ignore # noqa: F401 +from .types import * # type: ignore # noqa: F401, F403 diff --git a/gotrue/_async/api.py b/gotrue/_async/api.py deleted file mode 100644 index e0b1a7b9..00000000 --- a/gotrue/_async/api.py +++ /dev/null @@ -1,642 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Union - -from pydantic import parse_obj_as - -from ..exceptions import APIError -from ..helpers import check_response, encode_uri_component -from ..http_clients import AsyncClient -from ..types import ( - CookieOptions, - LinkType, - Provider, - Session, - User, - UserAttributes, - determine_session_or_user_model_from_response, -) - - -class AsyncGoTrueAPI: - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[AsyncClient] = None, - ) -> None: - """Initialise API class.""" - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or AsyncClient() - - async def __aenter__(self) -> AsyncGoTrueAPI: - return self - - async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: - await self.close() - - async def close(self) -> None: - await self.http_client.aclose() - - async def create_user(self, *, attributes: UserAttributes) -> User: - """Creates a new user. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - attributes: UserAttributes - The data you want to create the user with. - - Returns - ------- - response : User - The created user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = attributes.dict() - url = f"{self.url}/admin/users" - response = await self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - async def list_users(self) -> List[User]: - """Get a list of users. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Returns - ------- - response : List[User] - A list of users - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - url = f"{self.url}/admin/users" - response = await self.http_client.get(url, headers=headers) - check_response(response) - users = response.json().get("users") - if users is None: - raise APIError("No users found in response", 400) - if not isinstance(users, list): - raise APIError("Expected a list of users", 400) - return parse_obj_as(List[User], users) - - async def sign_up_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password, "data": data} - url = f"{self.url}/signup{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def sign_in_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - ) -> Session: - """Logs in an existing user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "?grant_type=password" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string += f"&redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password} - url = f"{self.url}/token{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def sign_up_with_phone( - self, - *, - phone: str, - password: str, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Signs up a new user using their phone number and a password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = {"phone": phone, "password": password, "data": data} - url = f"{self.url}/signup" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def sign_in_with_phone( - self, - *, - phone: str, - password: str, - ) -> Session: - """Logs in an existing user using their phone number and password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - data = {"phone": phone, "password": password} - url = f"{self.url}/token?grant_type=password" - headers = self.headers - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def send_magic_link_email( - self, - *, - email: str, - create_user: bool, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a magic login link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "create_user": create_user} - url = f"{self.url}/magiclink{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - async def send_mobile_otp(self, *, phone: str, create_user: bool) -> None: - """Sends a mobile OTP via SMS. - Will register the account if it doesn't already exist - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = {"phone": phone, "create_user": create_user} - url = f"{self.url}/otp" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - async def verify_mobile_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Send User supplied Mobile OTP to be verified - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - token : str - Token that user was sent to their mobile phone - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = { - "phone": phone, - "token": token, - "type": "sms", - } - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/verify" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def invite_user_by_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> User: - """Sends an invite link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "data": data} - url = f"{self.url}/invite{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - async def reset_password_for_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a reset request to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email} - url = f"{self.url}/recover{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def _create_request_headers(self, *, jwt: str) -> Dict[str, str]: - """Create temporary object. - - Create a temporary object with all configured headers and adds the - Authorization token to be used on request methods. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - headers : dict of str - The headers required for a successful request statement with the - supabase backend. - """ - headers = {**self.headers, "Authorization": f"Bearer {jwt}"} - return headers - - async def sign_out(self, *, jwt: str) -> None: - """Removes a logged-in session. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/logout" - await self.http_client.post(url, headers=headers) - - async def get_url_for_provider( - self, - *, - provider: Provider, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - ) -> str: - """Generates the relevant login URL for a third-party provider. - - Parameters - ---------- - provider : Provider - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - url : str - The URL to redirect the user to. - - Raises - ------ - error : APIError - If an error occurs - """ - url_params = [f"provider={encode_uri_component(provider)}"] - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - url_params.append(f"redirect_to={redirect_to_encoded}") - if scopes: - url_params.append(f"scopes={encode_uri_component(scopes)}") - return f"{self.url}/authorize?{'&'.join(url_params)}" - - async def get_user(self, *, jwt: str) -> User: - """Gets the user details. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/user" - response = await self.http_client.get(url, headers=headers) - return User.parse_response(response) - - async def update_user( - self, - *, - jwt: str, - attributes: UserAttributes, - ) -> User: - """ - Updates the user data. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - attributes : UserAttributes - The data you want to update. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - data = attributes.dict() - url = f"{self.url}/user" - response = await self.http_client.put(url, json=data, headers=headers) - return User.parse_response(response) - - async def delete_user(self, *, uid: str, jwt: str) -> None: - """Delete a user. Requires a `service_role` key. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - uid : str - The user uid you want to remove. - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/admin/users/{uid}" - response = await self.http_client.delete(url, headers=headers) - return check_response(response) - - async def refresh_access_token(self, *, refresh_token: str) -> Session: - """Generates a new JWT. - - Parameters - ---------- - refresh_token : str - A valid refresh token that was returned on login. - - Returns - ------- - response : Session - A session - - Raises - ------ - error : APIError - If an error occurs - """ - data = {"refresh_token": refresh_token} - url = f"{self.url}/token?grant_type=refresh_token" - headers = self.headers - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def generate_link( - self, - *, - type: LinkType, - email: str, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """ - Generates links to be sent via email or other. - - Parameters - ---------- - type : LinkType - The link type ("signup" or "magiclink" or "recovery" or "invite"). - email : str - The user's email. - password : Optional[str] - User password. For signup only. - redirect_to : Optional[str] - The link type ("signup" or "magiclink" or "recovery" or "invite"). - data : Optional[Dict[str, Any]] - Optional user metadata. For signup only. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = { - "type": type, - "email": email, - "data": data, - } - if password: - data["password"] = password - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/admin/generate_link" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def set_auth_cookie(self, *, req, res): - """Stub for parity with JS api.""" - raise NotImplementedError("set_auth_cookie not implemented.") - - async def get_user_by_cookie(self, *, req): - """Stub for parity with JS api.""" - raise NotImplementedError("get_user_by_cookie not implemented.") diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py deleted file mode 100644 index 6073cc71..00000000 --- a/gotrue/_async/client.py +++ /dev/null @@ -1,648 +0,0 @@ -from __future__ import annotations - -from functools import partial -from json import dumps, loads -from threading import Timer -from time import time -from typing import Any, Callable, Dict, Optional, Tuple, Union, cast -from urllib.parse import parse_qs, urlparse -from uuid import uuid4 - -from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY -from ..exceptions import APIError -from ..types import ( - AuthChangeEvent, - CookieOptions, - Provider, - Session, - Subscription, - User, - UserAttributes, - UserAttributesDict, -) -from .api import AsyncGoTrueAPI -from .storage import AsyncMemoryStorage, AsyncSupportedStorage - - -class AsyncGoTrueClient: - def __init__( - self, - *, - url: str = GOTRUE_URL, - headers: Dict[str, str] = {}, - auto_refresh_token: bool = True, - persist_session: bool = True, - local_storage: AsyncSupportedStorage = AsyncMemoryStorage(), - cookie_options: CookieOptions = CookieOptions.parse_obj(COOKIE_OPTIONS), - api: Optional[AsyncGoTrueAPI] = None, - replace_default_headers: bool = False, - ) -> None: - """Create a new client - - url : str - The URL of the GoTrue server. - headers : Dict[str, str] - Any additional headers to send to the GoTrue server. - auto_refresh_token : bool - Set to "true" if you want to automatically refresh the token before - expiring. - persist_session : bool - Set to "true" if you want to automatically save the user session - into local storage. - local_storage : SupportedStorage - The storage engine to use for persisting the session. - cookie_options : CookieOptions - The options for the cookie. - """ - if url.startswith("http://"): - print( - "Warning:\n\nDO NOT USE HTTP IN PRODUCTION FOR GOTRUE EVER!\n" - "GoTrue REQUIRES HTTPS to work securely." - ) - self.state_change_emitters: Dict[str, Subscription] = {} - self.refresh_token_timer: Optional[Timer] = None - self.current_user: Optional[User] = None - self.current_session: Optional[Session] = None - self.auto_refresh_token = auto_refresh_token - self.persist_session = persist_session - self.local_storage = local_storage - empty_or_default_headers = {} if replace_default_headers else DEFAULT_HEADERS - args = { - "url": url, - "headers": {**empty_or_default_headers, **headers}, - "cookie_options": cookie_options, - } - self.api = api or AsyncGoTrueAPI(**args) - - async def __aenter__(self) -> AsyncGoTrueClient: - return self - - async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: - await self.close() - - async def close(self) -> None: - await self.api.close() - - async def init_recover(self) -> None: - """Recover the current session from local storage.""" - await self._recover_session() - await self._recover_and_refresh() - - async def sign_up( - self, - *, - email: Optional[str] = None, - phone: Optional[str] = None, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user. If email and phone are provided, email will be - used and phone will be ignored. - - Parameters - --------- - email : Optional[str] - The user's email address. - phone : Optional[str] - The user's phone number. - password : Optional[str] - The user's password. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - await self._remove_session() - - if email and password: - response = await self.api.sign_up_with_email( - email=email, - password=password, - redirect_to=redirect_to, - data=data, - ) - elif phone and password: - response = await self.api.sign_up_with_phone( - phone=phone, password=password, data=data - ) - elif not password: - raise ValueError("Password must be defined, can't be None.") - else: - raise ValueError("Email or phone must be defined, both can't be None.") - - if isinstance(response, Session): - # The user has confirmed their email or the underlying DB doesn't - # require email confirmation. - await self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - async def sign_in( - self, - *, - email: Optional[str] = None, - phone: Optional[str] = None, - password: Optional[str] = None, - refresh_token: Optional[str] = None, - provider: Optional[Provider] = None, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - create_user: bool = False, - ) -> Optional[Union[Session, str]]: - """Log in an existing user, or login via a third-party provider. - If email and phone are provided, email will be used and phone will be ignored. - - Parameters - --------- - email : Optional[str] - The user's email address. - phone : Optional[str] - The user's phone number. - password : Optional[str] - The user's password. - refresh_token : Optional[str] - A valid refresh token that was returned on login. - provider : Optional[Provider] - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - response : Optional[Union[Session, str]] - If only email are provided between the email and password, - None is returned and send magic link to email - - If email and password are provided, a logged-in session is returned. - - If only phone are provided between the phone and password, - None is returned and send message to phone - - If phone and password are provided, a logged-in session is returned. - - If refresh_token is provided, a logged-in session is returned. - - If provider is provided, an redirect URL is returned. - - Otherwise, error is raised. - - Raises - ------ - error : APIError - If an error occurs - """ - await self._remove_session() - if email: - if password: - response = await self._handle_email_sign_in( - email=email, - password=password, - redirect_to=redirect_to, - ) - else: - response = await self.api.send_magic_link_email( - email=email, create_user=create_user - ) - elif phone: - if password: - response = await self._handle_phone_sign_in( - phone=phone, password=password - ) - else: - response = await self.api.send_mobile_otp( - phone=phone, create_user=create_user - ) - elif refresh_token: - # current_session and current_user will be updated to latest - # on _call_refresh_token using the passed refresh_token - await self._call_refresh_token(refresh_token=refresh_token) - response = self.current_session - elif provider: - response = await self._handle_provider_sign_in( - provider=provider, - redirect_to=redirect_to, - scopes=scopes, - ) - else: - raise ValueError( - "Email, phone, refresh_token, or provider must be defined, " - "all can't be None." - ) - return response - - async def verify_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Log in a user given a User supplied OTP received via mobile. - - Parameters - ---------- - phone : str - The user's phone number. - token : str - The user's OTP. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - await self._remove_session() - response = await self.api.verify_mobile_otp( - phone=phone, - token=token, - redirect_to=redirect_to, - ) - if isinstance(response, Session): - await self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def user(self) -> Optional[User]: - """Returns the user data, if there is a logged in user.""" - return self.current_user - - def session(self) -> Optional[Session]: - """Returns the session data, if there is an active session.""" - return self.current_session - - async def refresh_session(self) -> Session: - """Force refreshes the session. - - Force refreshes the session including the user data incase it was - updated in a different session. - """ - if not self.current_session: - raise ValueError("Not logged in.") - return await self._call_refresh_token() - - async def update( - self, *, attributes: Union[UserAttributesDict, UserAttributes] - ) -> User: - """Updates user data, if there is a logged in user. - - Parameters - ---------- - attributes : UserAttributesDict | UserAttributes - Attributes to update, could be: email, password, email_change_token, data - - Returns - ------- - response : User - The updated user data. - - Raises - ------ - error : APIError - If an error occurs - """ - if not self.current_session: - raise ValueError("Not logged in.") - - if isinstance(attributes, dict): - attributes_to_update = UserAttributes(**attributes) - else: - attributes_to_update = attributes - - response = await self.api.update_user( - jwt=self.current_session.access_token, - attributes=attributes_to_update, - ) - self.current_session.user = response - await self._save_session(session=self.current_session) - self._notify_all_subscribers(event=AuthChangeEvent.USER_UPDATED) - return response - - async def set_session(self, *, refresh_token: str) -> Session: - """Sets the session data from refresh_token and returns current Session - - Parameters - ---------- - refresh_token : str - A JWT token - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - response = await self.api.refresh_access_token(refresh_token=refresh_token) - await self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - async def set_auth(self, *, access_token: str) -> Session: - """Overrides the JWT on the current client. The JWT will then be sent in - all subsequent network requests. - - Parameters - ---------- - access_token : str - A JWT token - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - session = Session( - access_token=access_token, - token_type="bearer", - user=None, - expires_in=None, - expires_at=None, - refresh_token=None, - provider_token=None, - ) - if self.current_session: - session.expires_in = self.current_session.expires_in - session.expires_at = self.current_session.expires_at - session.refresh_token = self.current_session.refresh_token - session.provider_token = self.current_session.provider_token - await self._save_session(session=session) - return session - - async def get_session_from_url( - self, - *, - url: str, - store_session: bool = False, - ) -> Session: - """Gets the session data from a URL string. - - Parameters - ---------- - url : str - The URL string. - store_session : bool - Optionally store the session in the browser - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - data = urlparse(url) - query = parse_qs(data.query) - error_description = query.get("error_description") - access_token = query.get("access_token") - expires_in = query.get("expires_in") - refresh_token = query.get("refresh_token") - token_type = query.get("token_type") - if error_description: - raise APIError(error_description[0], 400) - if not access_token or not access_token[0]: - raise APIError("No access_token detected.", 400) - if not refresh_token or not refresh_token[0]: - raise APIError("No refresh_token detected.", 400) - if not token_type or not token_type[0]: - raise APIError("No token_type detected.", 400) - if not expires_in or not expires_in[0]: - raise APIError("No expires_in detected.", 400) - try: - expires_at = round(time()) + int(expires_in[0]) - except ValueError: - raise APIError("Invalid expires_in.", 400) - response = await self.api.get_user(jwt=access_token[0]) - provider_token = query.get("provider_token") - session = Session( - access_token=access_token[0], - token_type=token_type[0], - user=response, - expires_in=int(expires_in[0]), - expires_at=expires_at, - refresh_token=refresh_token[0], - provider_token=provider_token[0] if provider_token else None, - ) - if store_session: - await self._save_session(session=session) - recovery_mode = query.get("type") - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - if recovery_mode and recovery_mode[0] == "recovery": - self._notify_all_subscribers(event=AuthChangeEvent.PASSWORD_RECOVERY) - return session - - async def sign_out(self) -> None: - """Log the user out.""" - access_token: Optional[str] = None - if self.current_session: - access_token = self.current_session.access_token - await self._remove_session() - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_OUT) - if access_token: - await self.api.sign_out(jwt=access_token) - - def _unsubscribe(self, *, id: str) -> None: - """Unsubscribe from a subscription.""" - self.state_change_emitters.pop(id) - - def on_auth_state_change( - self, - *, - callback: Callable[[AuthChangeEvent, Optional[Session]], None], - ) -> Subscription: - """Receive a notification every time an auth event happens. - - Parameters - ---------- - callback : Callable[[AuthChangeEvent, Optional[Session]], None] - The callback to call when an auth event happens. - - Returns - ------- - subscription : Subscription - A subscription object which can be used to unsubscribe itself. - - Raises - ------ - error : APIError - If an error occurs - """ - unique_id = uuid4() - subscription = Subscription( - id=unique_id, - callback=callback, - unsubscribe=partial(self._unsubscribe, id=unique_id.hex), - ) - self.state_change_emitters[unique_id.hex] = subscription - return subscription - - async def _handle_email_sign_in( - self, - *, - email: str, - password: str, - redirect_to: Optional[str], - ) -> Session: - """Sign in with email and password.""" - response = await self.api.sign_in_with_email( - email=email, - password=password, - redirect_to=redirect_to, - ) - await self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - async def _handle_phone_sign_in(self, *, phone: str, password: str) -> Session: - """Sign in with phone and password.""" - response = await self.api.sign_in_with_phone(phone=phone, password=password) - await self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - async def _handle_provider_sign_in( - self, - *, - provider: Provider, - redirect_to: Optional[str], - scopes: Optional[str], - ) -> str: - """Sign in with provider.""" - return await self.api.get_url_for_provider( - provider=provider, - redirect_to=redirect_to, - scopes=scopes, - ) - - async def _recover_common(self) -> Optional[Tuple[Session, int, int]]: - """Recover common logic""" - json = await self.local_storage.get_item(STORAGE_KEY) - if not json: - return - data = loads(json) - session_raw = data.get("session") - expires_at_raw = data.get("expires_at") - if ( - expires_at_raw - and isinstance(expires_at_raw, int) - and session_raw - and isinstance(session_raw, dict) - ): - session = Session.parse_obj(session_raw) - expires_at = int(expires_at_raw) - time_now = round(time()) - return session, expires_at, time_now - - async def _recover_session(self) -> None: - """Attempts to get the session from LocalStorage""" - result = await self._recover_common() - if not result: - return - session, expires_at, time_now = result - if expires_at >= time_now: - await self._save_session(session=session) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - - async def _recover_and_refresh(self) -> None: - """Recovers the session from LocalStorage and refreshes""" - result = await self._recover_common() - if not result: - return - session, expires_at, time_now = result - if expires_at < time_now and self.auto_refresh_token and session.refresh_token: - try: - await self._call_refresh_token(refresh_token=session.refresh_token) - except APIError: - await self._remove_session() - elif expires_at < time_now or not session or not session.user: - await self._remove_session() - else: - await self._save_session(session=session) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - - async def _call_refresh_token( - self, *, refresh_token: Optional[str] = None - ) -> Session: - if refresh_token is None: - if self.current_session: - refresh_token = self.current_session.refresh_token - else: - raise ValueError("No current session and refresh_token not supplied.") - response = await self.api.refresh_access_token( - refresh_token=cast(str, refresh_token) - ) - await self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.TOKEN_REFRESHED) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def _notify_all_subscribers(self, *, event: AuthChangeEvent) -> None: - """Notify all subscribers that auth event happened.""" - for value in self.state_change_emitters.values(): - value.callback(event, self.current_session) - - async def _save_session(self, *, session: Session) -> None: - """Save session to client.""" - self.current_session = session - self.current_user = session.user - if session.expires_at: - time_now = round(time()) - expire_in = session.expires_at - time_now - refresh_duration_before_expires = 60 if expire_in > 60 else 0.5 - self._start_auto_refresh_token( - value=(expire_in - refresh_duration_before_expires) * 1000 - ) - if self.persist_session and session.expires_at: - await self._persist_session(session=session) - - async def _persist_session(self, *, session: Session) -> None: - data = {"session": session.dict(), "expires_at": session.expires_at} - await self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str)) - - async def _remove_session(self) -> None: - """Remove the session.""" - self.current_session = None - self.current_user = None - if self.refresh_token_timer: - self.refresh_token_timer.cancel() - await self.local_storage.remove_item(STORAGE_KEY) - - def _start_auto_refresh_token(self, *, value: float) -> None: - if self.refresh_token_timer: - self.refresh_token_timer.cancel() - if value <= 0 or not self.auto_refresh_token: - return - self.refresh_token_timer = Timer(value, self._call_refresh_token) - self.refresh_token_timer.start() diff --git a/gotrue/_async/gotrue_admin_api.py b/gotrue/_async/gotrue_admin_api.py new file mode 100644 index 00000000..7a568cf1 --- /dev/null +++ b/gotrue/_async/gotrue_admin_api.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from typing import Dict, List, Union + +from ..helpers import parse_link_response, parse_user_response +from ..http_clients import AsyncClient +from ..types import ( + AdminUserAttributes, + GenerateLinkParams, + GenerateLinkResponse, + Options, + User, + UserResponse, +) +from .gotrue_base_api import AsyncGoTrueBaseAPI + + +class AsyncGoTrueAdminAPI(AsyncGoTrueBaseAPI): + def __init__( + self, + *, + url: str = "", + headers: Dict[str, str] = {}, + http_client: Union[AsyncClient, None] = None, + ) -> None: + AsyncGoTrueBaseAPI.__init__( + self, + url=url, + headers=headers, + http_client=http_client, + ) + + async def sign_out(self, jwt: str) -> None: + """ + Removes a logged-in session. + """ + return await self._request("POST", "logout", jwt=jwt) + + async def invite_user_by_email( + self, + email: str, + options: Options = {}, + ) -> UserResponse: + """ + Sends an invite link to an email address. + """ + return await self._request( + "POST", + "invite", + body={"email": email, "data": options.get("data")}, + redirect_to=options.get("redirect_to"), + xform=parse_user_response, + ) + + async def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: + """ + Generates email links and OTPs to be sent via a custom email provider. + """ + return await self._request( + "POST", + "admin/generate_link", + body={ + "type": params.get("type"), + "email": params.get("email"), + "password": params.get("password"), + "new_email": params.get("new_email"), + "data": params.get("options", {}).get("data"), + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_link_response, + ) + + # User Admin API + + async def create_user(self, attributes: AdminUserAttributes) -> UserResponse: + """ + Creates a new user. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "POST", + "admin/users", + body=attributes, + xform=parse_user_response, + ) + + async def list_users(self) -> List[User]: + """ + Get a list of users. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "GET", + "admin/users", + xform=lambda data: [User.parse_obj(user) for user in data], + ) + + async def get_user_by_id(self, uid: str) -> UserResponse: + """ + Get user by id. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "GET", + f"admin/users/{uid}", + xform=parse_user_response, + ) + + async def update_user_by_id( + self, + uid: str, + attributes: AdminUserAttributes, + ) -> UserResponse: + """ + Updates the user data. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "PUT", + f"admin/users/{uid}", + body=attributes, + xform=parse_user_response, + ) + + async def delete_user(self, id: str) -> UserResponse: + """ + Delete a user. Requires a `service_role` key. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "DELETE", + f"admin/users/{id}", + xform=parse_user_response, + ) diff --git a/gotrue/_async/gotrue_base_api.py b/gotrue/_async/gotrue_base_api.py new file mode 100644 index 00000000..89ed39d7 --- /dev/null +++ b/gotrue/_async/gotrue_base_api.py @@ -0,0 +1,98 @@ +from __future__ import annotations +from typing import Any, Callable, Dict, Literal, TypeVar, Union, overload + +from pydantic import BaseModel +from typing_extensions import Self + +from ..helpers import handle_exception +from ..http_clients import AsyncClient + +T = TypeVar("T") + + +class AsyncGoTrueBaseAPI: + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + http_client: Union[AsyncClient, None], + ): + self._url = url + self._headers = headers + self._http_client = http_client or AsyncClient() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: + await self.close() + + async def close(self) -> None: + await self._http_client.aclose() + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Union[bool, None] = None, + xform: Callable[[Any], T], + ) -> T: + ... + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Union[bool, None] = None, + ) -> None: + ... + + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Union[bool, None] = None, + xform: Union[Callable[[Any], T], None] = None, + ) -> Union[T, None]: + url = f"{self._url}/{path}" + headers = {**self._headers, **(headers or {})} + if jwt: + headers["Authorization"] = f"Bearer {jwt}" + query = query or {} + if redirect_to: + query["redirect_to"] = redirect_to + try: + response = await self._http_client.request( + method, + url, + headers=headers, + params=query, + json=body.dict() if isinstance(body, BaseModel) else body, + ) + response.raise_for_status() + result = response if no_resolve_json else response.json() + if xform: + return xform(result) + except Exception as e: + raise handle_exception(e) diff --git a/gotrue/_async/gotrue_client.py b/gotrue/_async/gotrue_client.py new file mode 100644 index 00000000..a19772ef --- /dev/null +++ b/gotrue/_async/gotrue_client.py @@ -0,0 +1,683 @@ +from __future__ import annotations + +from base64 import b64decode +from json import loads +from time import time +from typing import Callable, Dict, List, Tuple, Union +from urllib.parse import parse_qs, quote, urlencode, urlparse +from uuid import uuid4 + +from ..constants import ( + DEFAULT_HEADERS, + EXPIRY_MARGIN, + GOTRUE_URL, + MAX_RETRIES, + RETRY_INTERVAL, + STORAGE_KEY, +) +from ..errors import ( + AuthImplicitGrantRedirectError, + AuthInvalidCredentialsError, + AuthRetryableError, + AuthSessionMissingError, +) +from ..helpers import parse_auth_response, parse_user_response +from ..http_clients import AsyncClient +from ..timer import Timer +from ..types import ( + AuthChangeEvent, + AuthResponse, + OAuthResponse, + Options, + Provider, + Session, + SignInWithOAuthCredentials, + SignInWithPasswordCredentials, + SignInWithPasswordlessCredentials, + SignUpWithPasswordCredentials, + Subscription, + UserAttributes, + UserResponse, + VerifyOtpParams, +) +from .gotrue_admin_api import AsyncGoTrueAdminAPI +from .gotrue_base_api import AsyncGoTrueBaseAPI +from .storage import AsyncMemoryStorage, AsyncSupportedStorage + + +class AsyncGoTrueClient(AsyncGoTrueBaseAPI): + def __init__( + self, + *, + url: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + storage_key: Union[str, None] = None, + auto_refresh_token: bool = True, + persist_session: bool = True, + storage: Union[AsyncSupportedStorage, None] = None, + http_client: Union[AsyncClient, None] = None, + ) -> None: + AsyncGoTrueBaseAPI.__init__( + self, + url=url or GOTRUE_URL, + headers=headers or DEFAULT_HEADERS, + http_client=http_client, + ) + self._storage_key = storage_key or STORAGE_KEY + self._auto_refresh_token = auto_refresh_token + self._persist_session = persist_session + self._storage = storage or AsyncMemoryStorage() + self._in_memory_session: Union[Session, None] = None + self._refresh_token_timer: Union[Timer, None] = None + self._network_retries = 0 + self._state_change_emitters: Dict[str, Subscription] = {} + + self.admin = AsyncGoTrueAdminAPI( + url=self._url, + headers=self._headers, + http_client=self._http_client, + ) + + # Initializations + + async def initialize(self, *, url: Union[str, None] = None) -> None: + if url and self._is_implicit_grant_flow(url): + await self.initialize_from_url(url) + else: + await self.initialize_from_storage() + + async def initialize_from_storage(self) -> None: + return await self._recover_and_refresh() + + async def initialize_from_url(self, url: str) -> None: + try: + if self._is_implicit_grant_flow(url): + session, redirect_type = await self._get_session_from_url(url) + await self._save_session(session) + self._notify_all_subscribers("SIGNED_IN", session) + if redirect_type == "recovery": + self._notify_all_subscribers("PASSWORD_RECOVERY", session) + except Exception as e: + await self._remove_session() + raise e + + # Public methods + + async def sign_up( + self, + credentials: SignUpWithPasswordCredentials, + ) -> AuthResponse: + """ + Creates a new user. + """ + await self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + data = options.get("data") + captcha_token = options.get("captcha_token") + if email: + response = await self._request( + "POST", + "signup", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=redirect_to, + xform=parse_auth_response, + ) + elif phone: + response = await self._request( + "POST", + "signup", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + await self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + async def sign_in_with_password( + self, + credentials: SignInWithPasswordCredentials, + ) -> AuthResponse: + """ + Log in an existing user with an email or phone and password. + """ + await self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + captcha_token = options.get("captcha_token") + if email: + response = await self._request( + "POST", + "token", + body={ + "email": email, + "password": password, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + elif phone: + response = await self._request( + "POST", + "token", + body={ + "phone": phone, + "password": password, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + await self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + async def sign_in_with_oauth( + self, + credentials: SignInWithOAuthCredentials, + ) -> OAuthResponse: + """ + Log in an existing user via a third-party provider. + """ + await self._remove_session() + provider = credentials.get("provider") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + scopes = options.get("scopes") + params = options.get("query_params", {}) + if redirect_to: + params["redirect_to"] = redirect_to + if scopes: + params["scopes"] = scopes + url = self._get_url_for_provider(provider, params) + return OAuthResponse(provider=provider, url=url) + + async def sign_in_with_otp( + self, + credentials: SignInWithPasswordlessCredentials, + ) -> AuthResponse: + """ + Log in a user using magiclink or a one-time password (OTP). + + If the `{{ .ConfirmationURL }}` variable is specified in + the email template, a magiclink will be sent. + + If the `{{ .Token }}` variable is specified in the email + template, an OTP will be sent. + + If you're using phone sign-ins, only an OTP will be sent. + You won't be able to send a magiclink for phone sign-ins. + """ + await self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + options = credentials.get("options", {}) + email_redirect_to = options.get("email_redirect_to") + should_create_user = options.get("create_user", True) + data = options.get("data") + captcha_token = options.get("captcha_token") + if email: + return await self._request( + "POST", + "otp", + body={ + "email": email, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=email_redirect_to, + xform=parse_auth_response, + ) + if phone: + return await self._request( + "POST", + "otp", + body={ + "phone": phone, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number" + ) + + async def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: + """ + Log in a user given a User supplied OTP received via mobile. + """ + await self._remove_session() + response = await self._request( + "POST", + "verify", + body={ + "gotrue_meta_security": { + "captcha_token": params.get("options", {}).get("captcha_token"), + }, + **params, + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_auth_response, + ) + if response.session: + await self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + async def get_session(self) -> Union[Session, None]: + """ + Returns the session, refreshing it if necessary. + + The session returned can be null if the session is not detected which + can happen in the event a user is not signed-in or has logged out. + """ + current_session: Union[Session, None] = None + if self._persist_session: + maybe_session = await self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(maybe_session) + if not current_session: + await self._remove_session() + else: + current_session = self._in_memory_session + if not current_session: + return None + has_expired = ( + current_session.expires_at <= time() + if current_session.expires_at + else False + ) + if not has_expired: + return current_session + return await self._call_refresh_token(current_session.refresh_token) + + async def get_user(self, jwt: Union[str, None] = None) -> UserResponse: + """ + Gets the current user details if there is an existing session. + + Takes in an optional access token `jwt`. If no `jwt` is provided, + `get_user()` will attempt to get the `jwt` from the current session. + """ + if not jwt: + session = await self.get_session() + if session: + jwt = session.access_token + return await self._request("GET", "user", jwt=jwt, xform=parse_user_response) + + async def update_user(self, attributes: UserAttributes) -> UserResponse: + """ + Updates user data, if there is a logged in user. + """ + session = await self.get_session() + if not session: + raise AuthSessionMissingError() + response = await self._request( + "PUT", + "user", + body=attributes, + jwt=session.access_token, + xform=parse_user_response, + ) + session.user = response.user + await self._save_session(session) + self._notify_all_subscribers("USER_UPDATED", session) + return response + + async def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: + """ + Sets the session data from the current session. If the current session + is expired, `set_session` will take care of refreshing it to obtain a + new session. + + If the refresh token in the current session is invalid and the current + session has expired, an error will be thrown. + + If the current session does not contain at `expires_at` field, + `set_session` will use the exp claim defined in the access token. + + The current session that minimally contains an access token, + refresh token and a user. + """ + time_now = round(time()) + expires_at = time_now + has_expired = True + session: Union[Session, None] = None + if access_token and access_token.split(".")[1]: + json_raw = b64decode(access_token.split(".")[1] + "===").decode("utf-8") + payload = loads(json_raw) + if payload.get("exp"): + expires_at = int(payload.get("exp")) + has_expired = expires_at <= time_now + if has_expired: + if not refresh_token: + raise AuthSessionMissingError() + response = await self._refresh_access_token(refresh_token) + if not response.session: + return AuthResponse() + session = response.session + else: + response = await self.get_user(access_token) + session = Session( + access_token=access_token, + refresh_token=refresh_token, + user=response.user, + token_type="bearer", + expires_in=expires_at - time_now, + expires_at=expires_at, + ) + await self._save_session(session) + self._notify_all_subscribers("TOKEN_REFRESHED", session) + return AuthResponse(session=session, user=response.user) + + async def sign_out(self) -> None: + """ + Inside a browser context, `sign_out` will remove the logged in user from the + browser session and log them out - removing all items from localstorage and + then trigger a `"SIGNED_OUT"` event. + + For server-side management, you can revoke all refresh tokens for a user by + passing a user's JWT through to `api.sign_out`. + + There is no way to revoke a user's access token jwt until it expires. + It is recommended to set a shorter expiry on the jwt for this reason. + """ + session = await self.get_session() + access_token = session.access_token if session else None + if access_token: + await self.admin.sign_out(access_token) + await self._remove_session() + self._notify_all_subscribers("SIGNED_OUT", None) + + async def on_auth_state_change( + self, + callback: Callable[[AuthChangeEvent, Union[Session, None]], None], + ) -> Subscription: + """ + Receive a notification every time an auth event happens. + """ + unique_id = str(uuid4()) + + def _unsubscribe() -> None: + self._state_change_emitters.pop(unique_id) + + subscription = Subscription( + id=unique_id, + callback=callback, + unsubscribe=_unsubscribe, + ) + self._state_change_emitters[unique_id] = subscription + return subscription + + async def reset_password_email( + self, + email: str, + options: Options = {}, + ) -> None: + """ + Sends a password reset request to an email address. + """ + raise NotImplementedError + + # Private methods + + async def _remove_session(self) -> None: + if self._persist_session: + await self._storage.remove_item(self._storage_key) + else: + self._in_memory_session = None + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + + async def _get_session_from_url( + self, + url: str, + ) -> Tuple[Session, Union[str, None]]: + if not self._is_implicit_grant_flow(url): + raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") + result = urlparse(url) + params = parse_qs(result.query) + error_description = self._get_param(params, "error_description") + if error_description: + error_code = self._get_param(params, "error_code") + error = self._get_param(params, "error") + if not error_code: + raise AuthImplicitGrantRedirectError("No error_code detected.") + if not error: + raise AuthImplicitGrantRedirectError("No error detected.") + raise AuthImplicitGrantRedirectError( + error_description, + {"code": error_code, "error": error}, + ) + provider_token = self._get_param(params, "provider_token") + provider_refresh_token = self._get_param(params, "provider_refresh_token") + access_token = self._get_param(params, "access_token") + if not access_token: + raise AuthImplicitGrantRedirectError("No access_token detected.") + expires_in = self._get_param(params, "expires_in") + if not expires_in: + raise AuthImplicitGrantRedirectError("No expires_in detected.") + refresh_token = self._get_param(params, "refresh_token") + if not refresh_token: + raise AuthImplicitGrantRedirectError("No refresh_token detected.") + token_type = self._get_param(params, "token_type") + if not token_type: + raise AuthImplicitGrantRedirectError("No token_type detected.") + time_now = round(time()) + expires_at = time_now + int(expires_in) + user = await self.get_user(access_token) + session = Session( + provider_token=provider_token, + provider_refresh_token=provider_refresh_token, + access_token=access_token, + expires_in=int(expires_in), + expires_at=expires_at, + refresh_token=refresh_token, + token_type=token_type, + user=user.user, + ) + redirect_type = self._get_param(params, "type") + return session, redirect_type + + async def _recover_and_refresh(self) -> None: + raw_session = await self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(raw_session) + if not current_session: + if raw_session: + await self._remove_session() + return + time_now = round(time()) + expires_at = current_session.expires_at + if expires_at and expires_at < time_now + EXPIRY_MARGIN: + refresh_token = current_session.refresh_token + if self._auto_refresh_token and refresh_token: + self._network_retries += 1 + try: + await self._call_refresh_token(refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = Timer( + (RETRY_INTERVAL ** (self._network_retries * 100)), + self._recover_and_refresh, + ) + self._refresh_token_timer.start() + return + await self._remove_session() + return + if self._persist_session: + await self._save_session(current_session) + self._notify_all_subscribers("SIGNED_IN", current_session) + + async def _call_refresh_token(self, refresh_token: str) -> Session: + if not refresh_token: + raise AuthSessionMissingError() + response = await self._refresh_access_token(refresh_token) + if not response.session: + raise AuthSessionMissingError() + await self._save_session(response.session) + self._notify_all_subscribers("TOKEN_REFRESHED", response.session) + return response.session + + async def _refresh_access_token(self, refresh_token: str) -> AuthResponse: + return await self._request( + "POST", + "token", + query={"grant_type": "refresh_token"}, + body={"refresh_token": refresh_token}, + xform=parse_auth_response, + ) + + async def _save_session(self, session: Session) -> None: + if not self._persist_session: + self._in_memory_session = session + expire_at = session.expires_at + if expire_at: + pass + time_now = round(time()) + expire_in = expire_at - time_now + refresh_duration_before_expires = ( + EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 + ) + value = (expire_in - refresh_duration_before_expires) * 1000 + await self._start_auto_refresh_token(value) + if self._persist_session and session.expires_at: + await self._storage.set_item(self._storage_key, session.json()) + + async def _start_auto_refresh_token(self, value: float) -> None: + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + if value <= 0 or not self._auto_refresh_token: + return + + async def refresh_token_function(): + self._network_retries += 1 + try: + session = await self.get_session() + if session: + await self._call_refresh_token(session.refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + await self._start_auto_refresh_token( + RETRY_INTERVAL ** (self._network_retries * 100) + ) + + self._refresh_token_timer = Timer(value, refresh_token_function) + self._refresh_token_timer.start() + + def _notify_all_subscribers( + self, + event: AuthChangeEvent, + session: Union[Session, None], + ) -> None: + for subscription in self._state_change_emitters.values(): + subscription.callback(event, session) + + def _get_valid_session( + self, + raw_session: Union[str, None], + ) -> Union[Session, None]: + if not raw_session: + return None + data = loads(raw_session) + if not data: + return None + if not data.get("access_token"): + return None + if not data.get("refresh_token"): + return None + if not data.get("expires_at"): + return None + try: + expires_at = int(data["expires_at"]) + data["expires_at"] = expires_at + except ValueError: + return None + try: + session = Session.parse_obj(data) + return session + except Exception: + return None + + def _get_param( + self, + query_params: Dict[str, List[str]], + name: str, + ) -> Union[str, None]: + if name in query_params: + return query_params[name][0] + return None + + def _is_implicit_grant_flow(self, url: str) -> bool: + result = urlparse(url) + params = parse_qs(result.query) + return "access_token" in params or "error_description" in params + + def _get_url_for_provider( + self, + provider: Provider, + params: Dict[str, str], + ) -> str: + params = {k: quote(v) for k, v in params.items()} + params["provider"] = quote(provider) + query = urlencode(params) + return f"{self._url}/authorize?{query}" + + +async def test(): + client = AsyncGoTrueClient() + await client.initialize() diff --git a/gotrue/_sync/api.py b/gotrue/_sync/api.py deleted file mode 100644 index bec3f70b..00000000 --- a/gotrue/_sync/api.py +++ /dev/null @@ -1,642 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Union - -from pydantic import parse_obj_as - -from ..exceptions import APIError -from ..helpers import check_response, encode_uri_component -from ..http_clients import SyncClient -from ..types import ( - CookieOptions, - LinkType, - Provider, - Session, - User, - UserAttributes, - determine_session_or_user_model_from_response, -) - - -class SyncGoTrueAPI: - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[SyncClient] = None, - ) -> None: - """Initialise API class.""" - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or SyncClient() - - def __enter__(self) -> SyncGoTrueAPI: - return self - - def __exit__(self, exc_t, exc_v, exc_tb) -> None: - self.close() - - def close(self) -> None: - self.http_client.aclose() - - def create_user(self, *, attributes: UserAttributes) -> User: - """Creates a new user. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - attributes: UserAttributes - The data you want to create the user with. - - Returns - ------- - response : User - The created user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = attributes.dict() - url = f"{self.url}/admin/users" - response = self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - def list_users(self) -> List[User]: - """Get a list of users. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Returns - ------- - response : List[User] - A list of users - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - url = f"{self.url}/admin/users" - response = self.http_client.get(url, headers=headers) - check_response(response) - users = response.json().get("users") - if users is None: - raise APIError("No users found in response", 400) - if not isinstance(users, list): - raise APIError("Expected a list of users", 400) - return parse_obj_as(List[User], users) - - def sign_up_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password, "data": data} - url = f"{self.url}/signup{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def sign_in_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - ) -> Session: - """Logs in an existing user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "?grant_type=password" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string += f"&redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password} - url = f"{self.url}/token{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def sign_up_with_phone( - self, - *, - phone: str, - password: str, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Signs up a new user using their phone number and a password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = {"phone": phone, "password": password, "data": data} - url = f"{self.url}/signup" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def sign_in_with_phone( - self, - *, - phone: str, - password: str, - ) -> Session: - """Logs in an existing user using their phone number and password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - data = {"phone": phone, "password": password} - url = f"{self.url}/token?grant_type=password" - headers = self.headers - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def send_magic_link_email( - self, - *, - email: str, - create_user: bool, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a magic login link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "create_user": create_user} - url = f"{self.url}/magiclink{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def send_mobile_otp(self, *, phone: str, create_user: bool) -> None: - """Sends a mobile OTP via SMS. - Will register the account if it doesn't already exist - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = {"phone": phone, "create_user": create_user} - url = f"{self.url}/otp" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def verify_mobile_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Send User supplied Mobile OTP to be verified - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - token : str - Token that user was sent to their mobile phone - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = { - "phone": phone, - "token": token, - "type": "sms", - } - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/verify" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def invite_user_by_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> User: - """Sends an invite link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "data": data} - url = f"{self.url}/invite{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - def reset_password_for_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a reset request to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email} - url = f"{self.url}/recover{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def _create_request_headers(self, *, jwt: str) -> Dict[str, str]: - """Create temporary object. - - Create a temporary object with all configured headers and adds the - Authorization token to be used on request methods. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - headers : dict of str - The headers required for a successful request statement with the - supabase backend. - """ - headers = {**self.headers, "Authorization": f"Bearer {jwt}"} - return headers - - def sign_out(self, *, jwt: str) -> None: - """Removes a logged-in session. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/logout" - self.http_client.post(url, headers=headers) - - def get_url_for_provider( - self, - *, - provider: Provider, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - ) -> str: - """Generates the relevant login URL for a third-party provider. - - Parameters - ---------- - provider : Provider - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - url : str - The URL to redirect the user to. - - Raises - ------ - error : APIError - If an error occurs - """ - url_params = [f"provider={encode_uri_component(provider)}"] - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - url_params.append(f"redirect_to={redirect_to_encoded}") - if scopes: - url_params.append(f"scopes={encode_uri_component(scopes)}") - return f"{self.url}/authorize?{'&'.join(url_params)}" - - def get_user(self, *, jwt: str) -> User: - """Gets the user details. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/user" - response = self.http_client.get(url, headers=headers) - return User.parse_response(response) - - def update_user( - self, - *, - jwt: str, - attributes: UserAttributes, - ) -> User: - """ - Updates the user data. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - attributes : UserAttributes - The data you want to update. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - data = attributes.dict() - url = f"{self.url}/user" - response = self.http_client.put(url, json=data, headers=headers) - return User.parse_response(response) - - def delete_user(self, *, uid: str, jwt: str) -> None: - """Delete a user. Requires a `service_role` key. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - uid : str - The user uid you want to remove. - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/admin/users/{uid}" - response = self.http_client.delete(url, headers=headers) - return check_response(response) - - def refresh_access_token(self, *, refresh_token: str) -> Session: - """Generates a new JWT. - - Parameters - ---------- - refresh_token : str - A valid refresh token that was returned on login. - - Returns - ------- - response : Session - A session - - Raises - ------ - error : APIError - If an error occurs - """ - data = {"refresh_token": refresh_token} - url = f"{self.url}/token?grant_type=refresh_token" - headers = self.headers - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def generate_link( - self, - *, - type: LinkType, - email: str, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """ - Generates links to be sent via email or other. - - Parameters - ---------- - type : LinkType - The link type ("signup" or "magiclink" or "recovery" or "invite"). - email : str - The user's email. - password : Optional[str] - User password. For signup only. - redirect_to : Optional[str] - The link type ("signup" or "magiclink" or "recovery" or "invite"). - data : Optional[Dict[str, Any]] - Optional user metadata. For signup only. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = { - "type": type, - "email": email, - "data": data, - } - if password: - data["password"] = password - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/admin/generate_link" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def set_auth_cookie(self, *, req, res): - """Stub for parity with JS api.""" - raise NotImplementedError("set_auth_cookie not implemented.") - - def get_user_by_cookie(self, *, req): - """Stub for parity with JS api.""" - raise NotImplementedError("get_user_by_cookie not implemented.") diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py deleted file mode 100644 index f7f20455..00000000 --- a/gotrue/_sync/client.py +++ /dev/null @@ -1,640 +0,0 @@ -from __future__ import annotations - -from functools import partial -from json import dumps, loads -from threading import Timer -from time import time -from typing import Any, Callable, Dict, Optional, Tuple, Union, cast -from urllib.parse import parse_qs, urlparse -from uuid import uuid4 - -from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY -from ..exceptions import APIError -from ..types import ( - AuthChangeEvent, - CookieOptions, - Provider, - Session, - Subscription, - User, - UserAttributes, - UserAttributesDict, -) -from .api import SyncGoTrueAPI -from .storage import SyncMemoryStorage, SyncSupportedStorage - - -class SyncGoTrueClient: - def __init__( - self, - *, - url: str = GOTRUE_URL, - headers: Dict[str, str] = {}, - auto_refresh_token: bool = True, - persist_session: bool = True, - local_storage: SyncSupportedStorage = SyncMemoryStorage(), - cookie_options: CookieOptions = CookieOptions.parse_obj(COOKIE_OPTIONS), - api: Optional[SyncGoTrueAPI] = None, - replace_default_headers: bool = False, - ) -> None: - """Create a new client - - url : str - The URL of the GoTrue server. - headers : Dict[str, str] - Any additional headers to send to the GoTrue server. - auto_refresh_token : bool - Set to "true" if you want to automatically refresh the token before - expiring. - persist_session : bool - Set to "true" if you want to automatically save the user session - into local storage. - local_storage : SupportedStorage - The storage engine to use for persisting the session. - cookie_options : CookieOptions - The options for the cookie. - """ - if url.startswith("http://"): - print( - "Warning:\n\nDO NOT USE HTTP IN PRODUCTION FOR GOTRUE EVER!\n" - "GoTrue REQUIRES HTTPS to work securely." - ) - self.state_change_emitters: Dict[str, Subscription] = {} - self.refresh_token_timer: Optional[Timer] = None - self.current_user: Optional[User] = None - self.current_session: Optional[Session] = None - self.auto_refresh_token = auto_refresh_token - self.persist_session = persist_session - self.local_storage = local_storage - empty_or_default_headers = {} if replace_default_headers else DEFAULT_HEADERS - args = { - "url": url, - "headers": {**empty_or_default_headers, **headers}, - "cookie_options": cookie_options, - } - self.api = api or SyncGoTrueAPI(**args) - - def __enter__(self) -> SyncGoTrueClient: - return self - - def __exit__(self, exc_t, exc_v, exc_tb) -> None: - self.close() - - def close(self) -> None: - self.api.close() - - def init_recover(self) -> None: - """Recover the current session from local storage.""" - self._recover_session() - self._recover_and_refresh() - - def sign_up( - self, - *, - email: Optional[str] = None, - phone: Optional[str] = None, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user. If email and phone are provided, email will be - used and phone will be ignored. - - Parameters - --------- - email : Optional[str] - The user's email address. - phone : Optional[str] - The user's phone number. - password : Optional[str] - The user's password. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - self._remove_session() - - if email and password: - response = self.api.sign_up_with_email( - email=email, - password=password, - redirect_to=redirect_to, - data=data, - ) - elif phone and password: - response = self.api.sign_up_with_phone( - phone=phone, password=password, data=data - ) - elif not password: - raise ValueError("Password must be defined, can't be None.") - else: - raise ValueError("Email or phone must be defined, both can't be None.") - - if isinstance(response, Session): - # The user has confirmed their email or the underlying DB doesn't - # require email confirmation. - self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def sign_in( - self, - *, - email: Optional[str] = None, - phone: Optional[str] = None, - password: Optional[str] = None, - refresh_token: Optional[str] = None, - provider: Optional[Provider] = None, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - create_user: bool = False, - ) -> Optional[Union[Session, str]]: - """Log in an existing user, or login via a third-party provider. - If email and phone are provided, email will be used and phone will be ignored. - - Parameters - --------- - email : Optional[str] - The user's email address. - phone : Optional[str] - The user's phone number. - password : Optional[str] - The user's password. - refresh_token : Optional[str] - A valid refresh token that was returned on login. - provider : Optional[Provider] - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - response : Optional[Union[Session, str]] - If only email are provided between the email and password, - None is returned and send magic link to email - - If email and password are provided, a logged-in session is returned. - - If only phone are provided between the phone and password, - None is returned and send message to phone - - If phone and password are provided, a logged-in session is returned. - - If refresh_token is provided, a logged-in session is returned. - - If provider is provided, an redirect URL is returned. - - Otherwise, error is raised. - - Raises - ------ - error : APIError - If an error occurs - """ - self._remove_session() - if email: - if password: - response = self._handle_email_sign_in( - email=email, - password=password, - redirect_to=redirect_to, - ) - else: - response = self.api.send_magic_link_email( - email=email, create_user=create_user - ) - elif phone: - if password: - response = self._handle_phone_sign_in(phone=phone, password=password) - else: - response = self.api.send_mobile_otp( - phone=phone, create_user=create_user - ) - elif refresh_token: - # current_session and current_user will be updated to latest - # on _call_refresh_token using the passed refresh_token - self._call_refresh_token(refresh_token=refresh_token) - response = self.current_session - elif provider: - response = self._handle_provider_sign_in( - provider=provider, - redirect_to=redirect_to, - scopes=scopes, - ) - else: - raise ValueError( - "Email, phone, refresh_token, or provider must be defined, " - "all can't be None." - ) - return response - - def verify_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Log in a user given a User supplied OTP received via mobile. - - Parameters - ---------- - phone : str - The user's phone number. - token : str - The user's OTP. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - self._remove_session() - response = self.api.verify_mobile_otp( - phone=phone, - token=token, - redirect_to=redirect_to, - ) - if isinstance(response, Session): - self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def user(self) -> Optional[User]: - """Returns the user data, if there is a logged in user.""" - return self.current_user - - def session(self) -> Optional[Session]: - """Returns the session data, if there is an active session.""" - return self.current_session - - def refresh_session(self) -> Session: - """Force refreshes the session. - - Force refreshes the session including the user data incase it was - updated in a different session. - """ - if not self.current_session: - raise ValueError("Not logged in.") - return self._call_refresh_token() - - def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User: - """Updates user data, if there is a logged in user. - - Parameters - ---------- - attributes : UserAttributesDict | UserAttributes - Attributes to update, could be: email, password, email_change_token, data - - Returns - ------- - response : User - The updated user data. - - Raises - ------ - error : APIError - If an error occurs - """ - if not self.current_session: - raise ValueError("Not logged in.") - - if isinstance(attributes, dict): - attributes_to_update = UserAttributes(**attributes) - else: - attributes_to_update = attributes - - response = self.api.update_user( - jwt=self.current_session.access_token, - attributes=attributes_to_update, - ) - self.current_session.user = response - self._save_session(session=self.current_session) - self._notify_all_subscribers(event=AuthChangeEvent.USER_UPDATED) - return response - - def set_session(self, *, refresh_token: str) -> Session: - """Sets the session data from refresh_token and returns current Session - - Parameters - ---------- - refresh_token : str - A JWT token - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - response = self.api.refresh_access_token(refresh_token=refresh_token) - self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def set_auth(self, *, access_token: str) -> Session: - """Overrides the JWT on the current client. The JWT will then be sent in - all subsequent network requests. - - Parameters - ---------- - access_token : str - A JWT token - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - session = Session( - access_token=access_token, - token_type="bearer", - user=None, - expires_in=None, - expires_at=None, - refresh_token=None, - provider_token=None, - ) - if self.current_session: - session.expires_in = self.current_session.expires_in - session.expires_at = self.current_session.expires_at - session.refresh_token = self.current_session.refresh_token - session.provider_token = self.current_session.provider_token - self._save_session(session=session) - return session - - def get_session_from_url( - self, - *, - url: str, - store_session: bool = False, - ) -> Session: - """Gets the session data from a URL string. - - Parameters - ---------- - url : str - The URL string. - store_session : bool - Optionally store the session in the browser - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - data = urlparse(url) - query = parse_qs(data.query) - error_description = query.get("error_description") - access_token = query.get("access_token") - expires_in = query.get("expires_in") - refresh_token = query.get("refresh_token") - token_type = query.get("token_type") - if error_description: - raise APIError(error_description[0], 400) - if not access_token or not access_token[0]: - raise APIError("No access_token detected.", 400) - if not refresh_token or not refresh_token[0]: - raise APIError("No refresh_token detected.", 400) - if not token_type or not token_type[0]: - raise APIError("No token_type detected.", 400) - if not expires_in or not expires_in[0]: - raise APIError("No expires_in detected.", 400) - try: - expires_at = round(time()) + int(expires_in[0]) - except ValueError: - raise APIError("Invalid expires_in.", 400) - response = self.api.get_user(jwt=access_token[0]) - provider_token = query.get("provider_token") - session = Session( - access_token=access_token[0], - token_type=token_type[0], - user=response, - expires_in=int(expires_in[0]), - expires_at=expires_at, - refresh_token=refresh_token[0], - provider_token=provider_token[0] if provider_token else None, - ) - if store_session: - self._save_session(session=session) - recovery_mode = query.get("type") - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - if recovery_mode and recovery_mode[0] == "recovery": - self._notify_all_subscribers(event=AuthChangeEvent.PASSWORD_RECOVERY) - return session - - def sign_out(self) -> None: - """Log the user out.""" - access_token: Optional[str] = None - if self.current_session: - access_token = self.current_session.access_token - self._remove_session() - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_OUT) - if access_token: - self.api.sign_out(jwt=access_token) - - def _unsubscribe(self, *, id: str) -> None: - """Unsubscribe from a subscription.""" - self.state_change_emitters.pop(id) - - def on_auth_state_change( - self, - *, - callback: Callable[[AuthChangeEvent, Optional[Session]], None], - ) -> Subscription: - """Receive a notification every time an auth event happens. - - Parameters - ---------- - callback : Callable[[AuthChangeEvent, Optional[Session]], None] - The callback to call when an auth event happens. - - Returns - ------- - subscription : Subscription - A subscription object which can be used to unsubscribe itself. - - Raises - ------ - error : APIError - If an error occurs - """ - unique_id = uuid4() - subscription = Subscription( - id=unique_id, - callback=callback, - unsubscribe=partial(self._unsubscribe, id=unique_id.hex), - ) - self.state_change_emitters[unique_id.hex] = subscription - return subscription - - def _handle_email_sign_in( - self, - *, - email: str, - password: str, - redirect_to: Optional[str], - ) -> Session: - """Sign in with email and password.""" - response = self.api.sign_in_with_email( - email=email, - password=password, - redirect_to=redirect_to, - ) - self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def _handle_phone_sign_in(self, *, phone: str, password: str) -> Session: - """Sign in with phone and password.""" - response = self.api.sign_in_with_phone(phone=phone, password=password) - self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def _handle_provider_sign_in( - self, - *, - provider: Provider, - redirect_to: Optional[str], - scopes: Optional[str], - ) -> str: - """Sign in with provider.""" - return self.api.get_url_for_provider( - provider=provider, - redirect_to=redirect_to, - scopes=scopes, - ) - - def _recover_common(self) -> Optional[Tuple[Session, int, int]]: - """Recover common logic""" - json = self.local_storage.get_item(STORAGE_KEY) - if not json: - return - data = loads(json) - session_raw = data.get("session") - expires_at_raw = data.get("expires_at") - if ( - expires_at_raw - and isinstance(expires_at_raw, int) - and session_raw - and isinstance(session_raw, dict) - ): - session = Session.parse_obj(session_raw) - expires_at = int(expires_at_raw) - time_now = round(time()) - return session, expires_at, time_now - - def _recover_session(self) -> None: - """Attempts to get the session from LocalStorage""" - result = self._recover_common() - if not result: - return - session, expires_at, time_now = result - if expires_at >= time_now: - self._save_session(session=session) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - - def _recover_and_refresh(self) -> None: - """Recovers the session from LocalStorage and refreshes""" - result = self._recover_common() - if not result: - return - session, expires_at, time_now = result - if expires_at < time_now and self.auto_refresh_token and session.refresh_token: - try: - self._call_refresh_token(refresh_token=session.refresh_token) - except APIError: - self._remove_session() - elif expires_at < time_now or not session or not session.user: - self._remove_session() - else: - self._save_session(session=session) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - - def _call_refresh_token(self, *, refresh_token: Optional[str] = None) -> Session: - if refresh_token is None: - if self.current_session: - refresh_token = self.current_session.refresh_token - else: - raise ValueError("No current session and refresh_token not supplied.") - response = self.api.refresh_access_token(refresh_token=cast(str, refresh_token)) - self._save_session(session=response) - self._notify_all_subscribers(event=AuthChangeEvent.TOKEN_REFRESHED) - self._notify_all_subscribers(event=AuthChangeEvent.SIGNED_IN) - return response - - def _notify_all_subscribers(self, *, event: AuthChangeEvent) -> None: - """Notify all subscribers that auth event happened.""" - for value in self.state_change_emitters.values(): - value.callback(event, self.current_session) - - def _save_session(self, *, session: Session) -> None: - """Save session to client.""" - self.current_session = session - self.current_user = session.user - if session.expires_at: - time_now = round(time()) - expire_in = session.expires_at - time_now - refresh_duration_before_expires = 60 if expire_in > 60 else 0.5 - self._start_auto_refresh_token( - value=(expire_in - refresh_duration_before_expires) * 1000 - ) - if self.persist_session and session.expires_at: - self._persist_session(session=session) - - def _persist_session(self, *, session: Session) -> None: - data = {"session": session.dict(), "expires_at": session.expires_at} - self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str)) - - def _remove_session(self) -> None: - """Remove the session.""" - self.current_session = None - self.current_user = None - if self.refresh_token_timer: - self.refresh_token_timer.cancel() - self.local_storage.remove_item(STORAGE_KEY) - - def _start_auto_refresh_token(self, *, value: float) -> None: - if self.refresh_token_timer: - self.refresh_token_timer.cancel() - if value <= 0 or not self.auto_refresh_token: - return - self.refresh_token_timer = Timer(value, self._call_refresh_token) - self.refresh_token_timer.start() diff --git a/gotrue/_sync/gotrue_admin_api.py b/gotrue/_sync/gotrue_admin_api.py new file mode 100644 index 00000000..00576d58 --- /dev/null +++ b/gotrue/_sync/gotrue_admin_api.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from typing import Dict, List, Union + +from ..helpers import parse_link_response, parse_user_response +from ..http_clients import SyncClient +from ..types import ( + AdminUserAttributes, + GenerateLinkParams, + GenerateLinkResponse, + Options, + User, + UserResponse, +) +from .gotrue_base_api import SyncGoTrueBaseAPI + + +class SyncGoTrueAdminAPI(SyncGoTrueBaseAPI): + def __init__( + self, + *, + url: str = "", + headers: Dict[str, str] = {}, + http_client: Union[SyncClient, None] = None, + ) -> None: + SyncGoTrueBaseAPI.__init__( + self, + url=url, + headers=headers, + http_client=http_client, + ) + + def sign_out(self, jwt: str) -> None: + """ + Removes a logged-in session. + """ + return self._request("POST", "logout", jwt=jwt) + + def invite_user_by_email( + self, + email: str, + options: Options = {}, + ) -> UserResponse: + """ + Sends an invite link to an email address. + """ + return self._request( + "POST", + "invite", + body={"email": email, "data": options.get("data")}, + redirect_to=options.get("redirect_to"), + xform=parse_user_response, + ) + + def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: + """ + Generates email links and OTPs to be sent via a custom email provider. + """ + return self._request( + "POST", + "admin/generate_link", + body={ + "type": params.get("type"), + "email": params.get("email"), + "password": params.get("password"), + "new_email": params.get("new_email"), + "data": params.get("options", {}).get("data"), + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_link_response, + ) + + # User Admin API + + def create_user(self, attributes: AdminUserAttributes) -> UserResponse: + """ + Creates a new user. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "POST", + "admin/users", + body=attributes, + xform=parse_user_response, + ) + + def list_users(self) -> List[User]: + """ + Get a list of users. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "GET", + "admin/users", + xform=lambda data: [User.parse_obj(user) for user in data], + ) + + def get_user_by_id(self, uid: str) -> UserResponse: + """ + Get user by id. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "GET", + f"admin/users/{uid}", + xform=parse_user_response, + ) + + def update_user_by_id( + self, + uid: str, + attributes: AdminUserAttributes, + ) -> UserResponse: + """ + Updates the user data. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "PUT", + f"admin/users/{uid}", + body=attributes, + xform=parse_user_response, + ) + + def delete_user(self, id: str) -> UserResponse: + """ + Delete a user. Requires a `service_role` key. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "DELETE", + f"admin/users/{id}", + xform=parse_user_response, + ) diff --git a/gotrue/_sync/gotrue_base_api.py b/gotrue/_sync/gotrue_base_api.py new file mode 100644 index 00000000..a1868fce --- /dev/null +++ b/gotrue/_sync/gotrue_base_api.py @@ -0,0 +1,98 @@ +from __future__ import annotations +from typing import Any, Callable, Dict, Literal, TypeVar, Union, overload + +from pydantic import BaseModel +from typing_extensions import Self + +from ..helpers import handle_exception +from ..http_clients import SyncClient + +T = TypeVar("T") + + +class SyncGoTrueBaseAPI: + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + http_client: Union[SyncClient, None], + ): + self._url = url + self._headers = headers + self._http_client = http_client or SyncClient() + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_t, exc_v, exc_tb) -> None: + self.close() + + def close(self) -> None: + self._http_client.aclose() + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Union[bool, None] = None, + xform: Callable[[Any], T], + ) -> T: + ... + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Union[bool, None] = None, + ) -> None: + ... + + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Union[bool, None] = None, + xform: Union[Callable[[Any], T], None] = None, + ) -> Union[T, None]: + url = f"{self._url}/{path}" + headers = {**self._headers, **(headers or {})} + if jwt: + headers["Authorization"] = f"Bearer {jwt}" + query = query or {} + if redirect_to: + query["redirect_to"] = redirect_to + try: + response = self._http_client.request( + method, + url, + headers=headers, + params=query, + json=body.dict() if isinstance(body, BaseModel) else body, + ) + response.raise_for_status() + result = response if no_resolve_json else response.json() + if xform: + return xform(result) + except Exception as e: + raise handle_exception(e) diff --git a/gotrue/_sync/gotrue_client.py b/gotrue/_sync/gotrue_client.py new file mode 100644 index 00000000..15571e19 --- /dev/null +++ b/gotrue/_sync/gotrue_client.py @@ -0,0 +1,683 @@ +from __future__ import annotations + +from base64 import b64decode +from json import loads +from time import time +from typing import Callable, Dict, List, Tuple, Union +from urllib.parse import parse_qs, quote, urlencode, urlparse +from uuid import uuid4 + +from ..constants import ( + DEFAULT_HEADERS, + EXPIRY_MARGIN, + GOTRUE_URL, + MAX_RETRIES, + RETRY_INTERVAL, + STORAGE_KEY, +) +from ..errors import ( + AuthImplicitGrantRedirectError, + AuthInvalidCredentialsError, + AuthRetryableError, + AuthSessionMissingError, +) +from ..helpers import parse_auth_response, parse_user_response +from ..http_clients import SyncClient +from ..timer import Timer +from ..types import ( + AuthChangeEvent, + AuthResponse, + OAuthResponse, + Options, + Provider, + Session, + SignInWithOAuthCredentials, + SignInWithPasswordCredentials, + SignInWithPasswordlessCredentials, + SignUpWithPasswordCredentials, + Subscription, + UserAttributes, + UserResponse, + VerifyOtpParams, +) +from .gotrue_admin_api import SyncGoTrueAdminAPI +from .gotrue_base_api import SyncGoTrueBaseAPI +from .storage import SyncMemoryStorage, SyncSupportedStorage + + +class SyncGoTrueClient(SyncGoTrueBaseAPI): + def __init__( + self, + *, + url: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + storage_key: Union[str, None] = None, + auto_refresh_token: bool = True, + persist_session: bool = True, + storage: Union[SyncSupportedStorage, None] = None, + http_client: Union[SyncClient, None] = None, + ) -> None: + SyncGoTrueBaseAPI.__init__( + self, + url=url or GOTRUE_URL, + headers=headers or DEFAULT_HEADERS, + http_client=http_client, + ) + self._storage_key = storage_key or STORAGE_KEY + self._auto_refresh_token = auto_refresh_token + self._persist_session = persist_session + self._storage = storage or SyncMemoryStorage() + self._in_memory_session: Union[Session, None] = None + self._refresh_token_timer: Union[Timer, None] = None + self._network_retries = 0 + self._state_change_emitters: Dict[str, Subscription] = {} + + self.admin = SyncGoTrueAdminAPI( + url=self._url, + headers=self._headers, + http_client=self._http_client, + ) + + # Initializations + + def initialize(self, *, url: Union[str, None] = None) -> None: + if url and self._is_implicit_grant_flow(url): + self.initialize_from_url(url) + else: + self.initialize_from_storage() + + def initialize_from_storage(self) -> None: + return self._recover_and_refresh() + + def initialize_from_url(self, url: str) -> None: + try: + if self._is_implicit_grant_flow(url): + session, redirect_type = self._get_session_from_url(url) + self._save_session(session) + self._notify_all_subscribers("SIGNED_IN", session) + if redirect_type == "recovery": + self._notify_all_subscribers("PASSWORD_RECOVERY", session) + except Exception as e: + self._remove_session() + raise e + + # Public methods + + def sign_up( + self, + credentials: SignUpWithPasswordCredentials, + ) -> AuthResponse: + """ + Creates a new user. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + data = options.get("data") + captcha_token = options.get("captcha_token") + if email: + response = self._request( + "POST", + "signup", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=redirect_to, + xform=parse_auth_response, + ) + elif phone: + response = self._request( + "POST", + "signup", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def sign_in_with_password( + self, + credentials: SignInWithPasswordCredentials, + ) -> AuthResponse: + """ + Log in an existing user with an email or phone and password. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + captcha_token = options.get("captcha_token") + if email: + response = self._request( + "POST", + "token", + body={ + "email": email, + "password": password, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + elif phone: + response = self._request( + "POST", + "token", + body={ + "phone": phone, + "password": password, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def sign_in_with_oauth( + self, + credentials: SignInWithOAuthCredentials, + ) -> OAuthResponse: + """ + Log in an existing user via a third-party provider. + """ + self._remove_session() + provider = credentials.get("provider") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + scopes = options.get("scopes") + params = options.get("query_params", {}) + if redirect_to: + params["redirect_to"] = redirect_to + if scopes: + params["scopes"] = scopes + url = self._get_url_for_provider(provider, params) + return OAuthResponse(provider=provider, url=url) + + def sign_in_with_otp( + self, + credentials: SignInWithPasswordlessCredentials, + ) -> AuthResponse: + """ + Log in a user using magiclink or a one-time password (OTP). + + If the `{{ .ConfirmationURL }}` variable is specified in + the email template, a magiclink will be sent. + + If the `{{ .Token }}` variable is specified in the email + template, an OTP will be sent. + + If you're using phone sign-ins, only an OTP will be sent. + You won't be able to send a magiclink for phone sign-ins. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + options = credentials.get("options", {}) + email_redirect_to = options.get("email_redirect_to") + should_create_user = options.get("create_user", True) + data = options.get("data") + captcha_token = options.get("captcha_token") + if email: + return self._request( + "POST", + "otp", + body={ + "email": email, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=email_redirect_to, + xform=parse_auth_response, + ) + if phone: + return self._request( + "POST", + "otp", + body={ + "phone": phone, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number" + ) + + def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: + """ + Log in a user given a User supplied OTP received via mobile. + """ + self._remove_session() + response = self._request( + "POST", + "verify", + body={ + "gotrue_meta_security": { + "captcha_token": params.get("options", {}).get("captcha_token"), + }, + **params, + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_auth_response, + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def get_session(self) -> Union[Session, None]: + """ + Returns the session, refreshing it if necessary. + + The session returned can be null if the session is not detected which + can happen in the event a user is not signed-in or has logged out. + """ + current_session: Union[Session, None] = None + if self._persist_session: + maybe_session = self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(maybe_session) + if not current_session: + self._remove_session() + else: + current_session = self._in_memory_session + if not current_session: + return None + has_expired = ( + current_session.expires_at <= time() + if current_session.expires_at + else False + ) + if not has_expired: + return current_session + return self._call_refresh_token(current_session.refresh_token) + + def get_user(self, jwt: Union[str, None] = None) -> UserResponse: + """ + Gets the current user details if there is an existing session. + + Takes in an optional access token `jwt`. If no `jwt` is provided, + `get_user()` will attempt to get the `jwt` from the current session. + """ + if not jwt: + session = self.get_session() + if session: + jwt = session.access_token + return self._request("GET", "user", jwt=jwt, xform=parse_user_response) + + def update_user(self, attributes: UserAttributes) -> UserResponse: + """ + Updates user data, if there is a logged in user. + """ + session = self.get_session() + if not session: + raise AuthSessionMissingError() + response = self._request( + "PUT", + "user", + body=attributes, + jwt=session.access_token, + xform=parse_user_response, + ) + session.user = response.user + self._save_session(session) + self._notify_all_subscribers("USER_UPDATED", session) + return response + + def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: + """ + Sets the session data from the current session. If the current session + is expired, `set_session` will take care of refreshing it to obtain a + new session. + + If the refresh token in the current session is invalid and the current + session has expired, an error will be thrown. + + If the current session does not contain at `expires_at` field, + `set_session` will use the exp claim defined in the access token. + + The current session that minimally contains an access token, + refresh token and a user. + """ + time_now = round(time()) + expires_at = time_now + has_expired = True + session: Union[Session, None] = None + if access_token and access_token.split(".")[1]: + json_raw = b64decode(access_token.split(".")[1] + "===").decode("utf-8") + payload = loads(json_raw) + if payload.get("exp"): + expires_at = int(payload.get("exp")) + has_expired = expires_at <= time_now + if has_expired: + if not refresh_token: + raise AuthSessionMissingError() + response = self._refresh_access_token(refresh_token) + if not response.session: + return AuthResponse() + session = response.session + else: + response = self.get_user(access_token) + session = Session( + access_token=access_token, + refresh_token=refresh_token, + user=response.user, + token_type="bearer", + expires_in=expires_at - time_now, + expires_at=expires_at, + ) + self._save_session(session) + self._notify_all_subscribers("TOKEN_REFRESHED", session) + return AuthResponse(session=session, user=response.user) + + def sign_out(self) -> None: + """ + Inside a browser context, `sign_out` will remove the logged in user from the + browser session and log them out - removing all items from localstorage and + then trigger a `"SIGNED_OUT"` event. + + For server-side management, you can revoke all refresh tokens for a user by + passing a user's JWT through to `api.sign_out`. + + There is no way to revoke a user's access token jwt until it expires. + It is recommended to set a shorter expiry on the jwt for this reason. + """ + session = self.get_session() + access_token = session.access_token if session else None + if access_token: + self.admin.sign_out(access_token) + self._remove_session() + self._notify_all_subscribers("SIGNED_OUT", None) + + def on_auth_state_change( + self, + callback: Callable[[AuthChangeEvent, Union[Session, None]], None], + ) -> Subscription: + """ + Receive a notification every time an auth event happens. + """ + unique_id = str(uuid4()) + + def _unsubscribe() -> None: + self._state_change_emitters.pop(unique_id) + + subscription = Subscription( + id=unique_id, + callback=callback, + unsubscribe=_unsubscribe, + ) + self._state_change_emitters[unique_id] = subscription + return subscription + + def reset_password_email( + self, + email: str, + options: Options = {}, + ) -> None: + """ + Sends a password reset request to an email address. + """ + raise NotImplementedError + + # Private methods + + def _remove_session(self) -> None: + if self._persist_session: + self._storage.remove_item(self._storage_key) + else: + self._in_memory_session = None + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + + def _get_session_from_url( + self, + url: str, + ) -> Tuple[Session, Union[str, None]]: + if not self._is_implicit_grant_flow(url): + raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") + result = urlparse(url) + params = parse_qs(result.query) + error_description = self._get_param(params, "error_description") + if error_description: + error_code = self._get_param(params, "error_code") + error = self._get_param(params, "error") + if not error_code: + raise AuthImplicitGrantRedirectError("No error_code detected.") + if not error: + raise AuthImplicitGrantRedirectError("No error detected.") + raise AuthImplicitGrantRedirectError( + error_description, + {"code": error_code, "error": error}, + ) + provider_token = self._get_param(params, "provider_token") + provider_refresh_token = self._get_param(params, "provider_refresh_token") + access_token = self._get_param(params, "access_token") + if not access_token: + raise AuthImplicitGrantRedirectError("No access_token detected.") + expires_in = self._get_param(params, "expires_in") + if not expires_in: + raise AuthImplicitGrantRedirectError("No expires_in detected.") + refresh_token = self._get_param(params, "refresh_token") + if not refresh_token: + raise AuthImplicitGrantRedirectError("No refresh_token detected.") + token_type = self._get_param(params, "token_type") + if not token_type: + raise AuthImplicitGrantRedirectError("No token_type detected.") + time_now = round(time()) + expires_at = time_now + int(expires_in) + user = self.get_user(access_token) + session = Session( + provider_token=provider_token, + provider_refresh_token=provider_refresh_token, + access_token=access_token, + expires_in=int(expires_in), + expires_at=expires_at, + refresh_token=refresh_token, + token_type=token_type, + user=user.user, + ) + redirect_type = self._get_param(params, "type") + return session, redirect_type + + def _recover_and_refresh(self) -> None: + raw_session = self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(raw_session) + if not current_session: + if raw_session: + self._remove_session() + return + time_now = round(time()) + expires_at = current_session.expires_at + if expires_at and expires_at < time_now + EXPIRY_MARGIN: + refresh_token = current_session.refresh_token + if self._auto_refresh_token and refresh_token: + self._network_retries += 1 + try: + self._call_refresh_token(refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = Timer( + (RETRY_INTERVAL ** (self._network_retries * 100)), + self._recover_and_refresh, + ) + self._refresh_token_timer.start() + return + self._remove_session() + return + if self._persist_session: + self._save_session(current_session) + self._notify_all_subscribers("SIGNED_IN", current_session) + + def _call_refresh_token(self, refresh_token: str) -> Session: + if not refresh_token: + raise AuthSessionMissingError() + response = self._refresh_access_token(refresh_token) + if not response.session: + raise AuthSessionMissingError() + self._save_session(response.session) + self._notify_all_subscribers("TOKEN_REFRESHED", response.session) + return response.session + + def _refresh_access_token(self, refresh_token: str) -> AuthResponse: + return self._request( + "POST", + "token", + query={"grant_type": "refresh_token"}, + body={"refresh_token": refresh_token}, + xform=parse_auth_response, + ) + + def _save_session(self, session: Session) -> None: + if not self._persist_session: + self._in_memory_session = session + expire_at = session.expires_at + if expire_at: + pass + time_now = round(time()) + expire_in = expire_at - time_now + refresh_duration_before_expires = ( + EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 + ) + value = (expire_in - refresh_duration_before_expires) * 1000 + self._start_auto_refresh_token(value) + if self._persist_session and session.expires_at: + self._storage.set_item(self._storage_key, session.json()) + + def _start_auto_refresh_token(self, value: float) -> None: + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + if value <= 0 or not self._auto_refresh_token: + return + + def refresh_token_function(): + self._network_retries += 1 + try: + session = self.get_session() + if session: + self._call_refresh_token(session.refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + self._start_auto_refresh_token( + RETRY_INTERVAL ** (self._network_retries * 100) + ) + + self._refresh_token_timer = Timer(value, refresh_token_function) + self._refresh_token_timer.start() + + def _notify_all_subscribers( + self, + event: AuthChangeEvent, + session: Union[Session, None], + ) -> None: + for subscription in self._state_change_emitters.values(): + subscription.callback(event, session) + + def _get_valid_session( + self, + raw_session: Union[str, None], + ) -> Union[Session, None]: + if not raw_session: + return None + data = loads(raw_session) + if not data: + return None + if not data.get("access_token"): + return None + if not data.get("refresh_token"): + return None + if not data.get("expires_at"): + return None + try: + expires_at = int(data["expires_at"]) + data["expires_at"] = expires_at + except ValueError: + return None + try: + session = Session.parse_obj(data) + return session + except Exception: + return None + + def _get_param( + self, + query_params: Dict[str, List[str]], + name: str, + ) -> Union[str, None]: + if name in query_params: + return query_params[name][0] + return None + + def _is_implicit_grant_flow(self, url: str) -> bool: + result = urlparse(url) + params = parse_qs(result.query) + return "access_token" in params or "error_description" in params + + def _get_url_for_provider( + self, + provider: Provider, + params: Dict[str, str], + ) -> str: + params = {k: quote(v) for k, v in params.items()} + params["provider"] = quote(provider) + query = urlencode(params) + return f"{self._url}/authorize?{query}" + + +def test(): + client = SyncGoTrueClient() + client.initialize() diff --git a/gotrue/constants.py b/gotrue/constants.py index e87dd060..ad001f27 100644 --- a/gotrue/constants.py +++ b/gotrue/constants.py @@ -1,18 +1,12 @@ from __future__ import annotations -from gotrue import __version__ +from . import __version__ GOTRUE_URL = "http://localhost:9999" -AUDIENCE = "" DEFAULT_HEADERS = { "X-Client-Info": f"gotrue-py/{__version__}", } -EXPIRY_MARGIN = 60 * 1000 +EXPIRY_MARGIN = 10 # seconds +MAX_RETRIES = 10 +RETRY_INTERVAL = 2 # deciseconds STORAGE_KEY = "supabase.auth.token" -COOKIE_OPTIONS = { - "name": "sb:token", - "lifetime": 60 * 60 * 8, - "domain": "", - "path": "/", - "same_site": "lax", -} diff --git a/gotrue/errors.py b/gotrue/errors.py new file mode 100644 index 00000000..742d5d44 --- /dev/null +++ b/gotrue/errors.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import Union + +from typing_extensions import TypedDict + + +class AuthError(Exception): + def __init__(self, message: str) -> None: + Exception.__init__(self, message) + self.message = message + self.name = "AuthError" + + +class AuthApiErrorDict(TypedDict): + name: str + message: str + status: int + + +class AuthApiError(AuthError): + def __init__(self, message: str, status: int) -> None: + AuthError.__init__(self, message) + self.name = "AuthApiError" + self.status = status + + def to_dict(self) -> AuthApiErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + } + + +class AuthUnknownError(AuthError): + def __init__(self, message: str, original_error: Exception) -> None: + AuthError.__init__(self, message) + self.name = "AuthUnknownError" + self.original_error = original_error + + +class CustomAuthError(AuthError): + def __init__(self, message: str, name: str, status: int) -> None: + AuthError.__init__(self, message) + self.name = name + self.status = status + + def to_dict(self) -> AuthApiErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + } + + +class AuthSessionMissingError(CustomAuthError): + def __init__(self) -> None: + CustomAuthError.__init__( + self, + "Auth session missing!", + "AuthSessionMissingError", + 400, + ) + + +class AuthInvalidCredentialsError(CustomAuthError): + def __init__(self, message: str) -> None: + CustomAuthError.__init__( + self, + message, + "AuthInvalidCredentialsError", + 400, + ) + + +class AuthImplicitGrantRedirectErrorDetails(TypedDict): + error: str + code: str + + +class AuthImplicitGrantRedirectErrorDict(AuthApiErrorDict): + details: Union[AuthImplicitGrantRedirectErrorDetails, None] + + +class AuthImplicitGrantRedirectError(CustomAuthError): + def __init__( + self, + message: str, + details: Union[AuthImplicitGrantRedirectErrorDetails, None] = None, + ) -> None: + CustomAuthError.__init__( + self, + message, + "AuthImplicitGrantRedirectError", + 500, + ) + self.details = details + + def to_dict(self) -> AuthImplicitGrantRedirectErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + "details": self.details, + } + + +class AuthRetryableError(CustomAuthError): + def __init__(self, message: str, status: int) -> None: + CustomAuthError.__init__( + self, + message, + "AuthRetryableError", + status, + ) diff --git a/gotrue/exceptions.py b/gotrue/exceptions.py deleted file mode 100644 index 9979410f..00000000 --- a/gotrue/exceptions.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from dataclasses import asdict, dataclass -from typing import Any, Dict - - -@dataclass -class APIError(Exception): - msg: str - code: int - - def __post_init__(self) -> None: - self.msg = str(self.msg) - self.code = int(str(self.code)) - - @classmethod - def parse_dict(cls, **json: dict) -> APIError: - ret = cls(msg="Unknown error", code=-1) - for new_name, new_val in json.items(): - setattr(ret, new_name, new_val) - return ret - - @classmethod - def from_dict(cls, data: dict) -> APIError: - if "msg" in data and "code" in data: - return APIError( - msg=data["msg"], - code=data["code"], - ) - if "error" in data and "error_description" in data: - try: - code = int(data["error"]) - except ValueError: - code = -1 - return APIError( - msg=data["error_description"], - code=code, - ) - if "message" in data: - try: - code = int(data.get("code", -1)) - except ValueError: - code = -1 - return APIError( - msg=data["message"], - code=code, - ) - return cls.parse_dict(**data) - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) diff --git a/gotrue/helpers.py b/gotrue/helpers.py index 069b79b7..53a5cf7b 100644 --- a/gotrue/helpers.py +++ b/gotrue/helpers.py @@ -1,18 +1,76 @@ from __future__ import annotations -from urllib.parse import quote +from typing import Any, Union, cast -from httpx import HTTPError, Response +from httpx import HTTPStatusError -from .exceptions import APIError +from .errors import AuthApiError, AuthError, AuthRetryableError, AuthUnknownError +from .types import ( + AuthResponse, + GenerateLinkProperties, + GenerateLinkResponse, + Session, + User, + UserResponse, +) -def encode_uri_component(uri: str) -> str: - return quote(uri.encode("utf-8")) +def parse_auth_response(data: Any) -> AuthResponse: + session: Union[Session, None] = None + if ( + "access_token" in data + and "refresh_token" in data + and "expires_in" in data + and data["access_token"] + and data["refresh_token"] + and data["expires_in"] + ): + session = Session.parse_obj(data) + user = User.parse_obj(data["user"]) if "user" in data else User.parse_obj(data) + return AuthResponse(session=session, user=user) -def check_response(response: Response) -> None: +def parse_link_response(data: Any) -> GenerateLinkResponse: + properties = GenerateLinkProperties( + action_link=data.get("action_link"), + email_otp=data.get("email_otp"), + hashed_token=data.get("hashed_token"), + redirect_to=data.get("redirect_to"), + verification_type=data.get("verification_type"), + ) + user = User.parse_obj({k: v for k, v in data.items() if k not in properties.dict()}) + return GenerateLinkResponse(properties=properties, user=user) + + +def parse_user_response(data: Any) -> UserResponse: + if "user" not in data: + data = {"user": data} + return UserResponse.parse_obj(data) + + +def get_error_message(error: Any) -> str: + props = ["msg", "message", "error_description", "error"] + filter = ( + lambda prop: prop in error if isinstance(error, dict) else hasattr(error, prop) + ) + return next((error[prop] for prop in props if filter(prop)), str(error)) + + +def looks_like_http_status_error(exception: Exception) -> bool: + return isinstance(exception, HTTPStatusError) + + +def handle_exception(exception: Exception) -> AuthError: + if not looks_like_http_status_error(exception): + return AuthRetryableError(get_error_message(exception), 0) + error = cast(HTTPStatusError, exception) try: - response.raise_for_status() - except HTTPError: - raise APIError.from_dict(response.json()) + network_error_codes = [502, 503, 504] + if error.response.status_code in network_error_codes: + return AuthRetryableError( + get_error_message(error), error.response.status_code + ) + json = error.response.json() + return AuthApiError(get_error_message(json), error.response.status_code or 500) + except Exception as e: + return AuthUnknownError(get_error_message(error), e) diff --git a/gotrue/timer.py b/gotrue/timer.py new file mode 100644 index 00000000..1b6bf9b8 --- /dev/null +++ b/gotrue/timer.py @@ -0,0 +1,44 @@ +import asyncio +from threading import Timer as _Timer +from typing import Any, Callable, Coroutine, Union, cast + + +class Timer: + def __init__( + self, + seconds: float, + function: Callable[[], Union[Coroutine[Any, Any, None], None]], + ) -> None: + self._milliseconds = seconds + self._function = function + self._task: Union[asyncio.Task, None] = None + self._timer: Union[_Timer, None] = None + + def start(self) -> None: + if asyncio.iscoroutinefunction(self._function): + + async def schedule(): + await asyncio.sleep(self._milliseconds / 1000) + await cast(Coroutine[Any, Any, None], self._function()) + + def cleanup(_): + self._task = None + + self._task = asyncio.create_task(schedule()) + self._task.add_done_callback(cleanup) + else: + self._timer = _Timer(self._milliseconds / 1000, self._function) + self._timer.start() + + def cancel(self) -> None: + if self._task is not None: + self._task.cancel() + self._task = None + if self._timer is not None: + self._timer.cancel() + self._timer = None + + def is_alive(self) -> bool: + return self._task is not None or ( + self._timer is not None and self._timer.is_alive() + ) diff --git a/gotrue/types.py b/gotrue/types.py index fdfd294d..59d785bd 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -1,167 +1,371 @@ from __future__ import annotations -import sys from datetime import datetime -from enum import Enum from time import time -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union -from uuid import UUID +from typing import Any, Callable, Dict, List, Literal, Union -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - -from httpx import Response from pydantic import BaseModel, root_validator +from typing_extensions import NotRequired, TypedDict + +Provider = Literal[ + "apple", + "azure", + "bitbucket", + "discord", + "facebook", + "github", + "gitlab", + "google", + "keycloak", + "linkedin", + "notion", + "slack", + "spotify", + "twitch", + "twitter", + "workos", +] + +AuthChangeEvent = Literal[ + "PASSWORD_RECOVERY", + "SIGNED_IN", + "SIGNED_OUT", + "TOKEN_REFRESHED", + "USER_UPDATED", + "USER_DELETED", +] + + +class Options(TypedDict, total=False): + redirect_to: str + data: Any + + +class AuthResponse(BaseModel): + user: Union[User, None] = None + session: Union[Session, None] = None + + +class OAuthResponse(BaseModel): + provider: Provider + url: str + + +class UserResponse(BaseModel): + user: User + + +class Session(BaseModel): + provider_token: Union[str, None] = None + """ + The oauth provider token. If present, this can be used to make external API + requests to the oauth provider used. + """ + provider_refresh_token: Union[str, None] = None + """ + The oauth provider refresh token. If present, this can be used to refresh + the provider_token via the oauth provider's API. + + Not all oauth providers return a provider refresh token. If the + provider_refresh_token is missing, please refer to the oauth provider's + documentation for information on how to obtain the provider refresh token. + """ + access_token: str + refresh_token: str + expires_in: int + """ + The number of seconds until the token expires (since it was issued). + Returned when a login is confirmed. + """ + expires_at: Union[int, None] = None + """ + A timestamp of when the token will expire. Returned when a login is confirmed. + """ + token_type: str + user: User -from gotrue.helpers import check_response + @root_validator + def validator(cls, values: dict) -> dict: + expires_in = values.get("expires_in") + if expires_in and not values.get("expires_at"): + values["expires_at"] = round(time()) + expires_in + return values -T = TypeVar("T", bound=BaseModel) +class UserIdentity(BaseModel): + id: str + user_id: str + identity_data: Dict[str, Any] + provider: str + created_at: datetime + last_sign_in_at: datetime + updated_at: Union[datetime, None] = None -def determine_session_or_user_model_from_response( - response: Response, -) -> Union[Type[Session], Type[User]]: - return Session if "access_token" in response.json() else User +class User(BaseModel): + id: str + app_metadata: Dict[str, Any] + user_metadata: Dict[str, Any] + aud: str + confirmation_sent_at: Union[datetime, None] = None + recovery_sent_at: Union[datetime, None] = None + email_change_sent_at: Union[datetime, None] = None + new_email: Union[str, None] = None + invited_at: Union[datetime, None] = None + action_link: Union[str, None] = None + email: Union[str, None] = None + phone: Union[str, None] = None + created_at: datetime + confirmed_at: Union[datetime, None] = None + email_confirmed_at: Union[datetime, None] = None + phone_confirmed_at: Union[datetime, None] = None + last_sign_in_at: Union[datetime, None] = None + role: Union[str, None] = None + updated_at: Union[datetime, None] = None + identities: Union[List[UserIdentity], None] = None -class BaseModelFromResponse(BaseModel): - @classmethod - def parse_response(cls: Type[T], response: Response) -> T: - check_response(response) - return cls.parse_obj(response.json()) +class UserAttributes(TypedDict, total=False): + email: str + phone: str + password: str + data: Any -class CookieOptions(BaseModelFromResponse): - name: str - """The name of the cookie. Defaults to `sb:token`.""" - lifetime: int - """The cookie lifetime (expiration) in seconds. Set to 8 hours by default.""" - domain: str - """The cookie domain this should run on. - Leave it blank to restrict it to your domain.""" - path: str - same_site: str - """SameSite configuration for the session cookie. - Defaults to 'lax', but can be changed to 'strict' or 'none'. - Set it to false if you want to disable the SameSite setting.""" +class AdminUserAttributes(UserAttributes, TypedDict, total=False): + user_metadata: Any + app_metadata: Any + email_confirm: bool + phone_confirm: bool + ban_duration: Union[str, Literal["none"]] -class Identity(BaseModelFromResponse): + +class Subscription(BaseModel): id: str - user_id: UUID - provider: str - created_at: datetime - updated_at: datetime - identity_data: Optional[Dict[str, Any]] = None - last_sign_in_at: Optional[datetime] = None + """ + The subscriber UUID. This will be set by the client. + """ + callback: Callable[[AuthChangeEvent, Union[Session, None]], None] + """ + The function to call every time there is an event. + """ + unsubscribe: Callable[[], None] + """ + Call this to remove the listener. + """ -class User(BaseModelFromResponse): - app_metadata: Dict[str, Any] - aud: str - """The user's audience. Use audiences to group users.""" - created_at: datetime - id: UUID - user_metadata: Dict[str, Any] - identities: Optional[List[Identity]] = None - confirmation_sent_at: Optional[datetime] = None - action_link: Optional[str] = None - last_sign_in_at: Optional[datetime] = None - phone: Optional[str] = None - phone_confirmed_at: Optional[datetime] = None - recovery_sent_at: Optional[datetime] = None - role: Optional[str] = None - updated_at: Optional[datetime] = None - email_confirmed_at: Optional[datetime] = None - confirmed_at: Optional[datetime] = None - invited_at: Optional[datetime] = None - email: Optional[str] = None - new_email: Optional[str] = None - email_change_sent_at: Optional[datetime] = None - new_phone: Optional[str] = None - phone_change_sent_at: Optional[datetime] = None - - -class UserAttributes(BaseModelFromResponse): - email: Optional[str] = None - """The user's email.""" - password: Optional[str] = None - """The user's password.""" - email_change_token: Optional[str] = None - """An email change token.""" - data: Optional[Any] = None - """A custom data object. Can be any JSON.""" - - -class Session(BaseModelFromResponse): - access_token: str - token_type: str - expires_at: Optional[int] = None - """A timestamp of when the token will expire. Returned when a login is confirmed.""" - expires_in: Optional[int] = None - """The number of seconds until the token expires (since it was issued). - Returned when a login is confirmed.""" - provider_token: Optional[str] = None - refresh_token: Optional[str] = None - user: Optional[User] = None +class SignUpWithEmailAndPasswordCredentialsOptions(TypedDict, total=False): + email_redirect_to: str + data: Any + captcha_token: str - @root_validator - def validator(cls, values: dict) -> dict: - expires_in = values.get("expires_in") - if expires_in and not values.get("expires_at"): - values["expires_at"] = round(time()) + expires_in - return values +class SignUpWithEmailAndPasswordCredentials(TypedDict): + email: str + password: str + options: NotRequired[SignUpWithEmailAndPasswordCredentialsOptions] -class AuthChangeEvent(str, Enum): - PASSWORD_RECOVERY = "PASSWORD_RECOVERY" - SIGNED_IN = "SIGNED_IN" - SIGNED_OUT = "SIGNED_OUT" - TOKEN_REFRESHED = "TOKEN_REFRESHED" - USER_UPDATED = "USER_UPDATED" - USER_DELETED = "USER_DELETED" +class SignUpWithPhoneAndPasswordCredentialsOptions(TypedDict, total=False): + data: Any + captcha_token: str -class Subscription(BaseModelFromResponse): - id: UUID - """The subscriber UUID. This will be set by the client.""" - callback: Callable[[AuthChangeEvent, Optional[Session]], None] - """The function to call every time there is an event.""" - unsubscribe: Callable[[], None] - """Call this to remove the listener.""" + +class SignUpWithPhoneAndPasswordCredentials(TypedDict): + phone: str + password: str + options: NotRequired[SignUpWithPhoneAndPasswordCredentialsOptions] + + +SignUpWithPasswordCredentials = Union[ + SignUpWithEmailAndPasswordCredentials, + SignUpWithPhoneAndPasswordCredentials, +] + + +class SignInWithPasswordCredentialsOptions(TypedDict, total=False): + captcha_token: str + + +class SignInWithEmailAndPasswordCredentials(TypedDict): + email: str + password: str + options: NotRequired[SignInWithPasswordCredentialsOptions] + + +class SignInWithPhoneAndPasswordCredentials(TypedDict): + phone: str + password: str + options: NotRequired[SignInWithPasswordCredentialsOptions] + + +SignInWithPasswordCredentials = Union[ + SignInWithEmailAndPasswordCredentials, + SignInWithPhoneAndPasswordCredentials, +] + + +class SignInWithEmailAndPasswordlessCredentialsOptions(TypedDict, total=False): + email_redirect_to: str + should_create_user: bool + data: Any + captcha_token: str + + +class SignInWithEmailAndPasswordlessCredentials(TypedDict): + email: str + options: NotRequired[SignInWithEmailAndPasswordlessCredentialsOptions] + + +class SignInWithPhoneAndPasswordlessCredentialsOptions(TypedDict, total=False): + should_create_user: bool + data: Any + captcha_token: str + + +class SignInWithPhoneAndPasswordlessCredentials(TypedDict): + phone: str + options: NotRequired[SignInWithPhoneAndPasswordlessCredentialsOptions] + + +SignInWithPasswordlessCredentials = Union[ + SignInWithEmailAndPasswordlessCredentials, + SignInWithPhoneAndPasswordlessCredentials, +] + + +class SignInWithOAuthCredentialsOptions(TypedDict, total=False): + redirect_to: str + scopes: str + query_params: Dict[str, str] + + +class SignInWithOAuthCredentials(TypedDict): + provider: Provider + options: NotRequired[SignInWithOAuthCredentialsOptions] + + +class VerifyOtpParamsOptions(TypedDict, total=False): + redirect_to: str + captcha_token: str + + +class VerifyEmailOtpParams(TypedDict): + email: str + token: str + type: Literal[ + "signup", + "invite", + "magiclink", + "recovery", + "email_change", + ] + options: NotRequired[VerifyOtpParamsOptions] + + +class VerifyMobileOtpParams(TypedDict): + phone: str + token: str + type: Literal[ + "sms", + "phone_change", + ] + options: NotRequired[VerifyOtpParamsOptions] + + +VerifyOtpParams = Union[ + VerifyEmailOtpParams, + VerifyMobileOtpParams, +] + + +class GenerateLinkParamsOptions(TypedDict, total=False): + redirect_to: str + + +class GenerateLinkParamsWithDataOptions( + GenerateLinkParamsOptions, + TypedDict, + total=False, +): + data: Any + + +class GenerateSignupLinkParams(TypedDict): + type: Literal["signup"] + email: str + password: str + options: NotRequired[GenerateLinkParamsWithDataOptions] -class Provider(str, Enum): - apple = "apple" - azure = "azure" - bitbucket = "bitbucket" - discord = "discord" - facebook = "facebook" - github = "github" - gitlab = "gitlab" - google = "google" - notion = "notion" - slack = "slack" - spotify = "spotify" - twitter = "twitter" - twitch = "twitch" +class GenerateInviteOrMagiclinkParams(TypedDict): + type: Literal["invite", "magiclink"] + email: str + options: NotRequired[GenerateLinkParamsWithDataOptions] -class LinkType(str, Enum): - """The type of link.""" +class GenerateRecoveryLinkParams(TypedDict): + type: Literal["recovery"] + email: str + options: NotRequired[GenerateLinkParamsOptions] - signup = "signup" - magiclink = "magiclink" - recovery = "recovery" - invite = "invite" +class GenerateEmailChangeLinkParams(TypedDict): + type: Literal["email_change"] + email: str + new_email: str + options: NotRequired[GenerateLinkParamsOptions] -class UserAttributesDict(TypedDict, total=False): - """Dict version of `UserAttributes`""" - email: Optional[str] - password: Optional[str] - email_change_token: Optional[str] - data: Optional[Any] +GenerateLinkParams = Union[ + GenerateSignupLinkParams, + GenerateInviteOrMagiclinkParams, + GenerateRecoveryLinkParams, + GenerateEmailChangeLinkParams, +] + +GenerateLinkType = Literal[ + "signup", + "invite", + "magiclink", + "recovery", + "email_change_current", + "email_change_new", +] + + +class GenerateLinkProperties(BaseModel): + """ + The properties related to the email link generated. + """ + + action_link: str + """ + The email link to send to the user. The action_link follows the following format: + + auth/v1/verify?type={verification_type}&token={hashed_token}&redirect_to={redirect_to} + """ + email_otp: str + """ + The raw email OTP. + You should send this in the email if you want your users to verify using an + OTP instead of the action link. + """ + hashed_token: str + """ + The hashed token appended to the action link. + """ + redirect_to: str + """ + The URL appended to the action link. + """ + verification_type: GenerateLinkType + """ + The verification type that the email link is associated to. + """ + + +class GenerateLinkResponse(BaseModel): + properties: GenerateLinkProperties + user: User diff --git a/tests/_async/test_api_with_auto_confirm_disabled.py b/tests/_async/test_api_with_auto_confirm_disabled.py index 3cd458dd..c780d5bf 100644 --- a/tests/_async/test_api_with_auto_confirm_disabled.py +++ b/tests/_async/test_api_with_auto_confirm_disabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI -from gotrue.constants import COOKIE_OPTIONS -from gotrue.types import CookieOptions, LinkType, User +from ...gotrue import AsyncGoTrueAPI +from ...gotrue.constants import COOKIE_OPTIONS +from ...gotrue.types import CookieOptions, GenerateLinkType, User GOTRUE_URL = "http://localhost:9999" TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwicm9sZSI6InN1cGFiYXNlX2FkbWluIiwiaWF0IjoxNTE2MjM5MDIyfQ.0sOtTSTfPv5oPZxsjvBO249FI4S4p0ymHoIZ6H6z9Y8" # noqa: E501 @@ -47,7 +47,7 @@ async def test_sign_up_with_email_and_password(api: AsyncGoTrueAPI): async def test_generate_sign_up_link(api: AsyncGoTrueAPI): try: response = await api.generate_link( - type=LinkType.signup, + type=GenerateLinkType.signup, email=email2, password=password2, redirect_to="http://localhost:9999/welcome", @@ -64,7 +64,7 @@ async def test_generate_sign_up_link(api: AsyncGoTrueAPI): async def test_generate_magic_link(api: AsyncGoTrueAPI): try: response = await api.generate_link( - type=LinkType.magiclink, + type=GenerateLinkType.magiclink, email=email3, redirect_to="http://localhost:9999/welcome", ) @@ -76,7 +76,7 @@ async def test_generate_magic_link(api: AsyncGoTrueAPI): async def test_generate_invite_link(api: AsyncGoTrueAPI): try: response = await api.generate_link( - type=LinkType.invite, + type=GenerateLinkType.invite, email=email3, redirect_to="http://localhost:9999/welcome", ) @@ -89,7 +89,7 @@ async def test_generate_invite_link(api: AsyncGoTrueAPI): async def test_generate_recovery_link(api: AsyncGoTrueAPI): try: response = await api.generate_link( - type=LinkType.recovery, + type=GenerateLinkType.recovery, email=email, redirect_to="http://localhost:9999/welcome", ) diff --git a/tests/_async/test_api_with_auto_confirm_enabled.py b/tests/_async/test_api_with_auto_confirm_enabled.py index 13ff0411..cf7e6faa 100644 --- a/tests/_async/test_api_with_auto_confirm_enabled.py +++ b/tests/_async/test_api_with_auto_confirm_enabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI -from gotrue.constants import COOKIE_OPTIONS -from gotrue.types import CookieOptions, Session, User +from ...gotrue import AsyncGoTrueAPI +from ...gotrue.constants import COOKIE_OPTIONS +from ...gotrue.types import CookieOptions, Session, User GOTRUE_URL = "http://localhost:9998" TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoic2VydmljZV9yb2xlIiwiaWF0IjoxNjQyMjMyNzUwfQ.TUR8Zu05TtNR25L42soA2trZpc4oBR8-9Pv5r5bvls8" # noqa: E501 diff --git a/tests/_async/test_client_with_auto_confirm_disabled.py b/tests/_async/test_client_with_auto_confirm_disabled.py index c48d597d..4ba87a4f 100644 --- a/tests/_async/test_client_with_auto_confirm_disabled.py +++ b/tests/_async/test_client_with_auto_confirm_disabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueClient -from gotrue.exceptions import APIError -from gotrue.types import User +from ...gotrue import AsyncGoTrueClient +from ...gotrue.errors import APIError +from ...gotrue.types import User GOTRUE_URL = "http://localhost:9999" TEST_TWILIO = False diff --git a/tests/_async/test_client_with_auto_confirm_enabled.py b/tests/_async/test_client_with_auto_confirm_enabled.py index be4eb104..e8dc566d 100644 --- a/tests/_async/test_client_with_auto_confirm_enabled.py +++ b/tests/_async/test_client_with_auto_confirm_enabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueClient -from gotrue.exceptions import APIError -from gotrue.types import Session, User, UserAttributes +from ...gotrue import AsyncGoTrueClient +from ...gotrue.errors import APIError +from ...gotrue.types import Session, User, UserAttributes GOTRUE_URL = "http://localhost:9998" TEST_TWILIO = False diff --git a/tests/_async/test_client_with_sign_ups_disabled.py b/tests/_async/test_client_with_sign_ups_disabled.py index cae4bbd8..1557bb13 100644 --- a/tests/_async/test_client_with_sign_ups_disabled.py +++ b/tests/_async/test_client_with_sign_ups_disabled.py @@ -3,10 +3,10 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI, AsyncGoTrueClient -from gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS -from gotrue.exceptions import APIError -from gotrue.types import CookieOptions, LinkType, User, UserAttributes +from ...gotrue import AsyncGoTrueAPI, AsyncGoTrueClient +from ...gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS +from ...gotrue.errors import APIError +from ...gotrue.types import CookieOptions, GenerateLinkType, User, UserAttributes GOTRUE_URL = "http://localhost:9997" AUTH_ADMIN_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwicm9sZSI6InN1cGFiYXNlX2FkbWluIiwiaWF0IjoxNTE2MjM5MDIyfQ.0sOtTSTfPv5oPZxsjvBO249FI4S4p0ymHoIZ6H6z9Y8" # noqa: E501 @@ -57,7 +57,7 @@ async def test_generate_link_should_be_able_to_generate_multiple_links( ): try: response = await auth_admin.generate_link( - type=LinkType.invite, + type=GenerateLinkType.invite, email=invited_user, redirect_to="http://localhost:9997", ) @@ -77,7 +77,7 @@ async def test_generate_link_should_be_able_to_generate_multiple_links( assert response.identities == [] user = response response = await auth_admin.generate_link( - type=LinkType.invite, + type=GenerateLinkType.invite, email=invited_user, ) assert isinstance(response, User) diff --git a/tests/_async/test_provider.py b/tests/_async/test_provider.py index 9f0df26f..ac7c9866 100644 --- a/tests/_async/test_provider.py +++ b/tests/_async/test_provider.py @@ -2,8 +2,8 @@ import pytest -from gotrue import AsyncGoTrueClient -from gotrue.types import Provider +from ...gotrue import AsyncGoTrueClient +from ...gotrue.types import Provider GOTRUE_URL = "http://localhost:9999" diff --git a/tests/_async/test_subscriptions.py b/tests/_async/test_subscriptions.py index 33b9533e..c4780294 100644 --- a/tests/_async/test_subscriptions.py +++ b/tests/_async/test_subscriptions.py @@ -1,7 +1,7 @@ import pytest -from gotrue import AsyncGoTrueClient -from gotrue.types import Subscription +from ...gotrue import AsyncGoTrueClient +from ...gotrue.types import Subscription GOTRUE_URL = "http://localhost:9999" diff --git a/tests/_sync/test_api_with_auto_confirm_disabled.py b/tests/_sync/test_api_with_auto_confirm_disabled.py index b87f489c..2f4a07c5 100644 --- a/tests/_sync/test_api_with_auto_confirm_disabled.py +++ b/tests/_sync/test_api_with_auto_confirm_disabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI -from gotrue.constants import COOKIE_OPTIONS -from gotrue.types import CookieOptions, LinkType, User +from ...gotrue import SyncGoTrueAPI +from ...gotrue.constants import COOKIE_OPTIONS +from ...gotrue.types import CookieOptions, GenerateLinkType, User GOTRUE_URL = "http://localhost:9999" TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwicm9sZSI6InN1cGFiYXNlX2FkbWluIiwiaWF0IjoxNTE2MjM5MDIyfQ.0sOtTSTfPv5oPZxsjvBO249FI4S4p0ymHoIZ6H6z9Y8" # noqa: E501 @@ -47,7 +47,7 @@ def test_sign_up_with_email_and_password(api: SyncGoTrueAPI): def test_generate_sign_up_link(api: SyncGoTrueAPI): try: response = api.generate_link( - type=LinkType.signup, + type=GenerateLinkType.signup, email=email2, password=password2, redirect_to="http://localhost:9999/welcome", @@ -64,7 +64,7 @@ def test_generate_sign_up_link(api: SyncGoTrueAPI): def test_generate_magic_link(api: SyncGoTrueAPI): try: response = api.generate_link( - type=LinkType.magiclink, + type=GenerateLinkType.magiclink, email=email3, redirect_to="http://localhost:9999/welcome", ) @@ -76,7 +76,7 @@ def test_generate_magic_link(api: SyncGoTrueAPI): def test_generate_invite_link(api: SyncGoTrueAPI): try: response = api.generate_link( - type=LinkType.invite, + type=GenerateLinkType.invite, email=email3, redirect_to="http://localhost:9999/welcome", ) @@ -89,7 +89,7 @@ def test_generate_invite_link(api: SyncGoTrueAPI): def test_generate_recovery_link(api: SyncGoTrueAPI): try: response = api.generate_link( - type=LinkType.recovery, + type=GenerateLinkType.recovery, email=email, redirect_to="http://localhost:9999/welcome", ) diff --git a/tests/_sync/test_api_with_auto_confirm_enabled.py b/tests/_sync/test_api_with_auto_confirm_enabled.py index 578646a8..67edaa81 100644 --- a/tests/_sync/test_api_with_auto_confirm_enabled.py +++ b/tests/_sync/test_api_with_auto_confirm_enabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI -from gotrue.constants import COOKIE_OPTIONS -from gotrue.types import CookieOptions, Session, User +from ...gotrue import SyncGoTrueAPI +from ...gotrue.constants import COOKIE_OPTIONS +from ...gotrue.types import CookieOptions, Session, User GOTRUE_URL = "http://localhost:9998" TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoic2VydmljZV9yb2xlIiwiaWF0IjoxNjQyMjMyNzUwfQ.TUR8Zu05TtNR25L42soA2trZpc4oBR8-9Pv5r5bvls8" # noqa: E501 diff --git a/tests/_sync/test_client_with_auto_confirm_disabled.py b/tests/_sync/test_client_with_auto_confirm_disabled.py index b9a0a4d0..587b5d87 100644 --- a/tests/_sync/test_client_with_auto_confirm_disabled.py +++ b/tests/_sync/test_client_with_auto_confirm_disabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueClient -from gotrue.exceptions import APIError -from gotrue.types import User +from ...gotrue import SyncGoTrueClient +from ...gotrue.errors import APIError +from ...gotrue.types import User GOTRUE_URL = "http://localhost:9999" TEST_TWILIO = False diff --git a/tests/_sync/test_client_with_auto_confirm_enabled.py b/tests/_sync/test_client_with_auto_confirm_enabled.py index e3a96f81..72407be2 100644 --- a/tests/_sync/test_client_with_auto_confirm_enabled.py +++ b/tests/_sync/test_client_with_auto_confirm_enabled.py @@ -3,9 +3,9 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueClient -from gotrue.exceptions import APIError -from gotrue.types import Session, User, UserAttributes +from ...gotrue import SyncGoTrueClient +from ...gotrue.errors import APIError +from ...gotrue.types import Session, User, UserAttributes GOTRUE_URL = "http://localhost:9998" TEST_TWILIO = False diff --git a/tests/_sync/test_client_with_sign_ups_disabled.py b/tests/_sync/test_client_with_sign_ups_disabled.py index 0e0cb9c8..b5b59302 100644 --- a/tests/_sync/test_client_with_sign_ups_disabled.py +++ b/tests/_sync/test_client_with_sign_ups_disabled.py @@ -3,10 +3,10 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI, SyncGoTrueClient -from gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS -from gotrue.exceptions import APIError -from gotrue.types import CookieOptions, LinkType, User, UserAttributes +from ...gotrue import SyncGoTrueAPI, SyncGoTrueClient +from ...gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS +from ...gotrue.errors import APIError +from ...gotrue.types import CookieOptions, GenerateLinkType, User, UserAttributes GOTRUE_URL = "http://localhost:9997" AUTH_ADMIN_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwicm9sZSI6InN1cGFiYXNlX2FkbWluIiwiaWF0IjoxNTE2MjM5MDIyfQ.0sOtTSTfPv5oPZxsjvBO249FI4S4p0ymHoIZ6H6z9Y8" # noqa: E501 @@ -57,7 +57,7 @@ def test_generate_link_should_be_able_to_generate_multiple_links( ): try: response = auth_admin.generate_link( - type=LinkType.invite, + type=GenerateLinkType.invite, email=invited_user, redirect_to="http://localhost:9997", ) @@ -77,7 +77,7 @@ def test_generate_link_should_be_able_to_generate_multiple_links( assert response.identities == [] user = response response = auth_admin.generate_link( - type=LinkType.invite, + type=GenerateLinkType.invite, email=invited_user, ) assert isinstance(response, User) diff --git a/tests/_sync/test_provider.py b/tests/_sync/test_provider.py index bb6dcc38..cc233bf2 100644 --- a/tests/_sync/test_provider.py +++ b/tests/_sync/test_provider.py @@ -2,8 +2,8 @@ import pytest -from gotrue import SyncGoTrueClient -from gotrue.types import Provider +from ...gotrue import SyncGoTrueClient +from ...gotrue.types import Provider GOTRUE_URL = "http://localhost:9999" diff --git a/tests/_sync/test_subscriptions.py b/tests/_sync/test_subscriptions.py index c9df5808..3c0f7238 100644 --- a/tests/_sync/test_subscriptions.py +++ b/tests/_sync/test_subscriptions.py @@ -1,7 +1,7 @@ import pytest -from gotrue import SyncGoTrueClient -from gotrue.types import Subscription +from ...gotrue import SyncGoTrueClient +from ...gotrue.types import Subscription GOTRUE_URL = "http://localhost:9999"