Skip to content

Commit

Permalink
[Issue #3129] Validate that the nonce matches in login gov token resp…
Browse files Browse the repository at this point in the history
…onse (#3211)

## Summary
Fixes #3129

### Time to review: __10 mins__

## Changes proposed
Validate the nonce value when receiving the token back from login.gov

Delete the login gov state value after we've used it once

Some restructuring to handle the multi-commit approach required here

## Context for reviewers
We delete the state object after processing a response that uses it to
prevent any sort of replay attack (ie. that state should be used for a
single session of logging in, it shouldn't ever be used again). If a
user navigates to the login page again later, we'd make a new state
anyways, so the data isn't needed anyways.

As for the nonce, this is a UUID we generated when we first redirected
someone to login.gov - login.gov doesn't echo it back to us until we are
parsing the response from the token endpoint. In short, by checking for
this, we protect ourselves from replay attacks.

## Additional information
See: https://openid.net/specs/openid-connect-core-1_0.html#NonceNotes
for more details on the nonce
https://en.wikipedia.org/wiki/Replay_attack
  • Loading branch information
chouinar authored Dec 16, 2024
1 parent eb45d02 commit c571106
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 29 deletions.
12 changes: 10 additions & 2 deletions api/src/api/users/user_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from src.auth.login_gov_jwt_auth import get_final_redirect_uri, get_login_gov_redirect_uri
from src.db.models.user_models import UserTokenSession
from src.services.users.get_user import get_user
from src.services.users.login_gov_callback_handler import handle_login_gov_callback
from src.services.users.login_gov_callback_handler import (
handle_login_gov_callback_request,
handle_login_gov_token,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,8 +56,13 @@ def user_login(db_session: db.Session) -> flask.Response:
def user_login_callback(db_session: db.Session, query_data: dict) -> flask.Response:
logger.info("GET /v1/users/login/callback")

# We process this in two separate DB transactions
# as we delete state at the end of the first handler
# even if it were to later error to avoid replay attacks
with db_session.begin():
data = handle_login_gov_callback_request(query_data, db_session)
with db_session.begin():
result = handle_login_gov_callback(query_data, db_session)
result = handle_login_gov_token(db_session, data)

# Redirect to the final location for the user
return response.redirect_response(
Expand Down
12 changes: 8 additions & 4 deletions api/src/auth/login_gov_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_final_redirect_uri(
return f"{config.login_final_destination}?{encoded_params}"


def validate_token(token: str, config: LoginGovConfig | None = None) -> LoginGovUser:
def validate_token(token: str, nonce: str, config: LoginGovConfig | None = None) -> LoginGovUser:
if not config:
config = get_config()

Expand All @@ -205,22 +205,22 @@ def validate_token(token: str, config: LoginGovConfig | None = None) -> LoginGov
# Iterate over the public keys we have and check each
# to determine if we have a valid key.
for public_key in config.public_keys:
user = _validate_token_with_key(token, public_key, config)
user = _validate_token_with_key(token, nonce, public_key, config)
if user is not None:
return user

_refresh_keys(config)

for public_key in config.public_keys:
user = _validate_token_with_key(token, public_key, config)
user = _validate_token_with_key(token, nonce, public_key, config)
if user is not None:
return user

raise JwtValidationError("Token could not be validated against any public keys from login.gov")


def _validate_token_with_key(
token: str, public_key: jwt.PyJWK | str, config: LoginGovConfig
token: str, nonce: str, public_key: jwt.PyJWK | str, config: LoginGovConfig
) -> LoginGovUser | None:
# We are processing the id_token as described on:
# https://developers.login.gov/oidc/token/#token-response
Expand All @@ -244,6 +244,10 @@ def _validate_token_with_key(
)
payload = data.get("payload", {})

payload_nonce = payload.get("nonce", None)
if payload_nonce != nonce:
raise JwtValidationError("Nonce does not match expected")

user_id = payload["sub"]
email = payload["email"]

Expand Down
40 changes: 32 additions & 8 deletions api/src/services/users/login_gov_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ class CallbackParams(BaseModel):
error_description: str | None = None


@dataclass
class LoginGovDataContainer:
"""Holds various login gov related fields we want to pass around"""

code: str
nonce: str


@dataclass
class LoginGovCallbackResponse:
token: str
Expand All @@ -37,13 +45,14 @@ def get_login_gov_client() -> LoginGovOauthClient:
return LoginGovOauthClient()


def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> LoginGovCallbackResponse:
def handle_login_gov_callback_request(
query_data: dict, db_session: db.Session
) -> LoginGovDataContainer:
"""Handle the callback from login.gov after calling the authenticate endpoint
NOTE: Any errors thrown here will actually lead to a redirect due to the
with_login_redirect_error_handler handler we have attached to the route
"""

# Process the data coming back from login.gov via the redirect query params
# see: https://developers.login.gov/oidc/authorization/#authorization-response
callback_params = CallbackParams.model_validate(query_data)
Expand Down Expand Up @@ -75,29 +84,44 @@ def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> Login
if login_gov_state is None:
raise_flask_error(404, "OAuth state not found")

# We do not want the login_gov_state to be reusable - so delete it
# even if we later error to avoid any replay attacks.
db_session.delete(login_gov_state)

return LoginGovDataContainer(code=callback_params.code, nonce=str(login_gov_state.nonce))


def handle_login_gov_token(
db_session: db.Session, login_gov_data: LoginGovDataContainer
) -> LoginGovCallbackResponse:
"""Fetch user info from login gov, and handle user creation
NOTE: Any errors thrown here will actually lead to a redirect due to the
with_login_redirect_error_handler handler we have attached to the route
"""

# call the token endpoint (make a client)
# https://developers.login.gov/oidc/token/
client = get_login_gov_client()
response = client.get_token(
OauthTokenRequest(
code=callback_params.code, client_assertion=get_login_gov_client_assertion()
code=login_gov_data.code, client_assertion=get_login_gov_client_assertion()
)
)

# If this request failed, we'll assume we're the issue and 500
# TODO - need to test with actual login.gov if there could be other scenarios
# the mock always returns something as long as the request is well-formatted
if response.is_error_response():
raise_flask_error(500, response.error_description)

# Process the token response from login.gov
return _process_token(db_session, response.id_token)
# which will create/update a user in the DB
return _process_token(db_session, response.id_token, login_gov_data.nonce)


def _process_token(db_session: db.Session, token: str) -> LoginGovCallbackResponse:
def _process_token(db_session: db.Session, token: str, nonce: str) -> LoginGovCallbackResponse:
"""Process the token from login.gov and generate our own token for auth"""
try:
login_gov_user = validate_token(token)
login_gov_user = validate_token(token, nonce)
except JwtValidationError as e:
logger.info("Login.gov token validation failed", extra={"auth.issue": e.message})
raise_flask_error(401, e.message)
Expand Down
64 changes: 59 additions & 5 deletions api/tests/src/api/users/test_user_route_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from src.adapters.oauth.oauth_client_models import OauthTokenResponse
from src.api.route_utils import raise_flask_error
from src.auth.api_jwt_auth import parse_jwt_for_user
from src.db.models.user_models import LinkExternalUser
from src.db.models.user_models import LinkExternalUser, LoginGovState
from src.util import datetime_util
from tests.lib.auth_test_utils import create_jwt
from tests.src.db.models.factories import LinkExternalUserFactory, LoginGovStateFactory
Expand Down Expand Up @@ -232,6 +232,7 @@ def test_user_callback_new_user_302(
code = str(uuid.uuid4())
id_token = create_jwt(
user_id="bob-xyz",
nonce=str(login_gov_state.nonce),
private_key=private_rsa_key,
)
mock_oauth_client.add_token_response(
Expand Down Expand Up @@ -267,6 +268,14 @@ def test_user_callback_new_user_302(
)
assert external_user is not None

# Make sure the login gov state was deleted
db_state = (
db_session.query(LoginGovState)
.filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id)
.one_or_none()
)
assert db_state is None


def test_user_callback_existing_user_302(
client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key
Expand All @@ -282,6 +291,7 @@ def test_user_callback_existing_user_302(
code = str(uuid.uuid4())
id_token = create_jwt(
user_id=login_gov_id,
nonce=str(login_gov_state.nonce),
private_key=private_rsa_key,
)
mock_oauth_client.add_token_response(
Expand All @@ -307,6 +317,14 @@ def test_user_callback_existing_user_302(
assert user_token_session.is_valid is True
assert user_token_session.user_id == external_user.user_id

# Make sure the login gov state was deleted
db_state = (
db_session.query(LoginGovState)
.filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id)
.one_or_none()
)
assert db_state is None


def test_user_callback_unknown_state_302(client, monkeypatch):
"""Test behavior when we get a redirect back from login.gov with an unknown state value"""
Expand Down Expand Up @@ -374,13 +392,24 @@ def test_user_callback_error_in_token_302(client, enable_factory_create, caplog)
],
)
def test_user_callback_token_fails_validation_302(
client, enable_factory_create, mock_oauth_client, private_rsa_key, jwt_params, error_description
client,
db_session,
enable_factory_create,
mock_oauth_client,
private_rsa_key,
jwt_params,
error_description,
):
# Create state so the callback gets past the check
login_gov_state = LoginGovStateFactory.create()

code = str(uuid.uuid4())
id_token = create_jwt(user_id=str(uuid.uuid4()), private_key=private_rsa_key, **jwt_params)
id_token = create_jwt(
user_id=str(uuid.uuid4()),
nonce=str(login_gov_state.nonce),
private_key=private_rsa_key,
**jwt_params,
)
mock_oauth_client.add_token_response(
code,
OauthTokenResponse(
Expand All @@ -398,9 +427,17 @@ def test_user_callback_token_fails_validation_302(
assert resp_json["message"] == "error"
assert resp_json["error_description"] == error_description

# Make sure the login gov state was deleted even though it errored
db_state = (
db_session.query(LoginGovState)
.filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id)
.one_or_none()
)
assert db_state is None


def test_user_callback_token_fails_validation_bad_token_302(
client, enable_factory_create, mock_oauth_client, private_rsa_key
client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key
):
# Create state so the callback gets past the check
login_gov_state = LoginGovStateFactory.create()
Expand All @@ -424,9 +461,17 @@ def test_user_callback_token_fails_validation_bad_token_302(
assert resp_json["message"] == "error"
assert resp_json["error_description"] == "Unable to process token"

# Make sure the login gov state was deleted even though it errored
db_state = (
db_session.query(LoginGovState)
.filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id)
.one_or_none()
)
assert db_state is None


def test_user_callback_token_fails_validation_no_valid_key_302(
client, enable_factory_create, mock_oauth_client, other_rsa_key_pair
client, db_session, enable_factory_create, mock_oauth_client, other_rsa_key_pair
):
"""Create the token with a different key than we check against"""
# Create state so the callback gets past the check
Expand All @@ -435,6 +480,7 @@ def test_user_callback_token_fails_validation_no_valid_key_302(
code = str(uuid.uuid4())
id_token = create_jwt(
user_id=str(uuid.uuid4()),
nonce=str(login_gov_state.nonce),
private_key=other_rsa_key_pair[0],
)
mock_oauth_client.add_token_response(
Expand All @@ -456,3 +502,11 @@ def test_user_callback_token_fails_validation_no_valid_key_302(
resp_json["error_description"]
== "Token could not be validated against any public keys from login.gov"
)

# Make sure the login gov state was deleted even though it errored
db_state = (
db_session.query(LoginGovState)
.filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id)
.one_or_none()
)
assert db_state is None
Loading

0 comments on commit c571106

Please sign in to comment.