diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 46a1c0746079..a6c6676f8d40 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ from synapse.api.errors import StoreError from synapse.events import FrozenEvent from synapse.events.utils import prune_event +from synapse.util import unwrap_deferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -28,7 +29,6 @@ from collections import namedtuple, OrderedDict import functools -import itertools import simplejson as json import sys import time @@ -870,35 +870,43 @@ def func(txn): @defer.inlineCallbacks def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, desc="_get_events"): - N = 50 # Only fetch 100 events at a time. + get_prev_content=False, allow_rejected=False, txn=None): + if not event_ids: + defer.returnValue([]) - ds = [ - self._fetch_events( - event_ids[i*N:(i+1)*N], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - ) - for i in range(1 + len(event_ids) / N) - ] + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) - res = yield defer.gatherResults(ds, consumeErrors=True) + missing_events = [e for e in event_ids if e not in event_map] - defer.returnValue( - list(itertools.chain(*res)) + missing_events = yield self._fetch_events( + txn, + missing_events, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, ) + event_map.update(missing_events) + + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + def _get_events_txn(self, txn, event_ids, check_redacted=True, - get_prev_content=False): - N = 50 # Only fetch 100 events at a time. - return list(itertools.chain(*[ - self._fetch_events_txn( - txn, event_ids[i*N:(i+1)*N], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - ) - for i in range(1 + len(event_ids) / N) - ])) + get_prev_content=False, allow_rejected=False): + return unwrap_deferred(self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + txn=txn, + )) def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): @@ -909,68 +917,24 @@ def _invalidate_get_event_cache(self, event_id): def _get_event_txn(self, txn, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False): - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - try: - ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) - - if allow_rejected or not ret.rejected_reason: - return ret - else: - return None - except KeyError: - pass - finally: - start_time = update_counter("event_cache", start_time) - - sql = ( - "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " - "FROM event_json as e " - "LEFT JOIN rejections as rej USING (event_id) " - "LEFT JOIN redactions as r ON e.event_id = r.redacts " - "WHERE e.event_id = ? " - "LIMIT 1 " - ) - - txn.execute(sql, (event_id,)) - - res = txn.fetchone() - - if not res: - return None - - internal_metadata, js, redacted, rejected_reason = res - - start_time = update_counter("select_event", start_time) - - result = self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, + events = self._get_events_txn( + txn, [event_id], check_redacted=check_redacted, get_prev_content=get_prev_content, - rejected_reason=rejected_reason, + allow_rejected=allow_rejected, ) - self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result) - if allow_rejected or not rejected_reason: - return result - else: - return None - - def _fetch_events_txn(self, txn, events, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not events: - return [] + return events[0] if events else None + def _get_events_from_cache(self, events, check_redacted, get_prev_content, + allow_rejected): event_map = {} for event_id in events: try: - ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) + ret = self._get_event_cache.get( + event_id, check_redacted, get_prev_content + ) if allow_rejected or not ret.rejected_reason: event_map[event_id] = ret @@ -979,200 +943,81 @@ def _fetch_events_txn(self, txn, events, check_redacted=True, except KeyError: pass - missing_events = [ - e for e in events - if e not in event_map - ] - - if missing_events: - sql = ( - "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"]*len(missing_events)),) - - txn.execute(sql, missing_events) - rows = txn.fetchall() - - res = [ - self._get_event_from_row_txn( - txn, row[0], row[1], row[2], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=row[3], - ) - for row in rows - ] - - event_map.update({ - e.event_id: e - for e in res if e - }) - - for e in res: - self._get_event_cache.prefill( - e.event_id, check_redacted, get_prev_content, e - ) - - return [ - event_map[e_id] for e_id in events - if e_id in event_map and event_map[e_id] - ] + return event_map @defer.inlineCallbacks - def _fetch_events(self, events, check_redacted=True, + def _fetch_events(self, txn, events, check_redacted=True, get_prev_content=False, allow_rejected=False): if not events: - defer.returnValue([]) - - event_map = {} + defer.returnValue({}) - for event_id in events: - try: - ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) - - if allow_rejected or not ret.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - except KeyError: - pass + rows = [] + N = 2 + for i in range(1 + len(events) / N): + evs = events[i*N:(i + 1)*N] + if not evs: + break - missing_events = [ - e for e in events - if e not in event_map - ] - - if missing_events: sql = ( "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " " FROM event_json as e" " LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN redactions as r ON e.event_id = r.redacts" " WHERE e.event_id IN (%s)" - ) % (",".join(["?"]*len(missing_events)),) - - rows = yield self._execute( - "_fetch_events", - None, - sql, - *missing_events - ) - - res_ds = [ - self._get_event_from_row( - row[0], row[1], row[2], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=row[3], - ) - for row in rows - ] + ) % (",".join(["?"]*len(evs)),) - res = yield defer.gatherResults(res_ds, consumeErrors=True) + if txn: + txn.execute(sql, evs) + rows.extend(txn.fetchall()) + else: + res = yield self._execute("_fetch_events", None, sql, *evs) + rows.extend(res) - event_map.update({ - e.event_id: e - for e in res if e - }) + res = [] + for row in rows: + e = yield self._get_event_from_row( + txn, + row[0], row[1], row[2], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row[3], + ) + res.append(e) - for e in res: - self._get_event_cache.prefill( - e.event_id, check_redacted, get_prev_content, e - ) + for e in res: + self._get_event_cache.prefill( + e.event_id, check_redacted, get_prev_content, e + ) - defer.returnValue([ - event_map[e_id] for e_id in events - if e_id in event_map and event_map[e_id] - ]) + defer.returnValue({ + e.event_id: e + for e in res if e + }) @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, + def _get_event_from_row(self, txn, internal_metadata, js, redacted, check_redacted=True, get_prev_content=False, rejected_reason=None): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - d = json.loads(js) - start_time = update_counter("decode_json", start_time) - internal_metadata = json.loads(internal_metadata) - start_time = update_counter("decode_internal", start_time) - - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - desc="_get_event_from_row", - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - ) - - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - start_time = update_counter("build_frozen_event", start_time) - - if check_redacted and redacted: - ev = prune_event(ev) - - redaction_id = yield self._simple_select_one_onecol( - desc="_get_event_from_row", - table="redactions", - keyvalues={"redacts": ev.event_id}, - retcol="event_id", - ) - - ev.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - because = yield self.get_event_txn( - redaction_id, - check_redacted=False - ) - - if because: - ev.unsigned["redacted_because"] = because - start_time = update_counter("redact_event", start_time) - - if get_prev_content and "replaces_state" in ev.unsigned: - prev = yield self.get_event( - ev.unsigned["replaces_state"], - get_prev_content=False, - ) - if prev: - ev.unsigned["prev_content"] = prev.get_dict()["content"] - start_time = update_counter("get_prev_content", start_time) - - defer.returnValue(ev) - - def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, - rejected_reason=None): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - d = json.loads(js) - start_time = update_counter("decode_json", start_time) + def select(txn, *args, **kwargs): + if txn: + return self._simple_select_one_onecol_txn(txn, *args, **kwargs) + else: + return self._simple_select_one_onecol( + *args, + desc="_get_event_from_row", **kwargs + ) - internal_metadata = json.loads(internal_metadata) - start_time = update_counter("decode_internal", start_time) + def get_event(txn, *args, **kwargs): + if txn: + return self._get_event_txn(txn, *args, **kwargs) + else: + return self.get_event(*args, **kwargs) if rejected_reason: - rejected_reason = self._simple_select_one_onecol_txn( + rejected_reason = yield select( txn, table="rejections", keyvalues={"event_id": rejected_reason}, @@ -1184,12 +1029,11 @@ def update_counter(desc, last_time): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) - start_time = update_counter("build_frozen_event", start_time) if check_redacted and redacted: ev = prune_event(ev) - redaction_id = self._simple_select_one_onecol_txn( + redaction_id = yield select( txn, table="redactions", keyvalues={"redacts": ev.event_id}, @@ -1199,7 +1043,7 @@ def update_counter(desc, last_time): ev.unsigned["redacted_by"] = redaction_id # Get the redaction event. - because = self._get_event_txn( + because = yield get_event( txn, redaction_id, check_redacted=False @@ -1207,19 +1051,17 @@ def update_counter(desc, last_time): if because: ev.unsigned["redacted_because"] = because - start_time = update_counter("redact_event", start_time) if get_prev_content and "replaces_state" in ev.unsigned: - prev = self._get_event_txn( + prev = yield get_event( txn, ev.unsigned["replaces_state"], get_prev_content=False, ) if prev: ev.unsigned["prev_content"] = prev.get_dict()["content"] - start_time = update_counter("get_prev_content", start_time) - return ev + defer.returnValue(ev) def _parse_events(self, rows): return self.runInteraction( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 483b316e9ff1..26fd3b3e6716 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -85,7 +85,7 @@ def f(txn): @defer.inlineCallbacks def c(vals): - vals[:] = yield self._fetch_events(vals, get_prev_content=False) + vals[:] = yield self._get_events(vals, get_prev_content=False) yield defer.gatherResults( [ diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index c1a16b639af4..b9afb3364df6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -29,6 +29,34 @@ def unwrapFirstError(failure): return failure.value.subFailure +def unwrap_deferred(d): + """Given a deferred that we know has completed, return its value or raise + the failure as an exception + """ + if not d.called: + raise RuntimeError("deferred has not finished") + + res = [] + + def f(r): + res.append(r) + return r + d.addCallback(f) + + if res: + return res[0] + + def f(r): + res.append(r) + return r + d.addErrback(f) + + if res: + res[0].raiseException() + else: + raise RuntimeError("deferred did not call callbacks") + + class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests.