diff --git a/changelog.d/17488.feature b/changelog.d/17488.feature new file mode 100644 index 00000000000..15cccf3ac22 --- /dev/null +++ b/changelog.d/17488.feature @@ -0,0 +1 @@ +Implement [MSC4133](https://github.com/matrix-org/matrix-spec-proposals/pull/4133) for custom profile fields. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 21989b6e0e8..5dd6e84289a 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -132,6 +132,10 @@ class Codes(str, Enum): # connection. UNKNOWN_POS = "M_UNKNOWN_POS" + # Part of MSC4133 + PROFILE_TOO_LARGE = "M_PROFILE_TOO_LARGE" + KEY_TOO_LARGE = "M_KEY_TOO_LARGE" + class CodeMessageException(RuntimeError): """An exception with integer code, a message string attributes and optional headers. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 90d19849ffd..94a25c7ee83 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -436,6 +436,9 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: ("experimental", "msc4108_delegation_endpoint"), ) + # MSC4133: Custom profile fields + self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False) + # MSC4210: Remove legacy mentions self.msc4210_enabled: bool = experimental.get("msc4210_enabled", False) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 22eedcb54f6..cdc388b4ab1 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -32,7 +32,7 @@ SynapseError, ) from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -43,6 +43,8 @@ MAX_DISPLAYNAME_LEN = 256 MAX_AVATAR_URL_LEN = 1000 +# Field name length is specced at 255 bytes. +MAX_CUSTOM_FIELD_LEN = 255 class ProfileHandler: @@ -90,7 +92,15 @@ async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDi if self.hs.is_mine(target_user): profileinfo = await self.store.get_profileinfo(target_user) - if profileinfo.display_name is None and profileinfo.avatar_url is None: + extra_fields = {} + if self.hs.config.experimental.msc4133_enabled: + extra_fields = await self.store.get_profile_fields(target_user) + + if ( + profileinfo.display_name is None + and profileinfo.avatar_url is None + and not extra_fields + ): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) # Do not include display name or avatar if unset. @@ -99,6 +109,9 @@ async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDi ret[ProfileFields.DISPLAYNAME] = profileinfo.display_name if profileinfo.avatar_url is not None: ret[ProfileFields.AVATAR_URL] = profileinfo.avatar_url + if extra_fields: + ret.update(extra_fields) + return ret else: try: @@ -403,6 +416,110 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool: return True + async def get_profile_field( + self, target_user: UserID, field_name: str + ) -> JsonValue: + """ + Fetch a user's profile from the database for local users and over federation + for remote users. + + Args: + target_user: The user ID to fetch the profile for. + field_name: The field to fetch the profile for. + + Returns: + The value for the profile field or None if the field does not exist. + """ + if self.hs.is_mine(target_user): + try: + field_value = await self.store.get_profile_field( + target_user, field_name + ) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) + raise + + return field_value + else: + try: + result = await self.federation.make_query( + destination=target_user.domain, + query_type="profile", + args={"user_id": target_user.to_string(), "field": field_name}, + ignore_backoff=True, + ) + except RequestSendFailed as e: + raise SynapseError(502, "Failed to fetch profile") from e + except HttpResponseException as e: + raise e.to_synapse_error() + + return result.get(field_name) + + async def set_profile_field( + self, + target_user: UserID, + requester: Requester, + field_name: str, + new_value: JsonValue, + by_admin: bool = False, + deactivation: bool = False, + ) -> None: + """Set a new profile field for a user. + + Args: + target_user: the user whose profile is to be changed. + requester: The user attempting to make this change. + field_name: The name of the profile field to update. + new_value: The new field value for this user. + by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. + """ + if not self.hs.is_mine(target_user): + raise SynapseError(400, "User is not hosted on this homeserver") + + if not by_admin and target_user != requester.user: + raise AuthError(403, "Cannot set another user's profile") + + await self.store.set_profile_field(target_user, field_name, new_value) + + # Custom fields do not propagate into the user directory *or* rooms. + profile = await self.store.get_profileinfo(target_user) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + + async def delete_profile_field( + self, + target_user: UserID, + requester: Requester, + field_name: str, + by_admin: bool = False, + deactivation: bool = False, + ) -> None: + """Delete a field from a user's profile. + + Args: + target_user: the user whose profile is to be changed. + requester: The user attempting to make this change. + field_name: The name of the profile field to remove. + by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. + """ + if not self.hs.is_mine(target_user): + raise SynapseError(400, "User is not hosted on this homeserver") + + if not by_admin and target_user != requester.user: + raise AuthError(400, "Cannot set another user's profile") + + await self.store.delete_profile_field(target_user, field_name) + + # Custom fields do not propagate into the user directory *or* rooms. + profile = await self.store.get_profileinfo(target_user) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + async def on_profile_query(self, args: JsonDict) -> JsonDict: """Handles federation profile query requests.""" @@ -419,13 +536,24 @@ async def on_profile_query(self, args: JsonDict) -> JsonDict: just_field = args.get("field", None) - response = {} + response: JsonDict = {} try: - if just_field is None or just_field == "displayname": + if just_field is None or just_field == ProfileFields.DISPLAYNAME: response["displayname"] = await self.store.get_profile_displayname(user) - if just_field is None or just_field == "avatar_url": + if just_field is None or just_field == ProfileFields.AVATAR_URL: response["avatar_url"] = await self.store.get_profile_avatar_url(user) + + if self.hs.config.experimental.msc4133_enabled: + if just_field is None: + response.update(await self.store.get_profile_fields(user)) + elif just_field not in ( + ProfileFields.DISPLAYNAME, + ProfileFields.AVATAR_URL, + ): + response[just_field] = await self.store.get_profile_field( + user, just_field + ) except StoreError as e: if e.code == 404: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 63b8a9364a4..ebd5a33ea5f 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -92,6 +92,23 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "enabled": self.config.experimental.msc3664_enabled, } + if self.config.experimental.msc4133_enabled: + response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = { + "enabled": True, + } + + # Ensure this is consistent with the legacy m.set_displayname and + # m.set_avatar_url. + disallowed = [] + if not self.config.registration.enable_set_displayname: + disallowed.append("displayname") + if not self.config.registration.enable_set_avatar_url: + disallowed.append("avatar_url") + if disallowed: + response["capabilities"]["uk.tcpip.msc4133.profile_fields"][ + "disallowed" + ] = disallowed + return HTTPStatus.OK, response diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index ef59582865f..8326d8017c9 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -21,10 +21,13 @@ """This module contains REST servlets to do with profile: /profile/""" +import re from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ProfileFields from synapse.api.errors import Codes, SynapseError +from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -33,7 +36,8 @@ ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonValue, UserID +from synapse.util.stringutils import is_namedspaced_grammar if TYPE_CHECKING: from synapse.server import HomeServer @@ -91,6 +95,11 @@ async def on_GET( async def on_PUT( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester) @@ -101,9 +110,7 @@ async def on_PUT( new_name = content["displayname"] except Exception: raise SynapseError( - code=400, - msg="Unable to parse name", - errcode=Codes.BAD_JSON, + 400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM ) propagate = _read_propagate(self.hs, request) @@ -166,6 +173,11 @@ async def on_GET( async def on_PUT( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester) @@ -232,7 +244,180 @@ async def on_GET( return 200, ret +class UnstableProfileFieldRestServlet(RestServlet): + PATTERNS = [ + re.compile( + r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P[^/]*)/(?P[^/]*)" + ) + ] + CATEGORY = "Event sending requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET( + self, request: SynapseRequest, user_id: str, field_name: str + ) -> Tuple[int, JsonDict]: + requester_user = None + + if self.hs.config.server.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + + if not field_name: + raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM) + + if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN: + raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE) + if not is_namedspaced_grammar(field_name): + raise SynapseError( + 400, + "Field name does not follow Common Namespaced Identifier Grammar", + errcode=Codes.INVALID_PARAM, + ) + + user = UserID.from_string(user_id) + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + if field_name == ProfileFields.DISPLAYNAME: + field_value: JsonValue = await self.profile_handler.get_displayname(user) + elif field_name == ProfileFields.AVATAR_URL: + field_value = await self.profile_handler.get_avatar_url(user) + else: + field_value = await self.profile_handler.get_profile_field(user, field_name) + + return 200, {field_name: field_value} + + async def on_PUT( + self, request: SynapseRequest, user_id: str, field_name: str + ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester) + + if not field_name: + raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM) + + if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN: + raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE) + if not is_namedspaced_grammar(field_name): + raise SynapseError( + 400, + "Field name does not follow Common Namespaced Identifier Grammar", + errcode=Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + try: + new_value = content[field_name] + except KeyError: + raise SynapseError( + 400, f"Missing key '{field_name}'", errcode=Codes.MISSING_PARAM + ) + + propagate = _read_propagate(self.hs, request) + + requester_suspended = ( + await self.hs.get_datastores().main.get_user_suspended_status( + requester.user.to_string() + ) + ) + + if requester_suspended: + raise SynapseError( + 403, + "Updating profile while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + + if field_name == ProfileFields.DISPLAYNAME: + await self.profile_handler.set_displayname( + user, requester, new_value, is_admin, propagate=propagate + ) + elif field_name == ProfileFields.AVATAR_URL: + await self.profile_handler.set_avatar_url( + user, requester, new_value, is_admin, propagate=propagate + ) + else: + await self.profile_handler.set_profile_field( + user, requester, field_name, new_value, is_admin + ) + + return 200, {} + + async def on_DELETE( + self, request: SynapseRequest, user_id: str, field_name: str + ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester) + + if not field_name: + raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM) + + if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN: + raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE) + if not is_namedspaced_grammar(field_name): + raise SynapseError( + 400, + "Field name does not follow Common Namespaced Identifier Grammar", + errcode=Codes.INVALID_PARAM, + ) + + propagate = _read_propagate(self.hs, request) + + requester_suspended = ( + await self.hs.get_datastores().main.get_user_suspended_status( + requester.user.to_string() + ) + ) + + if requester_suspended: + raise SynapseError( + 403, + "Updating profile while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + + if field_name == ProfileFields.DISPLAYNAME: + await self.profile_handler.set_displayname( + user, requester, "", is_admin, propagate=propagate + ) + elif field_name == ProfileFields.AVATAR_URL: + await self.profile_handler.set_avatar_url( + user, requester, "", is_admin, propagate=propagate + ) + else: + await self.profile_handler.delete_profile_field( + user, requester, field_name, is_admin + ) + + return 200, {} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + # The specific displayname / avatar URL / custom field endpoints *must* appear + # before their corresponding generic profile endpoint. ProfileDisplaynameRestServlet(hs).register(http_server) ProfileAvatarURLRestServlet(hs).register(http_server) ProfileRestServlet(hs).register(http_server) + if hs.config.experimental.msc4133_enabled: + UnstableProfileFieldRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index a1d089ebac8..266a0b835b9 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -172,6 +172,8 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "org.matrix.msc4140": bool(self.config.server.max_event_delay_ms), # Simplified sliding sync "org.matrix.simplified_msc3575": msc3575_enabled, + # Arbitrary key-value profile fields. + "uk.tcpip.msc4133": self.config.experimental.msc4133_enabled, }, }, ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 41cf08211f2..30d8a58d965 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -18,8 +18,13 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Optional +import json +from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from canonicaljson import encode_canonical_json + +from synapse.api.constants import ProfileFields +from synapse.api.errors import Codes, StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -27,13 +32,17 @@ LoggingTransaction, ) from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, UserID +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict, JsonValue, UserID if TYPE_CHECKING: from synapse.server import HomeServer +# The number of bytes that the serialized profile can have. +MAX_PROFILE_SIZE = 65536 + + class ProfileWorkerStore(SQLBaseStore): def __init__( self, @@ -201,6 +210,89 @@ async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]: desc="get_profile_avatar_url", ) + async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue: + """ + Get a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The custom profile field name. + + Returns: + The string value if the field exists, otherwise raises 404. + """ + + def get_profile_field(txn: LoggingTransaction) -> JsonValue: + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + field_path = f'$."{field_name}"' + + if isinstance(self.database_engine, PostgresEngine): + sql = """ + SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_name, user_id.localpart), + ) + + # Test exists first since value being None is used for both + # missing and a null JSON value. + exists, value = cast(Tuple[bool, JsonValue], txn.fetchone()) + if not exists: + raise StoreError(404, "No row found") + return value + + else: + sql = """ + SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_path, user_id.localpart), + ) + + # If value_type is None, then the value did not exist. + value_type, value = cast( + Tuple[Optional[str], JsonValue], txn.fetchone() + ) + if not value_type: + raise StoreError(404, "No row found") + # If value_type is object or array, then need to deserialize the JSON. + # Scalar values are properly returned directly. + if value_type in ("object", "array"): + assert isinstance(value, str) + return json.loads(value) + return value + + return await self.db_pool.runInteraction("get_profile_field", get_profile_field) + + async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]: + """ + Get all custom profile fields for a user. + + Args: + user_id: The user's ID. + + Returns: + A dictionary of custom profile fields. + """ + result = await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcol="fields", + desc="get_profile_fields", + ) + # The SQLite driver doesn't automatically convert JSON to + # Python objects + if isinstance(self.database_engine, Sqlite3Engine) and result: + result = json.loads(result) + return result or {} + async def create_profile(self, user_id: UserID) -> None: """ Create a blank profile for a user. @@ -215,6 +307,71 @@ async def create_profile(self, user_id: UserID) -> None: desc="create_profile", ) + def _check_profile_size( + self, + txn: LoggingTransaction, + user_id: UserID, + new_field_name: str, + new_value: JsonValue, + ) -> None: + # For each entry there are 4 quotes (2 each for key and value), 1 colon, + # and 1 comma. + PER_VALUE_EXTRA = 6 + + # Add the size of the current custom profile fields, ignoring the entry + # which will be overwritten. + if isinstance(txn.database_engine, PostgresEngine): + size_sql = """ + SELECT + OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + (new_field_name, user_id.localpart), + ) + else: + size_sql = """ + SELECT + LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + (f'$."{new_field_name}"', user_id.localpart), + ) + row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone()) + + # The values return null if the column is null. + total_bytes = ( + # Discount the opening and closing braces to avoid double counting, + # but add one for a comma. + # -2 + 1 = -1 + (row[0] - 1 if row[0] else 0) + + ( + row[1] + len("displayname") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.DISPLAYNAME and row[1] + else 0 + ) + + ( + row[2] + len("avatar_url") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.AVATAR_URL and row[2] + else 0 + ) + ) + + # Add the length of the field being added + the braces. + total_bytes += len(encode_canonical_json({new_field_name: new_value})) + + if total_bytes > MAX_PROFILE_SIZE: + raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE) + async def set_profile_displayname( self, user_id: UserID, new_displayname: Optional[str] ) -> None: @@ -227,14 +384,25 @@ async def set_profile_displayname( name is removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={ - "displayname": new_displayname, - "full_user_id": user_id.to_string(), - }, - desc="set_profile_displayname", + + def set_profile_displayname(txn: LoggingTransaction) -> None: + if new_displayname is not None: + self._check_profile_size( + txn, user_id, ProfileFields.DISPLAYNAME, new_displayname + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "displayname": new_displayname, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_displayname", set_profile_displayname ) async def set_profile_avatar_url( @@ -249,13 +417,125 @@ async def set_profile_avatar_url( removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()}, - desc="set_profile_avatar_url", + + def set_profile_avatar_url(txn: LoggingTransaction) -> None: + if new_avatar_url is not None: + self._check_profile_size( + txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "avatar_url": new_avatar_url, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_avatar_url", set_profile_avatar_url ) + async def set_profile_field( + self, user_id: UserID, field_name: str, new_value: JsonValue + ) -> None: + """ + Set a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + new_value: The value of the custom profile field. + """ + + # Encode to canonical JSON. + canonical_value = encode_canonical_json(new_value) + + def set_profile_field(txn: LoggingTransaction) -> None: + self._check_profile_size(txn, user_id, field_name, new_value) + + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import Json + + # Note that the || jsonb operator is not recursive, any duplicate + # keys will be taken from the second value. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb)) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields + """ + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + field_name, + # Pass as a JSON object since we have passing bytes disabled + # at the database driver. + Json(json.loads(canonical_value)), + ), + ) + else: + # You may be tempted to use json_patch instead of providing the parameters + # twice, but that recursively merges objects instead of replacing. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?))) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?)) + """ + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + json_field_name = f'$."{field_name}"' + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + json_field_name, + canonical_value, + json_field_name, + canonical_value, + ), + ) + + await self.db_pool.runInteraction("set_profile_field", set_profile_field) + + async def delete_profile_field(self, user_id: UserID, field_name: str) -> None: + """ + Remove a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + """ + + def delete_profile_field(txn: LoggingTransaction) -> None: + if isinstance(self.database_engine, PostgresEngine): + sql = """ + UPDATE profiles SET fields = fields - ? + WHERE user_id = ? + """ + txn.execute( + sql, + (field_name, user_id.localpart), + ) + else: + sql = """ + UPDATE profiles SET fields = json_remove(fields, ?) + WHERE user_id = ? + """ + txn.execute( + sql, + # This will error if field_name has double quotes in it. + (f'$."{field_name}"', user_id.localpart), + ) + + await self.db_pool.runInteraction("delete_profile_field", delete_profile_field) + class ProfileStore(ProfileWorkerStore): pass diff --git a/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql b/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql new file mode 100644 index 00000000000..63cbd7ffa99 --- /dev/null +++ b/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql @@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 Patrick Cloke +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Custom profile fields. +ALTER TABLE profiles ADD COLUMN fields JSONB; diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 13ff54b6692..32b5bc00c9d 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -43,6 +43,14 @@ # MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") +# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar +# +# At least one character, less than or equal to 255 characters. Must start with +# a-z, the rest is a-z, 0-9, -, _, or .. +# +# This doesn't check anything about validity of namespaces. +NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$") + def random_string(length: int) -> str: """Generate a cryptographically secure string of random letters. @@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool: return True +def is_namedspaced_grammar(s: str) -> bool: + return bool(NAMESPACED_GRAMMAR.match(s)) + + def assert_valid_client_secret(client_secret: str) -> None: """Validate that a given string matches the client_secret defined by the spec""" if ( diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index bbe8ab1a7c5..8af00221c23 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -142,6 +142,50 @@ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: self.assertEqual(channel.code, HTTPStatus.OK) self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) + @override_config( + { + "enable_set_displayname": False, + "experimental_features": {"msc4133_enabled": True}, + } + ) + def test_get_set_displayname_capabilities_displayname_disabled_msc4133( + self, + ) -> None: + """Test if set displayname is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.set_displayname"]["enabled"]) + self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"]) + self.assertEqual( + capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"], + ["displayname"], + ) + + @override_config( + { + "enable_set_avatar_url": False, + "experimental_features": {"msc4133_enabled": True}, + } + ) + def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None: + """Test if set avatar_url is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"]) + self.assertEqual( + capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"], + ["avatar_url"], + ) + @override_config({"enable_3pid_changes": False}) def test_get_change_3pid_capabilities_3pid_disabled(self) -> None: """Test if change 3pid is disabled that the server responds it.""" diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index a92713d220e..708402b7929 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -25,16 +25,20 @@ from http import HTTPStatus from typing import Any, Dict, Optional +from canonicaljson import encode_canonical_json + from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room from synapse.server import HomeServer +from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE from synapse.types import UserID from synapse.util import Clock from tests import unittest +from tests.utils import USE_POSTGRES_FOR_TESTS class ProfileTestCase(unittest.HomeserverTestCase): @@ -480,6 +484,298 @@ def test_msc4069_inhibit_propagation_like_default(self) -> None: # The client requested ?propagate=true, so it should have happened. self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif") + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_get_missing_custom_field(self) -> None: + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_get_missing_custom_field_invalid_field_name(self) -> None: + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]", + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_get_custom_field_rejects_bad_username(self) -> None: + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field", + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_set_custom_field(self) -> None: + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + content={"custom_field": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.json_body, {"custom_field": "test"}) + + # Overwriting the field should work. + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + content={"custom_field": "new_Value"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.json_body, {"custom_field": "new_Value"}) + + # Deleting the field should work. + channel = self.make_request( + "DELETE", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_non_string(self) -> None: + """Non-string fields are supported for custom fields.""" + fields = { + "bool_field": True, + "array_field": ["test"], + "object_field": {"test": "test"}, + "numeric_field": 1, + "null_field": None, + } + + for key, value in fields.items(): + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + content={key: value}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + f"/_matrix/client/v3/profile/{self.owner}", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.json_body, {"displayname": "owner", **fields}) + + # Check getting individual fields works. + for key, value in fields.items(): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.json_body, {key: value}) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_set_custom_field_noauth(self) -> None: + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + content={"custom_field": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_set_custom_field_size(self) -> None: + """ + Attempts to set a custom field name that is too long should get a 400 error. + """ + # Key is missing. + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/", + content={"": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Single key is too large. + key = "c" * 500 + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + content={key: "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE) + + channel = self.make_request( + "DELETE", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + content={key: "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE) + + # Key doesn't match body. + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field", + content={"diff_key": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_set_custom_field_profile_too_long(self) -> None: + """ + Attempts to set a custom field that would push the overall profile too large. + """ + # Get right to the boundary: + # len("displayname") + len("owner") + 5 = 21 for the displayname + # 1 + 65498 + 5 for key "a" = 65504 + # 2 braces, 1 comma + # 3 + 21 + 65498 = 65522 < 65536. + key = "a" + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + content={key: "a" * 65498}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Get the entire profile. + channel = self.make_request( + "GET", + f"/_matrix/client/v3/profile/{self.owner}", + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + canonical_json = encode_canonical_json(channel.json_body) + # 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key. + # Be one below that so we can prove we're at the boundary. + self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8) + + # Postgres stores JSONB with whitespace, while SQLite doesn't. + if USE_POSTGRES_FOR_TESTS: + ADDITIONAL_CHARS = 0 + else: + ADDITIONAL_CHARS = 1 + + # The next one should fail, note the value has a (JSON) length of 2. + key = "b" + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + content={key: "1" + "a" * ADDITIONAL_CHARS}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) + + # Setting an avatar or (longer) display name should not work. + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/displayname", + content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) + + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/avatar_url", + content={"avatar_url": "mxc://foo/bar"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) + + # Removing a single byte should work. + key = "b" + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + content={key: "" + "a" * ADDITIONAL_CHARS}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Finally, setting a field that already exists to a value that is <= in length should work. + key = "a" + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}", + content={key: ""}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_set_custom_field_displayname(self) -> None: + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname", + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + displayname = self._get_displayname() + self.assertEqual(displayname, "test") + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_set_custom_field_avatar_url(self) -> None: + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url", + content={"avatar_url": "mxc://test/good"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + avatar_url = self._get_avatar_url() + self.assertEqual(avatar_url, "mxc://test/good") + + @unittest.override_config({"experimental_features": {"msc4133_enabled": True}}) + def test_set_custom_field_other(self) -> None: + """Setting someone else's profile field should fail""" + channel = self.make_request( + "PUT", + f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field", + content={"custom_field": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 403, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: """Stores metadata about files in the database. diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py index 646fd2163e6..34c2395ecfa 100644 --- a/tests/util/test_stringutils.py +++ b/tests/util/test_stringutils.py @@ -20,7 +20,11 @@ # from synapse.api.errors import SynapseError -from synapse.util.stringutils import assert_valid_client_secret, base62_encode +from synapse.util.stringutils import ( + assert_valid_client_secret, + base62_encode, + is_namedspaced_grammar, +) from .. import unittest @@ -58,3 +62,25 @@ def test_base62_encode(self) -> None: self.assertEqual("10", base62_encode(62)) self.assertEqual("1c", base62_encode(100)) self.assertEqual("001c", base62_encode(100, minwidth=4)) + + def test_namespaced_identifier(self) -> None: + self.assertTrue(is_namedspaced_grammar("test")) + self.assertTrue(is_namedspaced_grammar("m.test")) + self.assertTrue(is_namedspaced_grammar("org.matrix.test")) + self.assertTrue(is_namedspaced_grammar("org.matrix.msc1234")) + self.assertTrue(is_namedspaced_grammar("test")) + self.assertTrue(is_namedspaced_grammar("t-e_s.t")) + + # Must start with letter. + self.assertFalse(is_namedspaced_grammar("1test")) + self.assertFalse(is_namedspaced_grammar("-test")) + self.assertFalse(is_namedspaced_grammar("_test")) + self.assertFalse(is_namedspaced_grammar(".test")) + + # Must contain only a-z, 0-9, -, _, .. + self.assertFalse(is_namedspaced_grammar("test/")) + self.assertFalse(is_namedspaced_grammar('test"')) + self.assertFalse(is_namedspaced_grammar("testö")) + + # Must be < 255 characters. + self.assertFalse(is_namedspaced_grammar("t" * 256))