From c571106a2fb793e11fcc9567708795293bb0f408 Mon Sep 17 00:00:00 2001 From: Michael Chouinard <46358556+chouinar@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:58:53 -0500 Subject: [PATCH] [Issue #3129] Validate that the nonce matches in login gov token response (#3211) ## Summary Fixes #3129 ### Time to review: __10 mins__ ## Changes proposed Validate the nonce value when receiving the token back from login.gov Delete the login gov state value after we've used it once Some restructuring to handle the multi-commit approach required here ## Context for reviewers We delete the state object after processing a response that uses it to prevent any sort of replay attack (ie. that state should be used for a single session of logging in, it shouldn't ever be used again). If a user navigates to the login page again later, we'd make a new state anyways, so the data isn't needed anyways. As for the nonce, this is a UUID we generated when we first redirected someone to login.gov - login.gov doesn't echo it back to us until we are parsing the response from the token endpoint. In short, by checking for this, we protect ourselves from replay attacks. ## Additional information See: https://openid.net/specs/openid-connect-core-1_0.html#NonceNotes for more details on the nonce https://en.wikipedia.org/wiki/Replay_attack --- api/src/api/users/user_routes.py | 12 +++- api/src/auth/login_gov_jwt_auth.py | 12 ++-- .../users/login_gov_callback_handler.py | 40 +++++++++--- .../src/api/users/test_user_route_login.py | 64 +++++++++++++++++-- api/tests/src/auth/test_login_gov_jwt_auth.py | 37 ++++++++--- 5 files changed, 136 insertions(+), 29 deletions(-) diff --git a/api/src/api/users/user_routes.py b/api/src/api/users/user_routes.py index 24f426635..6e4e18eb3 100644 --- a/api/src/api/users/user_routes.py +++ b/api/src/api/users/user_routes.py @@ -19,7 +19,10 @@ from src.auth.login_gov_jwt_auth import get_final_redirect_uri, get_login_gov_redirect_uri from src.db.models.user_models import UserTokenSession from src.services.users.get_user import get_user -from src.services.users.login_gov_callback_handler import handle_login_gov_callback +from src.services.users.login_gov_callback_handler import ( + handle_login_gov_callback_request, + handle_login_gov_token, +) logger = logging.getLogger(__name__) @@ -53,8 +56,13 @@ def user_login(db_session: db.Session) -> flask.Response: def user_login_callback(db_session: db.Session, query_data: dict) -> flask.Response: logger.info("GET /v1/users/login/callback") + # We process this in two separate DB transactions + # as we delete state at the end of the first handler + # even if it were to later error to avoid replay attacks + with db_session.begin(): + data = handle_login_gov_callback_request(query_data, db_session) with db_session.begin(): - result = handle_login_gov_callback(query_data, db_session) + result = handle_login_gov_token(db_session, data) # Redirect to the final location for the user return response.redirect_response( diff --git a/api/src/auth/login_gov_jwt_auth.py b/api/src/auth/login_gov_jwt_auth.py index 1804d17c3..0b52fa3a7 100644 --- a/api/src/auth/login_gov_jwt_auth.py +++ b/api/src/auth/login_gov_jwt_auth.py @@ -195,7 +195,7 @@ def get_final_redirect_uri( return f"{config.login_final_destination}?{encoded_params}" -def validate_token(token: str, config: LoginGovConfig | None = None) -> LoginGovUser: +def validate_token(token: str, nonce: str, config: LoginGovConfig | None = None) -> LoginGovUser: if not config: config = get_config() @@ -205,14 +205,14 @@ def validate_token(token: str, config: LoginGovConfig | None = None) -> LoginGov # Iterate over the public keys we have and check each # to determine if we have a valid key. for public_key in config.public_keys: - user = _validate_token_with_key(token, public_key, config) + user = _validate_token_with_key(token, nonce, public_key, config) if user is not None: return user _refresh_keys(config) for public_key in config.public_keys: - user = _validate_token_with_key(token, public_key, config) + user = _validate_token_with_key(token, nonce, public_key, config) if user is not None: return user @@ -220,7 +220,7 @@ def validate_token(token: str, config: LoginGovConfig | None = None) -> LoginGov def _validate_token_with_key( - token: str, public_key: jwt.PyJWK | str, config: LoginGovConfig + token: str, nonce: str, public_key: jwt.PyJWK | str, config: LoginGovConfig ) -> LoginGovUser | None: # We are processing the id_token as described on: # https://developers.login.gov/oidc/token/#token-response @@ -244,6 +244,10 @@ def _validate_token_with_key( ) payload = data.get("payload", {}) + payload_nonce = payload.get("nonce", None) + if payload_nonce != nonce: + raise JwtValidationError("Nonce does not match expected") + user_id = payload["sub"] email = payload["email"] diff --git a/api/src/services/users/login_gov_callback_handler.py b/api/src/services/users/login_gov_callback_handler.py index d4710516b..3e34f81d9 100644 --- a/api/src/services/users/login_gov_callback_handler.py +++ b/api/src/services/users/login_gov_callback_handler.py @@ -26,6 +26,14 @@ class CallbackParams(BaseModel): error_description: str | None = None +@dataclass +class LoginGovDataContainer: + """Holds various login gov related fields we want to pass around""" + + code: str + nonce: str + + @dataclass class LoginGovCallbackResponse: token: str @@ -37,13 +45,14 @@ def get_login_gov_client() -> LoginGovOauthClient: return LoginGovOauthClient() -def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> LoginGovCallbackResponse: +def handle_login_gov_callback_request( + query_data: dict, db_session: db.Session +) -> LoginGovDataContainer: """Handle the callback from login.gov after calling the authenticate endpoint NOTE: Any errors thrown here will actually lead to a redirect due to the with_login_redirect_error_handler handler we have attached to the route """ - # Process the data coming back from login.gov via the redirect query params # see: https://developers.login.gov/oidc/authorization/#authorization-response callback_params = CallbackParams.model_validate(query_data) @@ -75,29 +84,44 @@ def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> Login if login_gov_state is None: raise_flask_error(404, "OAuth state not found") + # We do not want the login_gov_state to be reusable - so delete it + # even if we later error to avoid any replay attacks. + db_session.delete(login_gov_state) + + return LoginGovDataContainer(code=callback_params.code, nonce=str(login_gov_state.nonce)) + + +def handle_login_gov_token( + db_session: db.Session, login_gov_data: LoginGovDataContainer +) -> LoginGovCallbackResponse: + """Fetch user info from login gov, and handle user creation + + NOTE: Any errors thrown here will actually lead to a redirect due to the + with_login_redirect_error_handler handler we have attached to the route + """ + # call the token endpoint (make a client) # https://developers.login.gov/oidc/token/ client = get_login_gov_client() response = client.get_token( OauthTokenRequest( - code=callback_params.code, client_assertion=get_login_gov_client_assertion() + code=login_gov_data.code, client_assertion=get_login_gov_client_assertion() ) ) # If this request failed, we'll assume we're the issue and 500 - # TODO - need to test with actual login.gov if there could be other scenarios - # the mock always returns something as long as the request is well-formatted if response.is_error_response(): raise_flask_error(500, response.error_description) # Process the token response from login.gov - return _process_token(db_session, response.id_token) + # which will create/update a user in the DB + return _process_token(db_session, response.id_token, login_gov_data.nonce) -def _process_token(db_session: db.Session, token: str) -> LoginGovCallbackResponse: +def _process_token(db_session: db.Session, token: str, nonce: str) -> LoginGovCallbackResponse: """Process the token from login.gov and generate our own token for auth""" try: - login_gov_user = validate_token(token) + login_gov_user = validate_token(token, nonce) except JwtValidationError as e: logger.info("Login.gov token validation failed", extra={"auth.issue": e.message}) raise_flask_error(401, e.message) diff --git a/api/tests/src/api/users/test_user_route_login.py b/api/tests/src/api/users/test_user_route_login.py index 36b152403..91241cdfb 100644 --- a/api/tests/src/api/users/test_user_route_login.py +++ b/api/tests/src/api/users/test_user_route_login.py @@ -8,7 +8,7 @@ from src.adapters.oauth.oauth_client_models import OauthTokenResponse from src.api.route_utils import raise_flask_error from src.auth.api_jwt_auth import parse_jwt_for_user -from src.db.models.user_models import LinkExternalUser +from src.db.models.user_models import LinkExternalUser, LoginGovState from src.util import datetime_util from tests.lib.auth_test_utils import create_jwt from tests.src.db.models.factories import LinkExternalUserFactory, LoginGovStateFactory @@ -232,6 +232,7 @@ def test_user_callback_new_user_302( code = str(uuid.uuid4()) id_token = create_jwt( user_id="bob-xyz", + nonce=str(login_gov_state.nonce), private_key=private_rsa_key, ) mock_oauth_client.add_token_response( @@ -267,6 +268,14 @@ def test_user_callback_new_user_302( ) assert external_user is not None + # Make sure the login gov state was deleted + db_state = ( + db_session.query(LoginGovState) + .filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id) + .one_or_none() + ) + assert db_state is None + def test_user_callback_existing_user_302( client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key @@ -282,6 +291,7 @@ def test_user_callback_existing_user_302( code = str(uuid.uuid4()) id_token = create_jwt( user_id=login_gov_id, + nonce=str(login_gov_state.nonce), private_key=private_rsa_key, ) mock_oauth_client.add_token_response( @@ -307,6 +317,14 @@ def test_user_callback_existing_user_302( assert user_token_session.is_valid is True assert user_token_session.user_id == external_user.user_id + # Make sure the login gov state was deleted + db_state = ( + db_session.query(LoginGovState) + .filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id) + .one_or_none() + ) + assert db_state is None + def test_user_callback_unknown_state_302(client, monkeypatch): """Test behavior when we get a redirect back from login.gov with an unknown state value""" @@ -374,13 +392,24 @@ def test_user_callback_error_in_token_302(client, enable_factory_create, caplog) ], ) def test_user_callback_token_fails_validation_302( - client, enable_factory_create, mock_oauth_client, private_rsa_key, jwt_params, error_description + client, + db_session, + enable_factory_create, + mock_oauth_client, + private_rsa_key, + jwt_params, + error_description, ): # Create state so the callback gets past the check login_gov_state = LoginGovStateFactory.create() code = str(uuid.uuid4()) - id_token = create_jwt(user_id=str(uuid.uuid4()), private_key=private_rsa_key, **jwt_params) + id_token = create_jwt( + user_id=str(uuid.uuid4()), + nonce=str(login_gov_state.nonce), + private_key=private_rsa_key, + **jwt_params, + ) mock_oauth_client.add_token_response( code, OauthTokenResponse( @@ -398,9 +427,17 @@ def test_user_callback_token_fails_validation_302( assert resp_json["message"] == "error" assert resp_json["error_description"] == error_description + # Make sure the login gov state was deleted even though it errored + db_state = ( + db_session.query(LoginGovState) + .filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id) + .one_or_none() + ) + assert db_state is None + def test_user_callback_token_fails_validation_bad_token_302( - client, enable_factory_create, mock_oauth_client, private_rsa_key + client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key ): # Create state so the callback gets past the check login_gov_state = LoginGovStateFactory.create() @@ -424,9 +461,17 @@ def test_user_callback_token_fails_validation_bad_token_302( assert resp_json["message"] == "error" assert resp_json["error_description"] == "Unable to process token" + # Make sure the login gov state was deleted even though it errored + db_state = ( + db_session.query(LoginGovState) + .filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id) + .one_or_none() + ) + assert db_state is None + def test_user_callback_token_fails_validation_no_valid_key_302( - client, enable_factory_create, mock_oauth_client, other_rsa_key_pair + client, db_session, enable_factory_create, mock_oauth_client, other_rsa_key_pair ): """Create the token with a different key than we check against""" # Create state so the callback gets past the check @@ -435,6 +480,7 @@ def test_user_callback_token_fails_validation_no_valid_key_302( code = str(uuid.uuid4()) id_token = create_jwt( user_id=str(uuid.uuid4()), + nonce=str(login_gov_state.nonce), private_key=other_rsa_key_pair[0], ) mock_oauth_client.add_token_response( @@ -456,3 +502,11 @@ def test_user_callback_token_fails_validation_no_valid_key_302( resp_json["error_description"] == "Token could not be validated against any public keys from login.gov" ) + + # Make sure the login gov state was deleted even though it errored + db_state = ( + db_session.query(LoginGovState) + .filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id) + .one_or_none() + ) + assert db_state is None diff --git a/api/tests/src/auth/test_login_gov_jwt_auth.py b/api/tests/src/auth/test_login_gov_jwt_auth.py index 3cb2189ec..930418282 100644 --- a/api/tests/src/auth/test_login_gov_jwt_auth.py +++ b/api/tests/src/auth/test_login_gov_jwt_auth.py @@ -10,6 +10,7 @@ DEFAULT_CLIENT_ID = "urn:gov:unit-test" DEFAULT_ISSUER = "http://localhost:3000" +DEFAULT_NONCE = "abc123" @pytest.fixture @@ -35,6 +36,7 @@ def create_jwt( issuer: str = DEFAULT_ISSUER, audience: str = DEFAULT_CLIENT_ID, acr: str = "urn:acr.login.gov:auth-only", + nonce: str = DEFAULT_NONCE, ): payload = { "sub": user_id, @@ -42,6 +44,7 @@ def create_jwt( "acr": acr, "aud": audience, "email": email, + "nonce": nonce, # The jwt encode function automatically turns these datetime # objects into a UTC timestamp integer "exp": expires_at, @@ -50,7 +53,6 @@ def create_jwt( # These values aren't checked by anything at the moment # but are a part of the token from login.gov "jti": "abc123", - "nonce": "abc123", "at_hash": "abc123", "c_hash": "abc123", } @@ -71,7 +73,7 @@ def test_validate_token_happy_path(login_gov_config, private_rsa_key): not_before=datetime.now(tz=timezone.utc) - timedelta(days=1), ) - login_gov_user = validate_token(token, login_gov_config) + login_gov_user = validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) assert login_gov_user.user_id == user_id assert login_gov_user.email == email @@ -88,7 +90,7 @@ def test_validate_token_expired(login_gov_config, private_rsa_key): ) with pytest.raises(JwtValidationError, match="Expired Token"): - validate_token(token, login_gov_config) + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) def test_validate_token_issued_at_future(login_gov_config, private_rsa_key): @@ -102,7 +104,7 @@ def test_validate_token_issued_at_future(login_gov_config, private_rsa_key): ) with pytest.raises(JwtValidationError, match="Token not yet valid"): - validate_token(token, login_gov_config) + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) def test_validate_token_not_before_future(login_gov_config, private_rsa_key): @@ -116,7 +118,7 @@ def test_validate_token_not_before_future(login_gov_config, private_rsa_key): ) with pytest.raises(JwtValidationError, match="Token not yet valid"): - validate_token(token, login_gov_config) + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) def test_validate_token_unknown_issuer(login_gov_config, private_rsa_key): @@ -131,7 +133,7 @@ def test_validate_token_unknown_issuer(login_gov_config, private_rsa_key): ) with pytest.raises(JwtValidationError, match="Unknown Issuer"): - validate_token(token, login_gov_config) + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) def test_validate_token_unknown_audience(login_gov_config, private_rsa_key): @@ -146,7 +148,7 @@ def test_validate_token_unknown_audience(login_gov_config, private_rsa_key): ) with pytest.raises(JwtValidationError, match="Unknown Audience"): - validate_token(token, login_gov_config) + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) def test_validate_token_invalid_signature(login_gov_config, other_rsa_key_pair, monkeypatch): @@ -170,10 +172,10 @@ def override_method(config): JwtValidationError, match="Token could not be validated against any public keys from login.gov", ): - validate_token(token, login_gov_config) + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) -def test_something_with_the_refresh(login_gov_config, other_rsa_key_pair, monkeypatch): +def test_validate_token_key_found_on_refresh(login_gov_config, other_rsa_key_pair, monkeypatch): token = create_jwt( user_id="abc123", email="mail@fake.com", @@ -188,7 +190,7 @@ def override_method(config): monkeypatch.setattr(login_gov_jwt_auth, "_refresh_keys", override_method) - validate_token(token, login_gov_config) + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config) @freezegun.freeze_time("2024-11-14 12:00:00", tz_offset=0) @@ -213,3 +215,18 @@ def test_get_login_gov_client_assertion(login_gov_config, public_rsa_key): assert decoded_jwt["exp"] == timegm( datetime.fromisoformat("2024-11-14 12:05:00+00:00").utctimetuple() ) + + +def test_validate_token_invalid_nonce(login_gov_config, private_rsa_key): + token = create_jwt( + user_id="abc123", + email="mail@fake.com", + nonce="something_else", + private_key=private_rsa_key, + expires_at=datetime.now(tz=timezone.utc) + timedelta(days=30), + issued_at=datetime.now(tz=timezone.utc) - timedelta(days=1), + not_before=datetime.now(tz=timezone.utc) - timedelta(days=1), + ) + + with pytest.raises(JwtValidationError, match="Nonce does not match expected"): + validate_token(token, nonce=DEFAULT_NONCE, config=login_gov_config)