Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT - Needs more feedback] Add support for user specified token lifetimes #829

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions nmdc_runtime/api/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class TokenExpires(BaseModel):
days: Optional[int] = 1
hours: Optional[int] = 0
minutes: Optional[int] = 0
seconds: Optional[int] = 0


ACCESS_TOKEN_EXPIRES = TokenExpires(days=1, hours=0, minutes=0)
ACCESS_TOKEN_EXPIRES = TokenExpires(days=1, hours=0, minutes=0, seconds=0)


class Token(BaseModel):
Expand Down Expand Up @@ -174,14 +175,31 @@ def __init__(
bearer_creds: Optional[HTTPAuthorizationCredentials] = Depends(
bearer_credentials
),
grant_type: str = Form(None, pattern="^password$|^client_credentials$"),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
grant_type: str = Form(
None,
pattern="^password$|^client_credentials$",
description="Select type of login credentials - either `password` or `client_credentials`",
),
username: Optional[str] = Form(
None, description="Username for grant_type `password`"
),
password: Optional[str] = Form(
None, description="Password for grant_type `password`"
),
scope: str = Form(""),
client_id: Optional[str] = Form(None),
client_secret: Optional[str] = Form(None),
client_id: Optional[str] = Form(
None, description="Client ID for grant_type `client_credentials`"
),
client_secret: Optional[str] = Form(
None, description="Client secret for grant_type `client_credentials`"
),
expires: Optional[str] = Form(
None, description="Seconds until token expires (Default 1 day)"
),
):
if bearer_creds:
# TODO: This is never being used since it gets overwritten later on.
# Should we remove this?
self.grant_type = "client_credentials"
self.username, self.password = None, None
self.scopes = scope.split()
Expand All @@ -207,3 +225,4 @@ def __init__(
self.scopes = scope.split()
self.client_id = client_id
self.client_secret = client_secret
self.expires = expires
30 changes: 26 additions & 4 deletions nmdc_runtime/api/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from nmdc_runtime.api.core.auth import (
OAuth2PasswordOrClientCredentialsRequestForm,
Token,
TokenExpires,
ACCESS_TOKEN_EXPIRES,
create_access_token,
ORCID_NMDC_CLIENT_ID,
Expand Down Expand Up @@ -74,6 +75,21 @@ async def login_for_access_token(
form_data: OAuth2PasswordOrClientCredentialsRequestForm = Depends(),
mdb: pymongo.database.Database = Depends(get_mongo_db),
):

if form_data.expires:
expires = int(form_data.expires)
if timedelta(**ACCESS_TOKEN_EXPIRES.model_dump()) - timedelta(
seconds=expires
) < timedelta(seconds=0):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="expires must be less than 86400",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(seconds=expires)
else:
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump())

if form_data.grant_type == "password":
user = authenticate_user(mdb, form_data.username, form_data.password)
if not user:
Expand All @@ -82,7 +98,6 @@ async def login_for_access_token(
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump())
access_token = create_access_token(
data={"sub": f"user:{user.username}"}, expires_delta=access_token_expires
)
Expand Down Expand Up @@ -112,7 +127,6 @@ async def login_for_access_token(
)
user = get_user(mdb, subject)
assert user is not None, "failed to create orcid user"
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump())
access_token = create_access_token(
data={"sub": f"user:{user.username}"},
expires_delta=access_token_expires,
Expand All @@ -131,15 +145,23 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"},
)
# TODO make below an absolute time
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump())
access_token = create_access_token(
data={"sub": f"client:{form_data.client_id}"},
expires_delta=access_token_expires,
)
days, remainder = divmod(access_token_expires.total_seconds(), 86400)
hours, remainder = divmod(remainder, 3600)
minutes, seconds = divmod(remainder, 60)
token_expires = TokenExpires(
days=int(days),
hours=int(hours),
minutes=int(minutes),
seconds=int(seconds),
)
return {
"access_token": access_token,
"token_type": "bearer",
"expires": ACCESS_TOKEN_EXPIRES.model_dump(),
"expires": token_expires.model_dump(),
}


Expand Down
55 changes: 55 additions & 0 deletions tests/test_api/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starlette import status
from tenacity import wait_random_exponential, stop_after_attempt, retry
from toolz import get_in
from datetime import datetime, timezone

from nmdc_runtime.api.core.auth import get_password_hash
from nmdc_runtime.api.core.metadata import df_from_sheet_in, _validate_changesheet
Expand All @@ -27,6 +28,7 @@
RuntimeApiUserClient,
)
from nmdc_runtime.util import REPO_ROOT_DIR, ensure_unique_id_indexes
from jose import jwt


def ensure_schema_collections_and_alldocs():
Expand Down Expand Up @@ -139,6 +141,59 @@ def test_update_operation():
JobOperationMetadata
)

def test_token():
mdb = get_mongo(run_config_frozen__normal_env).db
rs = ensure_test_resources(mdb)
base_url = os.getenv("API_HOST")

@retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
def get_token_response(expires=None):
"""
Fetch an auth token from the Runtime API.
"""
data={
"grant_type": "password",
"username": rs["user"]["username"],
"password": rs["user"]["password"],
}
if expires:
data["expires"] = expires
_rv = requests.post(
base_url + "/token",
data=data,
)
return _rv

# Test default expiration
token = get_token_response().json()
assert token["token_type"] == "bearer"
assert "access_token" in token
assert "expires" in token
assert token["expires"] == {
"days": 1,
"hours": 0,
"minutes": 0,
"seconds": 0
}

# Test custom expiration
token = get_token_response(7382) .json()
assert token["expires"] == {
"days": 0,
"hours": 2,
"minutes": 3,
"seconds": 2
}
access_token = token["access_token"]
# Decode the JWT access token
decoded_token = jwt.get_unverified_claims(access_token)
delta = (datetime.fromtimestamp(decoded_token['exp'], timezone.utc) - datetime.now(timezone.utc)).total_seconds()
# give it margin for error since the operation could take a few seconds
assert delta > 7200 and delta <= 7382

# Test expiration over 24 hours
response = get_token_response(86401)
assert response.status_code == 401

def test_create_user():
mdb = get_mongo(run_config_frozen__normal_env).db
Expand Down