Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
MSC2918 Refresh tokens implementation (#9450)
Browse files Browse the repository at this point in the history
This implements refresh tokens, as defined by MSC2918

This MSC has been implemented client side in Hydrogen Web: element-hq/hydrogen-web#235

The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one.

Signed-off-by: Quentin Gliech <[email protected]>
  • Loading branch information
sandhose authored Jun 24, 2021
1 parent 763dba7 commit bd4919f
Show file tree
Hide file tree
Showing 15 changed files with 892 additions and 61 deletions.
1 change: 1 addition & 0 deletions changelog.d/9450.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement refresh tokens as specified by [MSC2918](https://github.com/matrix-org/matrix-doc/pull/2918).
4 changes: 3 additions & 1 deletion scripts/synapse_port_db
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ BOOLEAN_COLUMNS = {
"local_media_repository": ["safe_from_quarantine"],
"users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"],
}


Expand Down Expand Up @@ -307,7 +308,8 @@ class Porter(object):
information_schema.table_constraints AS tc
INNER JOIN information_schema.constraint_column_usage AS ccu
USING (table_schema, constraint_name)
WHERE tc.constraint_type = 'FOREIGN KEY';
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name != ccu.table_name;
"""
txn.execute(sql)

Expand Down
5 changes: 5 additions & 0 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ async def get_user_by_req(
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)

# Mark the token as used. This is used to invalidate old refresh
# tokens after some time.
if not user_info.token_used and token_id is not None:
await self.store.mark_access_token_as_used(token_id)

requester = create_requester(
user_info.user_id,
token_id,
Expand Down
21 changes: 21 additions & 0 deletions synapse/config/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ def read_config(self, config, **kwargs):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime

# The `access_token_lifetime` applies for tokens that can be renewed
# using a refresh token, as per MSC2918. If it is `None`, the refresh
# token mechanism is disabled.
#
# Since it is incompatible with the `session_lifetime` mechanism, it is set to
# `None` by default if a `session_lifetime` is set.
access_token_lifetime = config.get(
"access_token_lifetime", "5m" if session_lifetime is None else None
)
if access_token_lifetime is not None:
access_token_lifetime = self.parse_duration(access_token_lifetime)
self.access_token_lifetime = access_token_lifetime

if session_lifetime is not None and access_token_lifetime is not None:
raise ConfigError(
"The refresh token mechanism is incompatible with the "
"`session_lifetime` option. Consider disabling the "
"`session_lifetime` option or disabling the refresh token "
"mechanism by removing the `access_token_lifetime` option."
)

# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")

Expand Down
132 changes: 127 additions & 5 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Optional,
Tuple,
Union,
cast,
)

import attr
Expand Down Expand Up @@ -72,6 +73,7 @@
from synapse.util.threepids import canonicalise_email

if TYPE_CHECKING:
from synapse.rest.client.v1.login import LoginResponse
from synapse.server import HomeServer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -777,13 +779,116 @@ def _auth_dict_for_flows(
"params": params,
}

async def refresh_token(
self,
refresh_token: str,
valid_until_ms: Optional[int],
) -> Tuple[str, str]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
Args:
refresh_token: The token to consume.
valid_until_ms: The expiration timestamp of the new access token.
Returns:
A tuple containing the new access token and refresh token
"""

# Verify the token signature first before looking up the token
if not self._verify_refresh_token(refresh_token):
raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)

existing_token = await self.store.lookup_refresh_token(refresh_token)
if existing_token is None:
raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)

if (
existing_token.has_next_access_token_been_used
or existing_token.has_next_refresh_token_been_refreshed
):
raise SynapseError(
403, "refresh token isn't valid anymore", Codes.FORBIDDEN
)

(
new_refresh_token,
new_refresh_token_id,
) = await self.get_refresh_token_for_user_id(
user_id=existing_token.user_id, device_id=existing_token.device_id
)
access_token = await self.get_access_token_for_user_id(
user_id=existing_token.user_id,
device_id=existing_token.device_id,
valid_until_ms=valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.replace_refresh_token(
existing_token.token_id, new_refresh_token_id
)
return access_token, new_refresh_token

def _verify_refresh_token(self, token: str) -> bool:
"""
Verifies the shape of a refresh token.
Args:
token: The refresh token to verify
Returns:
Whether the token has the right shape
"""
parts = token.split("_", maxsplit=4)
if len(parts) != 4:
return False

type, localpart, rand, crc = parts

# Refresh tokens are prefixed by "syr_", let's check that
if type != "syr":
return False

# Check the CRC
base = f"{type}_{localpart}_{rand}"
expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
if crc != expected_crc:
return False

return True

async def get_refresh_token_for_user_id(
self,
user_id: str,
device_id: str,
) -> Tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
Args:
user_id: canonical user ID
device_id: the device ID to associate with the token.
Returns:
The newly created refresh token and its ID in the database
"""
refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
refresh_token_id = await self.store.add_refresh_token_to_user(
user_id=user_id,
token=refresh_token,
device_id=device_id,
)
return refresh_token, refresh_token_id

async def get_access_token_for_user_id(
self,
user_id: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
refresh_token_id: Optional[int] = None,
) -> str:
"""
Creates a new access token for the user with the given user ID.
Expand All @@ -801,6 +906,8 @@ async def get_access_token_for_user_id(
valid_until_ms: when the token is valid until. None for
no expiry.
is_appservice_ghost: Whether the user is an application ghost user
refresh_token_id: the refresh token ID that will be associated with
this access token.
Returns:
The access token for the user's session.
Raises:
Expand Down Expand Up @@ -836,6 +943,7 @@ async def get_access_token_for_user_id(
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
refresh_token_id=refresh_token_id,
)

# the device *should* have been registered before we got here; however,
Expand Down Expand Up @@ -928,7 +1036,7 @@ async def validate_login(
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
Expand Down Expand Up @@ -1073,7 +1181,7 @@ async def _validate_userid_login(
self,
username: str,
login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
Expand Down Expand Up @@ -1151,7 +1259,7 @@ async def _validate_userid_login(

async def check_password_provider_3pid(
self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
Expand Down Expand Up @@ -1215,6 +1323,19 @@ def generate_access_token(self, for_user: UserID) -> str:
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"

def generate_refresh_token(self, for_user: UserID) -> str:
"""Generates an opaque string, for use as a refresh token"""

# we use the following format for refresh tokens:
# syr_<base64 local part>_<random string>_<base62 crc check>

b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
random_string = stringutils.random_string(20)
base = f"syr_{b64local}_{random_string}"

crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"

async def validate_short_term_login_token(
self, login_token: str
) -> LoginTokenAttributes:
Expand Down Expand Up @@ -1563,7 +1684,7 @@ def _complete_sso_login(
)
respond_with_html(request, 200, html)

async def _sso_login_callback(self, login_result: JsonDict) -> None:
async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
"""
A login callback which might add additional attributes to the login response.
Expand All @@ -1577,7 +1698,8 @@ async def _sso_login_callback(self, login_result: JsonDict) -> None:

extra_attributes = self._extra_attributes.get(login_result["user_id"])
if extra_attributes:
login_result.update(extra_attributes.extra_attributes)
login_result_dict = cast(Dict[str, Any], login_result)
login_result_dict.update(extra_attributes.extra_attributes)

def _expire_sso_extra_attributes(self) -> None:
"""
Expand Down
Loading

0 comments on commit bd4919f

Please sign in to comment.