Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Batch up storing state groups when creating new room #14918

Merged
merged 15 commits into from
Feb 24, 2023
Merged
1 change: 1 addition & 0 deletions changelog.d/14918.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Batch up storing state groups when creating a new room.
49 changes: 49 additions & 0 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if TYPE_CHECKING:
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases import StateGroupDataStore
from synapse.storage.databases.main import DataStore
from synapse.types.state import StateFilter

Expand Down Expand Up @@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None

@classmethod
async def batch_persist_unpersisted_contexts(
cls,
events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
room_id: str,
last_known_state_group: int,
datastore: "StateGroupDataStore",
) -> List[Tuple[EventBase, EventContext]]:
"""
Takes a list of events and their associated unpersisted contexts and persists
the unpersisted contexts, returning a list of events and persisted contexts.
Note that all the events must be in a linear chain (ie a <- b <- c).

Args:
events_and_context: A list of events and their unpersisted contexts
room_id: the room_id for the events
last_known_state_group: the last persisted state group
datastore: a state datastore
"""
amended_events_and_context = await datastore.store_state_deltas_for_batched(
events_and_context, room_id, last_known_state_group
)

events_and_persisted_context = []
for event, unpersisted_context in amended_events_and_context:
if event.is_state():
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.state_group_before_event,
delta_ids=unpersisted_context.state_delta_due_to_event,
)
else:
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.prev_group_for_state_group_before_event,
delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
)
events_and_persisted_context.append((event, context))
return events_and_persisted_context

async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
Expand Down
16 changes: 7 additions & 9 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ async def create_event(
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
Expand Down Expand Up @@ -721,8 +721,6 @@ async def create_event(
current_state_group=current_state_group,
)

context = await unpersisted_context.persist(event)

# In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new
Expand All @@ -739,7 +737,7 @@ async def create_event(
assert state_map is not None
prev_event_id = state_map.get((EventTypes.Member, event.sender))
else:
prev_state_ids = await context.get_prev_state_ids(
prev_state_ids = await unpersisted_context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
Expand All @@ -764,8 +762,7 @@ async def create_event(
)

self.validator.validate_new(event, self.config)

return event, context
return event, unpersisted_context

async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
Expand Down Expand Up @@ -1005,7 +1002,7 @@ async def create_and_send_nonmember_event(
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.create_event(
event, unpersisted_context = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
Expand All @@ -1016,6 +1013,7 @@ async def create_and_send_nonmember_event(
historical=historical,
depth=depth,
)
context = await unpersisted_context.persist(event)

assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender,
Expand Down Expand Up @@ -1190,7 +1188,6 @@ async def create_new_client_event(
if for_batch:
assert prev_event_ids is not None
assert state_map is not None
assert current_state_group is not None
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
Expand Down Expand Up @@ -2042,7 +2039,7 @@ async def _send_dummy_event_for_room(self, room_id: str) -> bool:
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.create_event(
event, unpersisted_context = await self.create_event(
requester,
{
"type": EventTypes.Dummy,
Expand All @@ -2051,6 +2048,7 @@ async def _send_dummy_event_for_room(self, room_id: str) -> bool:
"sender": user_id,
},
)
context = await unpersisted_context.persist(event)

event.internal_metadata.proactively_send = False

Expand Down
37 changes: 22 additions & 15 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext
from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
Expand Down Expand Up @@ -211,7 +212,7 @@ async def upgrade_room(
# the required power level to send the tombstone event.
(
tombstone_event,
tombstone_context,
tombstone_unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
{
Expand All @@ -225,6 +226,9 @@ async def upgrade_room(
},
},
)
tombstone_context = await tombstone_unpersisted_context.persist(
tombstone_event
)
validate_event_for_room_version(tombstone_event)
await self._event_auth_handler.check_auth_rules_from_context(
tombstone_event
Expand Down Expand Up @@ -1091,7 +1095,7 @@ async def create_event(
content: JsonDict,
for_batch: bool,
**kwargs: Any,
) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
"""
Creates an event and associated event context.
Args:
Expand All @@ -1110,20 +1114,23 @@ async def create_event(

event_dict = create_event_dict(etype, content, **kwargs)

new_event, new_context = await self.event_creation_handler.create_event(
(
new_event,
new_unpersisted_context,
) = await self.event_creation_handler.create_event(
creator,
event_dict,
prev_event_ids=prev_event,
depth=depth,
state_map=state_map,
for_batch=for_batch,
current_state_group=current_state_group,
)

depth += 1
prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id

return new_event, new_context
return new_event, new_unpersisted_context

try:
config = self._presets_dict[preset_config]
Expand All @@ -1133,10 +1140,10 @@ async def create_event(
)

creation_content.update({"creator": creator_id})
creation_event, creation_context = await create_event(
creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
)

creation_context = await unpersisted_creation_context.persist(creation_event)
logger.debug("Sending %s in new room", EventTypes.Member)
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
Expand Down Expand Up @@ -1180,7 +1187,6 @@ async def create_event(
power_event, power_context = await create_event(
EventTypes.PowerLevels, pl_content, True
)
current_state_group = power_context._state_group
events_to_send.append((power_event, power_context))
else:
power_level_content: JsonDict = {
Expand Down Expand Up @@ -1229,14 +1235,12 @@ async def create_event(
power_level_content,
True,
)
current_state_group = pl_context._state_group
events_to_send.append((pl_event, pl_context))

if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
)
current_state_group = room_alias_context._state_group
events_to_send.append((room_alias_event, room_alias_context))

if (EventTypes.JoinRules, "") not in initial_state:
Expand All @@ -1245,7 +1249,6 @@ async def create_event(
{"join_rule": config["join_rules"]},
True,
)
current_state_group = join_rules_context._state_group
events_to_send.append((join_rules_event, join_rules_context))

if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
Expand All @@ -1254,7 +1257,6 @@ async def create_event(
{"history_visibility": config["history_visibility"]},
True,
)
current_state_group = visibility_context._state_group
events_to_send.append((visibility_event, visibility_context))

if config["guest_can_join"]:
Expand All @@ -1264,14 +1266,12 @@ async def create_event(
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True,
)
current_state_group = guest_access_context._state_group
events_to_send.append((guest_access_event, guest_access_context))

for (etype, state_key), content in initial_state.items():
event, context = await create_event(
etype, content, True, state_key=state_key
)
current_state_group = context._state_group
events_to_send.append((event, context))

if config["encrypted"]:
Expand All @@ -1283,9 +1283,16 @@ async def create_event(
)
events_to_send.append((encryption_event, encryption_context))

datastore = self.hs.get_datastores().state
events_and_context = (
await UnpersistedEventContext.batch_persist_unpersisted_contexts(
events_to_send, room_id, current_state_group, datastore
)
)

last_event = await self.event_creation_handler.handle_new_client_event(
creator,
events_to_send,
events_and_context,
ignore_shadow_ban=True,
ratelimit=False,
)
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/room_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async def persist_historical_events(
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True

event, context = await self.event_creation_handler.create_event(
event, unpersisted_context = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service
),
Expand All @@ -345,7 +345,7 @@ async def persist_historical_events(
historical=True,
depth=inherited_depth,
)

context = await unpersisted_context.persist(event)
assert context._state_group

# Normally this is done when persisting the event but we have to
Expand Down
13 changes: 10 additions & 3 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,10 @@ async def _local_membership_update(
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.event_creation_handler.create_event(
(
event,
unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
Expand All @@ -435,7 +438,7 @@ async def _local_membership_update(
outlier=outlier,
historical=historical,
)

context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
Expand Down Expand Up @@ -1944,14 +1947,18 @@ async def _generate_local_out_of_band_leave(
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.event_creation_handler.create_event(
(
event,
unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
outlier=True,
)
context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True

result_event = (
Expand Down
Loading