Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix handling of redactions of redactions #5788

Merged
merged 5 commits into from
Aug 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5788.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correctly handle redactions of redactions.
213 changes: 112 additions & 101 deletions synapse/storage/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
Expand Down Expand Up @@ -342,13 +337,12 @@ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
log_ctx = LoggingContext.current_context()
log_ctx.record_event_fetch(len(missing_events_ids))

# Note that _enqueue_events is also responsible for turning db rows
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
# _enqueue_events is a bit of a rubbish name but naming is hard.
missing_events = yield self._enqueue_events(
missing_events = yield self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)

Expand Down Expand Up @@ -421,28 +415,28 @@ def _fetch_event_list(self, conn, event_list):
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.

The deferreds are callbacked with a dictionary mapping from event id
to event row. Note that it may well contain additional events that
were not part of this request.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not be in a Returns?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well no, because it doesn't return it. It sends it into the Deferreds which get passed in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, duh, sorry, yes.

"""
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = list(zip(*event_list))[0]
event_ids = [item for sublist in event_id_lists for item in sublist]
events_to_fetch = set(
event_id for events, _ in event_list for event_id in events
)

row_dict = self._new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)

# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
with PreserveLoggingContext():
d.callback([res[i] for i in ids if i in res])
except Exception:
logger.exception("Failed to callback")
def fire():
for _, d in event_list:
d.callback(row_dict)

with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
self.hs.get_reactor().callFromThread(fire)
except Exception as e:
logger.exception("do_fetch")

Expand All @@ -457,13 +451,98 @@ def fire(evs, exc):
self.hs.get_reactor().callFromThread(fire, event_list, e)

@defer.inlineCallbacks
def _enqueue_events(self, events, allow_rejected=False):
def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.

Returned events will be added to the cache for future lookups.

Args:
event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events

Returns:
Deferred[Dict[str, _EventCacheEntry]]:
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
events_to_fetch = event_ids

while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch)

# we need to recursively fetch any redactions of those events
redaction_ids = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
redaction_ids.update(row["redactions"])

events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)

# build a map from event_id to EventBase
event_map = {}
for event_id, row in fetched_events.items():
if not row:
continue
assert row["event_id"] == event_id

rejected_reason = row["rejected_reason"]

if not allow_rejected and rejected_reason:
continue

d = json.loads(row["json"])
internal_metadata = json.loads(row["internal_metadata"])

format_version = row["format_version"]
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1

original_ev = event_type_from_format_version(format_version)(
event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)

event_map[event_id] = original_ev

# finally, we can decide whether each one nededs redacting, and build
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id]["redactions"]
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)

cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)

self._get_event_cache.prefill((event_id,), cache_entry)
result_map[event_id] = cache_entry

return result_map

@defer.inlineCallbacks
def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.

Args:
events (Iterable[str]): events to be fetched.

Returns:
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
May contain events that weren't requested.
"""
if not events:
return {}

events_d = defer.Deferred()
with self._event_fetch_lock:
Expand All @@ -482,32 +561,12 @@ def _enqueue_events(self, events, allow_rejected=False):
"fetch_events", self.runWithConnection, self._do_fetch
)

logger.debug("Loading %d events", len(events))
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
rows = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))

if not allow_rejected:
rows[:] = [r for r in rows if r["rejected_reason"] is None]

res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._get_event_from_row,
row["internal_metadata"],
row["json"],
row["redactions"],
rejected_reason=row["rejected_reason"],
format_version=row["format_version"],
)
for row in rows
],
consumeErrors=True,
)
)
row_map = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))

return {e.event.event_id: e for e in res if e}
return row_map

def _fetch_event_rows(self, txn, event_ids):
"""Fetch event rows from the database
Expand Down Expand Up @@ -580,57 +639,16 @@ def _fetch_event_rows(self, txn, event_ids):

return event_dict

@defer.inlineCallbacks
def _get_event_from_row(
self, internal_metadata, js, redactions, format_version, rejected_reason=None
):
"""Parse an event row which has been read from the database

Args:
internal_metadata (str): json-encoded internal_metadata column
js (str): json-encoded event body from event_json
redactions (list[str]): a list of the events which claim to have redacted
this event, from the redactions table
format_version: (str): the 'format_version' column
rejected_reason (str|None): the reason this event was rejected, if any

Returns:
_EventCacheEntry
"""
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)

if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1

original_ev = event_type_from_format_version(format_version)(
event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)

redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)

cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)

self._get_event_cache.prefill((original_ev.event_id,), cache_entry)

return cache_entry

@defer.inlineCallbacks
def _maybe_redact_event_row(self, original_ev, redactions):
def _maybe_redact_event_row(self, original_ev, redactions, event_map):
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.

Args:
original_ev (EventBase):
redactions (iterable[str]): list of event ids of potential redaction events
event_map (dict[str, EventBase]): other events which have been fetched, in
which we can look up the redaaction events. Map from event id to event.

Returns:
Deferred[EventBase|None]: if the event should be redacted, a pruned
Expand All @@ -640,15 +658,9 @@ def _maybe_redact_event_row(self, original_ev, redactions):
# we choose to ignore redactions of m.room.create events.
return None

if original_ev.type == "m.room.redaction":
# ... and redaction events
return None

redaction_map = yield self._get_events_from_cache_or_db(redactions)

for redaction_id in redactions:
redaction_entry = redaction_map.get(redaction_id)
if not redaction_entry:
redaction_event = event_map.get(redaction_id)
if not redaction_event or redaction_event.rejected_reason:
# we don't have the redaction event, or the redaction event was not
# authorized.
logger.debug(
Expand All @@ -658,7 +670,6 @@ def _maybe_redact_event_row(self, original_ev, redactions):
)
continue

redaction_event = redaction_entry.event
if redaction_event.room_id != original_ev.room_id:
logger.debug(
"%s was redacted by %s but redaction was in a different room!",
Expand Down
70 changes: 70 additions & 0 deletions tests/storage/test_redaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from mock import Mock

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
Expand Down Expand Up @@ -216,3 +218,71 @@ def test_redact_join(self):
},
event.unsigned["redacted_because"],
)

def test_circular_redaction(self):
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"

class EventIdManglingBuilder:
def __init__(self, base_builder, event_id):
self._base_builder = base_builder
self._event_id = event_id

@defer.inlineCallbacks
def build(self, prev_event_ids):
built_event = yield self._base_builder.build(prev_event_ids)
built_event.event_id = self._event_id
built_event._event_dict["event_id"] = self._event_id
return built_event

@property
def room_id(self):
return self._base_builder.room_id

event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id2,
},
),
redaction_event_id1,
)
)
)

self.get_success(self.store.persist_event(event_1, context_1))

event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id1,
},
),
redaction_event_id2,
)
)
)
self.get_success(self.store.persist_event(event_2, context_2))

# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))

# it should have been redacted
self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
self.assertEqual(
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)