Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for authenticated media #69

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions src/matrix_content_scanner/scanner/file_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: AGPL-3.0-only
# Please see LICENSE in the repository root for full details.
import copy
import json
import logging
import urllib.parse
Expand Down Expand Up @@ -33,6 +34,8 @@ class _PathNotFoundException(Exception):
class FileDownloader:
MEDIA_DOWNLOAD_PREFIX = "_matrix/media/%s/download"
MEDIA_THUMBNAIL_PREFIX = "_matrix/media/%s/thumbnail"
MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/download"
MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX = "_matrix/client/%s/media/thumbnail"

def __init__(self, mcs: "MatrixContentScanner"):
self._base_url = mcs.config.download.base_homeserver_url
Expand All @@ -44,6 +47,7 @@ async def download_file(
self,
media_path: str,
thumbnail_params: Optional[MultiMapping[str]] = None,
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Retrieve the file with the given `server_name/media_id` path, and stores it on
disk.
Expand All @@ -52,6 +56,8 @@ async def download_file(
media_path: The path identifying the media to retrieve.
thumbnail_params: If present, then we want to request and scan a thumbnail
generated with the provided parameters instead of the full media.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.

Returns:
A description of the file (including its full content).
Expand All @@ -60,27 +66,45 @@ async def download_file(
ContentScannerRestError: The file was not found or could not be downloaded due
to an error on the remote homeserver's side.
"""

auth_media = True if auth_header is not None else False

prefix = (
self.MEDIA_DOWNLOAD_AUTHENTICATED_PREFIX
if auth_media
else self.MEDIA_DOWNLOAD_PREFIX
)
if thumbnail_params is not None:
prefix = (
self.MEDIA_THUMBNAIL_AUTHENTICATED_PREFIX
if auth_media
else self.MEDIA_THUMBNAIL_PREFIX
)

url = await self._build_https_url(
media_path, for_thumbnail=thumbnail_params is not None
media_path, prefix, "v1" if auth_media else "v3"
)

# Attempt to retrieve the file at the generated URL.
try:
file = await self._get_file_content(url, thumbnail_params)
file = await self._get_file_content(url, thumbnail_params, auth_header)
except _PathNotFoundException:
if auth_media:
raise ContentScannerRestError(
http_status=HTTPStatus.NOT_FOUND,
reason=ErrCode.NOT_FOUND,
info="File not found",
)

# If the file could not be found, it might be because the homeserver hasn't
# been upgraded to a version that supports Matrix v1.1 endpoints yet, so try
# again with an r0 endpoint.
logger.info("File not found, trying legacy r0 path")

url = await self._build_https_url(
media_path,
endpoint_version="r0",
for_thumbnail=thumbnail_params is not None,
)
url = await self._build_https_url(media_path, prefix, "r0")

try:
file = await self._get_file_content(url, thumbnail_params)
file = await self._get_file_content(url, thumbnail_params, auth_header)
except _PathNotFoundException:
# If that still failed, raise an error.
raise ContentScannerRestError(
Expand All @@ -94,9 +118,8 @@ async def download_file(
async def _build_https_url(
self,
media_path: str,
endpoint_version: str = "v3",
*,
for_thumbnail: bool,
prefix: str,
endpoint_version: str,
) -> str:
"""Turn a `server_name/media_id` path into an https:// one we can use to fetch
the media.
Expand All @@ -107,10 +130,8 @@ async def _build_https_url(
Args:
media_path: The media path to translate.
endpoint_version: The version of the download endpoint to use. As of Matrix
v1.1, this is either "v3" or "r0".
for_thumbnail: True if a server-side thumbnail is desired instead of the full
media. In that case, the URL for the `/thumbnail` endpoint is returned
instead of the `/download` endpoint.
v1.11, this is "v1" for authenticated media. For unauthenticated media
this is either "v3" or "r0".

Returns:
An https URL to use. If `base_homeserver_url` is set in the config, this
Expand Down Expand Up @@ -140,10 +161,6 @@ async def _build_https_url(
# didn't find a .well-known file.
base_url = "https://" + server_name

prefix = (
self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX
)

# Build the full URL.
path_prefix = prefix % endpoint_version
url = "%s/%s/%s/%s" % (
Expand All @@ -159,12 +176,15 @@ async def _get_file_content(
self,
url: str,
thumbnail_params: Optional[MultiMapping[str]],
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Retrieve the content of the file at a given URL.

Args:
url: The URL to query.
thumbnail_params: Query parameters used if the request is for a thumbnail.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.

Returns:
A description of the file (including its full content).
Expand All @@ -178,7 +198,9 @@ async def _get_file_content(
ContentScannerRestError: the server returned a non-200 status which cannot
meant that the path wasn't understood.
"""
code, body, headers = await self._get(url, query=thumbnail_params)
code, body, headers = await self._get(
url, query=thumbnail_params, auth_header=auth_header
)

logger.info("Remote server responded with %d", code)

Expand Down Expand Up @@ -307,12 +329,15 @@ async def _get(
self,
url: str,
query: Optional[MultiMapping[str]] = None,
auth_header: Optional[str] = None,
) -> Tuple[int, bytes, CIMultiDictProxy[str]]:
"""Sends a GET request to the provided URL.

Args:
url: The URL to send requests to.
query: Optional parameters to use in the request's query string.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.

Returns:
The HTTP status code, body and headers the remote server responded with.
Expand All @@ -324,10 +349,19 @@ async def _get(
try:
logger.info("Sending GET request to %s", url)
async with aiohttp.ClientSession() as session:
# TODO: Test we don't persist auth token
devonh marked this conversation as resolved.
Show resolved Hide resolved
request_headers = copy.deepcopy(self._headers)
if auth_header is not None:
auth_dict = {"Authorization": auth_header}
if request_headers is None:
request_headers = auth_dict
else:
request_headers.update(auth_dict)
devonh marked this conversation as resolved.
Show resolved Hide resolved

async with session.get(
url,
proxy=self._proxy_url,
headers=self._headers,
headers=request_headers,
params=query,
) as resp:
return resp.status, await resp.read(), resp.headers
Expand Down
10 changes: 9 additions & 1 deletion src/matrix_content_scanner/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def scan_file(
media_path: str,
metadata: Optional[JsonDict] = None,
thumbnail_params: Optional["MultiMapping[str]"] = None,
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Download and scan the given media.

Expand All @@ -119,6 +120,8 @@ async def scan_file(
the file isn't encrypted.
thumbnail_params: If present, then we want to request and scan a thumbnail
generated with the provided parameters instead of the full media.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.

Returns:
A description of the media.
Expand All @@ -141,7 +144,7 @@ async def scan_file(
# Try to download and scan the file.
try:
res = await self._scan_file(
cache_key, media_path, metadata, thumbnail_params
cache_key, media_path, metadata, thumbnail_params, auth_header
)
# Set the future's result, and mark it as done.
f.set_result(res)
Expand All @@ -168,6 +171,7 @@ async def _scan_file(
media_path: str,
metadata: Optional[JsonDict] = None,
thumbnail_params: Optional[MultiMapping[str]] = None,
auth_header: Optional[str] = None,
) -> MediaDescription:
"""Download and scan the given media.

Expand All @@ -185,6 +189,8 @@ async def _scan_file(
the file isn't encrypted.
thumbnail_params: If present, then we want to request and scan a thumbnail
generated with the provided parameters instead of the full media.
auth_header: If present, we forward the given Authorization header, this is
required for authenticated media endpoints.

Returns:
A description of the media.
Expand Down Expand Up @@ -218,6 +224,7 @@ async def _scan_file(
media = await self._file_downloader.download_file(
media_path=media_path,
thumbnail_params=thumbnail_params,
auth_header=auth_header,
)

# Compare the media's hash to ensure the server hasn't changed the file since
Expand Down Expand Up @@ -251,6 +258,7 @@ async def _scan_file(
media = await self._file_downloader.download_file(
media_path=media_path,
thumbnail_params=thumbnail_params,
auth_header=auth_header,
)

# Download and scan the file.
Expand Down
13 changes: 10 additions & 3 deletions src/matrix_content_scanner/servlets/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ async def _scan(
self,
media_path: str,
metadata: Optional[JsonDict] = None,
auth_header: Optional[str] = None,
) -> Tuple[int, _BytesResponse]:
media = await self._scanner.scan_file(media_path, metadata)
media = await self._scanner.scan_file(
media_path, metadata, auth_header=auth_header
)

return 200, _BytesResponse(
headers=media.response_headers,
Expand All @@ -38,7 +41,9 @@ async def _scan(
async def handle_plain(self, request: web.Request) -> Tuple[int, _BytesResponse]:
"""Handles GET requests to ../download/serverName/mediaId"""
media_path = request.match_info["media_path"]
return await self._scan(media_path)
return await self._scan(
media_path, auth_header=request.headers.get("Authorization")
)

@web_handler
async def handle_encrypted(
Expand All @@ -49,4 +54,6 @@ async def handle_encrypted(
request, self._crypto_handler
)

return await self._scan(media_path, metadata)
return await self._scan(
media_path, metadata, auth_header=request.headers.get("Authorization")
)
11 changes: 8 additions & 3 deletions src/matrix_content_scanner/servlets/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ async def _scan_and_format(
self,
media_path: str,
metadata: Optional[JsonDict] = None,
auth_header: Optional[str] = None,
) -> Tuple[int, JsonDict]:
try:
await self._scanner.scan_file(media_path, metadata)
await self._scanner.scan_file(media_path, metadata, auth_header=auth_header)
except FileDirtyError as e:
res = {"clean": False, "info": e.info}
else:
Expand All @@ -37,12 +38,16 @@ async def _scan_and_format(
async def handle_plain(self, request: web.Request) -> Tuple[int, JsonDict]:
"""Handles GET requests to ../scan/serverName/mediaId"""
media_path = request.match_info["media_path"]
return await self._scan_and_format(media_path)
return await self._scan_and_format(
media_path, auth_header=request.headers.get("Authorization")
)

@web_handler
async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]:
"""Handles GET requests to ../scan_encrypted"""
media_path, metadata = await get_media_metadata_from_request(
request, self._crypto_handler
)
return await self._scan_and_format(media_path, metadata)
return await self._scan_and_format(
media_path, metadata, auth_header=request.headers.get("Authorization")
)
1 change: 1 addition & 0 deletions src/matrix_content_scanner/servlets/thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def handle_thumbnail(
media = await self._scanner.scan_file(
media_path=media_path,
thumbnail_params=request.query,
auth_header=request.headers.get("Authorization"),
)

return 200, _BytesResponse(
Expand Down
Loading
Loading