Skip to content

Commit

Permalink
✨ Invitations to register to product (#4739)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcrespov authored Nov 16, 2023
1 parent 163d47f commit 1a4b614
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
ApiInvitationContentAndLink,
ApiInvitationInputs,
)
from models_library.users import GroupID
from pydantic import AnyHttpUrl, ValidationError, parse_obj_as
from servicelib.error_codes import create_error_code
from simcore_postgres_database.models.groups import user_to_groups
from simcore_postgres_database.models.users import users

from ..db.plugin import get_database_engine
from ..products.api import Product
from ._client import InvitationsServiceApi, get_invitations_service_api
from .errors import (
MSG_INVALID_INVITATION_URL,
Expand All @@ -26,14 +29,31 @@
_logger = logging.getLogger(__name__)


async def _is_user_registered(app: web.Application, email: str) -> bool:
async def _is_user_registered_in_platform(app: web.Application, email: str) -> bool:
pg_engine = get_database_engine(app=app)

async with pg_engine.acquire() as conn:
user_id = await conn.scalar(sa.select(users.c.id).where(users.c.email == email))
return user_id is not None


async def _is_user_registered_in_product(
app: web.Application, email: str, product_group_id: GroupID
) -> bool:
pg_engine = get_database_engine(app=app)

async with pg_engine.acquire() as conn:
user_id = await conn.scalar(
sa.select(users.c.id)
.select_from(
sa.join(user_to_groups, users, user_to_groups.c.uid == users.c.id)
)
.where(
(users.c.email == email) & (user_to_groups.c.gid == product_group_id)
)
)
return user_id is not None


@contextmanager
def _handle_exceptions_as_invitations_errors():
try:
Expand All @@ -56,6 +76,7 @@ def _handle_exceptions_as_invitations_errors():
raise InvitationsServiceUnavailable from err

except (ValidationError, ClientError) as err:
_logger.debug("Invitations error %s", f"{err}")
raise InvitationsServiceUnavailable from err

except InvitationsErrors:
Expand All @@ -80,12 +101,21 @@ def is_service_invitation_code(code: str):


async def validate_invitation_url(
app: web.Application, guest_email: str, invitation_url: str
app: web.Application,
*,
current_product: Product,
guest_email: str,
invitation_url: str,
) -> ApiInvitationContent:
"""Validates invitation and associated email/user and returns content upon success
raises InvitationsError
"""
if current_product.group_id is None:
raise InvitationsServiceUnavailable(
reason="Current product is not configured for invitations"
)

invitations_service: InvitationsServiceApi = get_invitations_service_api(app=app)

with _handle_exceptions_as_invitations_errors():
Expand All @@ -99,13 +129,29 @@ async def validate_invitation_url(
invitation_url=valid_url
)

# check email
if invitation.guest != guest_email:
raise InvalidInvitation(
reason="This invitation was issued for a different email"
)

# existing users cannot be re-invited
if await _is_user_registered(app=app, email=invitation.guest):
# check product
assert current_product.group_id is not None # nosec
if (
invitation.product is not None
and invitation.product != current_product.name
):
raise InvalidInvitation(
reason="This invitation was issued for a different product. "
f"Got '{invitation.product}', expected '{current_product.name}'"
)

# check invitation used
assert invitation.product == current_product.name # nosec
if await _is_user_registered_in_product(
app=app, email=invitation.guest, product_group_id=current_product.group_id
):
# NOTE: a user might be already registered but the invitation is for another product
raise InvalidInvitation(reason=MSG_INVITATION_ALREADY_USED)

return invitation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .._constants import APP_SETTINGS_KEY
from ..db.plugin import setup_db
from ..products.plugin import setup_products
from ._client import invitations_service_api_cleanup_ctx

_logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from datetime import datetime

from aiohttp import web
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON

from ..products.api import Product
from ..security.api import check_password, encrypt_password
from ._constants import MSG_UNKNOWN_EMAIL, MSG_WRONG_PASSWORD
from .storage import AsyncpgStorage, get_plugin_storage
from .utils import USER, get_user_name_from_email, validate_user_status


async def get_user_by_email(app: web.Application, *, email: str) -> dict:
db: AsyncpgStorage = get_plugin_storage(app)
user: dict = await db.get_user({"email": email})
return user


async def create_user(
app: web.Application,
*,
email: str,
password: str,
status: str,
expires_at: datetime | None
) -> dict:
db: AsyncpgStorage = get_plugin_storage(app)

user: dict = await db.create_user(
{
"name": get_user_name_from_email(email),
"email": email,
"password_hash": encrypt_password(password),
"status": status,
"role": USER,
"expires_at": expires_at,
}
)
return user


async def check_authorized_user_or_raise(
user: dict,
password: str,
product: Product,
) -> dict:

if not user:
raise web.HTTPUnauthorized(
reason=MSG_UNKNOWN_EMAIL, content_type=MIMETYPE_APPLICATION_JSON
)

validate_user_status(user=user, support_email=product.support_email)

if not check_password(password, user["password_hash"]):
raise web.HTTPUnauthorized(
reason=MSG_WRONG_PASSWORD, content_type=MIMETYPE_APPLICATION_JSON
)

return user
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aiohttp import web
from models_library.basic_types import IdInt
from models_library.emails import LowerCaseEmailStr
from models_library.products import ProductName
from pydantic import BaseModel, Field, Json, PositiveInt, ValidationError, validator
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from simcore_postgres_database.models.confirmations import ConfirmationAction
Expand All @@ -23,6 +24,7 @@
validate_invitation_url,
)
from ..invitations.errors import InvalidInvitation, InvitationsServiceUnavailable
from ..products.api import Product
from ._confirmation import is_confirmation_expired, validate_confirmation_code
from ._constants import MSG_EMAIL_EXISTS, MSG_INVITATIONS_CONTACT_SUFFIX
from .settings import LoginOptions
Expand Down Expand Up @@ -51,6 +53,7 @@ class InvitationData(BaseModel):
"Sets the number of days from creation until the account expires",
)
extra_credits_in_usd: PositiveInt | None = None
product: ProductName | None = None


class _InvitationValidator(BaseModel):
Expand Down Expand Up @@ -190,6 +193,7 @@ async def extract_email_from_invitation(
async def check_and_consume_invitation(
invitation_code: str,
guest_email: str,
product: Product,
db: AsyncpgStorage,
cfg: LoginOptions,
app: web.Application,
Expand All @@ -207,6 +211,7 @@ async def check_and_consume_invitation(
with _invitations_request_context(invitation_code=invitation_code) as url:
content = await validate_invitation_url(
app,
current_product=product,
guest_email=guest_email,
invitation_url=f"{url}",
)
Expand All @@ -219,6 +224,7 @@ async def check_and_consume_invitation(
guest=content.guest,
trial_account_days=content.trial_account_days,
extra_credits_in_usd=content.extra_credits_in_usd,
product=content.product,
)

# database-type invitations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .._meta import API_VTAG
from ..products.api import Product, get_current_product
from ..security.api import check_password, forget
from ..security.api import forget
from ..session.access_policies import (
on_success_grant_session_access_to,
session_access_required,
Expand All @@ -27,6 +27,7 @@
mask_phone_number,
send_sms_code,
)
from ._auth_api import check_authorized_user_or_raise, get_user_by_email
from ._constants import (
CODE_2FA_CODE_REQUIRED,
CODE_PHONE_NUMBER_REQUIRED,
Expand All @@ -37,22 +38,14 @@
MSG_LOGGED_OUT,
MSG_PHONE_MISSING,
MSG_UNAUTHORIZED_LOGIN_2FA,
MSG_UNKNOWN_EMAIL,
MSG_WRONG_2FA_CODE,
MSG_WRONG_PASSWORD,
)
from ._models import InputSchema
from ._security import login_granted_response
from .decorators import login_required
from .settings import LoginSettingsForProduct, get_plugin_settings
from .storage import AsyncpgStorage, get_plugin_storage
from .utils import (
ACTIVE,
envelope_response,
flash_response,
notify_user_logout,
validate_user_status,
)
from .utils import envelope_response, flash_response, notify_user_logout

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,25 +91,13 @@ async def login(request: web.Request):
settings: LoginSettingsForProduct = get_plugin_settings(
request.app, product_name=product.name
)
db: AsyncpgStorage = get_plugin_storage(request.app)

login_ = await parse_request_body_as(LoginBody, request)

user = await db.get_user({"email": login_.email})
if not user:
raise web.HTTPUnauthorized(
reason=MSG_UNKNOWN_EMAIL, content_type=MIMETYPE_APPLICATION_JSON
)

validate_user_status(user=user, support_email=product.support_email)

if not check_password(login_.password.get_secret_value(), user["password_hash"]):
raise web.HTTPUnauthorized(
reason=MSG_WRONG_PASSWORD, content_type=MIMETYPE_APPLICATION_JSON
)

assert user["status"] == ACTIVE, "db corrupted. Invalid status" # nosec
assert user["email"] == login_.email, "db corrupted. Invalid email" # nosec
user = await check_authorized_user_or_raise(
user=await get_user_by_email(request.app, email=login_.email),
password=login_.password.get_secret_value(),
product=product,
)

# Some roles have login privileges
has_privileges: Final[bool] = UserRole(user["role"]) > UserRole.USER
Expand Down
Loading

0 comments on commit 1a4b614

Please sign in to comment.