diff --git a/.env.example b/.env.example index 3c38ca77..b9505de9 100644 --- a/.env.example +++ b/.env.example @@ -16,7 +16,7 @@ DO_SPACES_SECRET=generateme JWT_SECRET_KEY=generateme API_HOST=http://fastapi:8000 -API_HOST_EXTERNAL=http://localhost:8000 +API_HOST_EXTERNAL=http://127.0.0.1:8000 API_ADMIN_USER=admin API_ADMIN_PASS=root API_SITE_ID=nmdc-runtime @@ -36,4 +36,5 @@ NMDC_PORTAL_API_BASE_URL=https://data-dev.microbiomedata.org/ NEON_API_TOKEN=y NEON_API_BASE_URL=https://data.neonscience.org/api/v0 -NERSC_USERNAME=replaceme \ No newline at end of file +NERSC_USERNAME=replaceme +ORCID_CLIENT_ID=replaceme \ No newline at end of file diff --git a/nmdc_runtime/api/core/auth.py b/nmdc_runtime/api/core/auth.py index 530727f1..5e4d7c1c 100644 --- a/nmdc_runtime/api/core/auth.py +++ b/nmdc_runtime/api/core/auth.py @@ -16,6 +16,18 @@ SECRET_KEY = os.getenv("JWT_SECRET_KEY") ALGORITHM = "HS256" +ORCID_CLIENT_ID = os.getenv("ORCID_CLIENT_ID") + +# https://orcid.org/.well-known/openid-configuration +# XXX do we want to live-load this? +ORCID_JWK = { # https://orcid.org/oauth/jwks + "e": "AQAB", + "kid": "production-orcid-org-7hdmdswarosg3gjujo8agwtazgkp1ojs", + "kty": "RSA", + "n": "jxTIntA7YvdfnYkLSN4wk__E2zf_wbb0SV_HLHFvh6a9ENVRD1_rHK0EijlBzikb-1rgDQihJETcgBLsMoZVQqGj8fDUUuxnVHsuGav_bf41PA7E_58HXKPrB2C0cON41f7K3o9TStKpVJOSXBrRWURmNQ64qnSSryn1nCxMzXpaw7VUo409ohybbvN6ngxVy4QR2NCC7Fr0QVdtapxD7zdlwx6lEwGemuqs_oG5oDtrRuRgeOHmRps2R6gG5oc-JqVMrVRv6F9h4ja3UgxCDBQjOVT1BFPWmMHnHCsVYLqbbXkZUfvP2sO1dJiYd_zrQhi-FtNth9qrLLv3gkgtwQ", + "use": "sig", +} +ORCID_JWS_VERITY_ALGORITHM = "RS256" class ClientCredentials(BaseModel): @@ -105,11 +117,14 @@ async def __call__(self, request: Request) -> Optional[str]: headers={"WWW-Authenticate": "Bearer"}, ) else: + print(request.url) return None return param -oauth2_scheme = OAuth2PasswordOrClientCredentialsBearer(tokenUrl="token") +oauth2_scheme = OAuth2PasswordOrClientCredentialsBearer( + tokenUrl="token", auto_error=False +) optional_oauth2_scheme = OAuth2PasswordOrClientCredentialsBearer( tokenUrl="token", auto_error=False ) diff --git a/nmdc_runtime/api/endpoints/users.py b/nmdc_runtime/api/endpoints/users.py index 5799ca3c..4f79e752 100644 --- a/nmdc_runtime/api/endpoints/users.py +++ b/nmdc_runtime/api/endpoints/users.py @@ -1,18 +1,28 @@ +import json from datetime import timedelta import pymongo.database from fastapi import Depends, APIRouter, HTTPException, status +from jose import jws, JWTError +from starlette.requests import Request +from starlette.responses import HTMLResponse, RedirectResponse from nmdc_runtime.api.core.auth import ( OAuth2PasswordOrClientCredentialsRequestForm, Token, ACCESS_TOKEN_EXPIRES, create_access_token, + ORCID_CLIENT_ID, + ORCID_JWK, + ORCID_JWS_VERITY_ALGORITHM, + credentials_exception, ) from nmdc_runtime.api.core.auth import get_password_hash +from nmdc_runtime.api.core.util import generate_secret from nmdc_runtime.api.db.mongo import get_mongo_db +from nmdc_runtime.api.endpoints.util import BASE_URL_EXTERNAL from nmdc_runtime.api.models.site import authenticate_site_client -from nmdc_runtime.api.models.user import UserInDB, UserIn +from nmdc_runtime.api.models.user import UserInDB, UserIn, get_user from nmdc_runtime.api.models.user import ( authenticate_user, User, @@ -22,6 +32,45 @@ router = APIRouter() +@router.get("/orcid_authorize") +async def orcid_authorize(): + """NOTE: You want to load /orcid_authorize directly in your web browser to initiate the login redirect flow.""" + return RedirectResponse( + f"https://orcid.org/oauth/authorize?client_id={ORCID_CLIENT_ID}" + "&response_type=token&scope=openid&" + f"redirect_uri={BASE_URL_EXTERNAL}/orcid_token" + ) + + +@router.get("/orcid_token") +async def redirect_uri_for_orcid_token(req: Request): + """ + Returns a web page that will display a user's orcid jwt token for copy/paste. + + This route is loaded by orcid.org after a successful orcid user login. + """ + return HTMLResponse( + """ +
+ + + + + + + """ + ) + + @router.post("/token", response_model=Token) async def login_for_access_token( form_data: OAuth2PasswordOrClientCredentialsRequestForm = Depends(), @@ -40,21 +89,55 @@ async def login_for_access_token( data={"sub": f"user:{user.username}"}, expires_delta=access_token_expires ) else: # form_data.grant_type == "client_credentials" - site = authenticate_site_client( - mdb, form_data.client_id, form_data.client_secret - ) - if not site: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect client_id or client_secret", - headers={"WWW-Authenticate": "Bearer"}, + # If the HTTP request didn't include a Client Secret, we validate the Client ID as an ORCID JWT. + # We get a username from that ORCID JWT and fetch the corresponding user record from our database, + # creating that user record if it doesn't already exist. + if not form_data.client_secret: + try: + payload = jws.verify( + form_data.client_id, + ORCID_JWK, + algorithms=[ORCID_JWS_VERITY_ALGORITHM], + ) + payload = json.loads(payload.decode()) + issuer: str = payload.get("iss") + if issuer != "https://orcid.org": + raise credentials_exception + subject: str = payload.get("sub") + user = get_user(mdb, subject) + if user is None: + mdb.users.insert_one( + UserInDB( + username=subject, + hashed_password=get_password_hash(generate_secret()), + ).model_dump(exclude_unset=True) + ) + user = get_user(mdb, subject) + assert user is not None, "failed to create orcid user" + access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump()) + access_token = create_access_token( + data={"sub": f"user:{user.username}"}, + expires_delta=access_token_expires, + ) + + except JWTError: + raise credentials_exception + else: # form_data.client_secret + site = authenticate_site_client( + mdb, form_data.client_id, form_data.client_secret + ) + if not site: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect client_id or client_secret", + headers={"WWW-Authenticate": "Bearer"}, + ) + # TODO make below an absolute time + access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump()) + access_token = create_access_token( + data={"sub": f"client:{form_data.client_id}"}, + expires_delta=access_token_expires, ) - # TODO make below an absolute time - access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump()) - access_token = create_access_token( - data={"sub": f"client:{form_data.client_id}"}, - expires_delta=access_token_expires, - ) return { "access_token": access_token, "token_type": "bearer", diff --git a/nmdc_runtime/api/models/user.py b/nmdc_runtime/api/models/user.py index f5323b6d..0a96e2eb 100644 --- a/nmdc_runtime/api/models/user.py +++ b/nmdc_runtime/api/models/user.py @@ -62,7 +62,8 @@ async def get_current_user( raise credentials_exception username = subject.split("user:", 1)[1] token_data = TokenData(subject=username) - except JWTError: + except JWTError as e: + print(f"jwt error: {e}") raise credentials_exception user = get_user(mdb, username=token_data.subject) if user is None: