Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Return attrs for more media repo APIs. #16611

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16611.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
15 changes: 9 additions & 6 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from synapse.api.errors import (
AuthError,
Expand All @@ -23,6 +23,7 @@
StoreError,
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
Expand Down Expand Up @@ -306,7 +307,9 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
server_name = host

if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id)
media_info: Optional[
Union[LocalMedia, RemoteMedia]
] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)

Expand All @@ -322,25 +325,25 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:

if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size
if media_info["media_length"] > self.max_avatar_size:
if media_info.media_length > self.max_avatar_size:
logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit",
mxc,
media_info["media_length"],
media_info.media_length,
)
return False

if self.allowed_avatar_mimetypes:
# Ensure the avatar's file type is allowed
if (
self.allowed_avatar_mimetypes
and media_info["media_type"] not in self.allowed_avatar_mimetypes
and media_info.media_type not in self.allowed_avatar_mimetypes
):
logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed",
mxc,
media_info["media_type"],
media_info.media_type,
)
return False

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def is_allowed_mime_type(content_type: str) -> bool:
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]:
if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True

Expand Down
70 changes: 40 additions & 30 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple

import attr
from matrix_common.types.mxc_uri import MXCUri

import twisted.internet.error
Expand Down Expand Up @@ -50,6 +51,7 @@
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
Expand Down Expand Up @@ -245,18 +247,18 @@ async def get_local_media(
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
if not media_info or media_info.quarantined_by:
respond_404(request)
return

self.mark_recently_accessed(None, media_id)

media_type = media_info["media_type"]
media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
media_length = media_info.media_length
upload_name = name if name else media_info.upload_name
url_cache = media_info.url_cache

file_info = FileInfo(None, media_id, url_cache=bool(url_cache))

Expand Down Expand Up @@ -310,16 +312,20 @@ async def get_remote_media(

# We deliberately stream the file outside the lock
if responder:
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
upload_name = name if name else media_info.upload_name
await respond_with_responder(
request, responder, media_type, media_length, upload_name
request,
responder,
media_info.media_type,
media_info.media_length,
upload_name,
)
else:
respond_404(request)

async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
async def get_remote_media_info(
self, server_name: str, media_id: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.

Expand Down Expand Up @@ -353,7 +359,7 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:

async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.

Expand All @@ -373,15 +379,17 @@ async def _get_remote_media_impl(

# If we have an entry in the DB, try and look for it
if media_info:
file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id)

if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
raise NotFoundError()

if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
if not media_info.media_type:
media_info = attr.evolve(
media_info, media_type="application/octet-stream"
)
clokep marked this conversation as resolved.
Show resolved Hide resolved

responder = await self.media_storage.fetch_media(file_info)
if responder:
Expand All @@ -403,9 +411,9 @@ async def _get_remote_media_impl(
if not media_info:
raise e

file_id = media_info["filesystem_id"]
if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
file_id = media_info.filesystem_id
if not media_info.media_type:
media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id)

# We generate thumbnails even if another process downloaded the media
Expand All @@ -415,7 +423,7 @@ async def _get_remote_media_impl(
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"]
server_name, media_id, file_id, media_info.media_type
)

responder = await self.media_storage.fetch_media(file_info)
Expand All @@ -425,7 +433,7 @@ async def _download_remote_file(
self,
server_name: str,
media_id: str,
) -> dict:
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.

Expand Down Expand Up @@ -518,23 +526,25 @@ async def _download_remote_file(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)

logger.info("Stored remote media in file %r", fname)

media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}

return media_info
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
media_type=media_type,
media_length=length,
upload_name=upload_name,
created_ts=time_now_ms,
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
)

def _get_thumbnail_requirements(
self, media_type: str
Expand Down
11 changes: 5 additions & 6 deletions synapse/media/url_previewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,14 @@ async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
and cache_result["response_code"] / 100 == 2
and cache_result.expires_ts > ts
and cache_result.response_code // 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
if isinstance(og, str):
og = og.encode("utf8")
return og
if isinstance(cache_result.og, str):
return cache_result.og.encode("utf8")
return cache_result.og

# If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url
Expand Down
16 changes: 8 additions & 8 deletions synapse/rest/media/thumbnail_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _respond_local_thumbnail(
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
Expand All @@ -134,7 +134,7 @@ async def _respond_local_thumbnail(
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
server_name=None,
)

Expand All @@ -152,7 +152,7 @@ async def _select_or_generate_local_thumbnail(
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
Expand All @@ -168,7 +168,7 @@ async def _select_or_generate_local_thumbnail(
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
url_cache=bool(media_info.url_cache),
thumbnail=info,
)

Expand All @@ -188,7 +188,7 @@ async def _select_or_generate_local_thumbnail(
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
)

if file_path:
Expand All @@ -213,7 +213,7 @@ async def _select_or_generate_remote_thumbnail(
server_name, media_id
)

file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id

for info in thumbnail_infos:
t_w = info.width == desired_width
Expand All @@ -224,7 +224,7 @@ async def _select_or_generate_remote_thumbnail(
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
file_id=file_id,
thumbnail=info,
)

Expand Down Expand Up @@ -280,7 +280,7 @@ async def _respond_remote_thumbnail(
m_type,
thumbnail_infos,
media_id,
media_info["filesystem_id"],
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
Expand Down
Loading
Loading