From 4ffe7a9ec769db4ab98e8954cbca13b0db41c1e9 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 6 May 2022 15:20:55 +0100
Subject: [PATCH 1/5] Ensure no presence is handled when presence is disabled.

Otherwise we can end up with memory leaks.
---
 synapse/handlers/presence.py | 42 +++++++++++++++++++++++-------------
 1 file changed, 27 insertions(+), 15 deletions(-)

diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d078162c2938..268481ec1963 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -659,27 +659,28 @@ def __init__(self, hs: "HomeServer"):
         )
 
         now = self.clock.time_msec()
-        for state in self.user_to_current_state.values():
-            self.wheel_timer.insert(
-                now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
-            )
-            self.wheel_timer.insert(
-                now=now,
-                obj=state.user_id,
-                then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
-            )
-            if self.is_mine_id(state.user_id):
+        if self._presence_enabled:
+            for state in self.user_to_current_state.values():
                 self.wheel_timer.insert(
-                    now=now,
-                    obj=state.user_id,
-                    then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL,
+                    now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
                 )
-            else:
                 self.wheel_timer.insert(
                     now=now,
                     obj=state.user_id,
-                    then=state.last_federation_update_ts + FEDERATION_TIMEOUT,
+                    then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
                 )
+                if self.is_mine_id(state.user_id):
+                    self.wheel_timer.insert(
+                        now=now,
+                        obj=state.user_id,
+                        then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL,
+                    )
+                else:
+                    self.wheel_timer.insert(
+                        now=now,
+                        obj=state.user_id,
+                        then=state.last_federation_update_ts + FEDERATION_TIMEOUT,
+                    )
 
         # Set of users who have presence in the `user_to_current_state` that
         # have not yet been persisted
@@ -804,6 +805,13 @@ async def _update_states(
                 This is currently used to bump the max presence stream ID without changing any
                 user's presence (see PresenceHandler.add_users_to_send_full_presence_to).
         """
+        if not self._presence_enabled:
+            # We shouldn't get here if presence is disabled, but we check anyway
+            # to ensure that we don't a) send out presence federation and b)
+            # don't add things to the wheel timer that will never be handled.
+            logger.warning("Tried to update presence states when presence is disabled")
+            return
+
         now = self.clock.time_msec()
 
         with Measure(self.clock, "presence_update_states"):
@@ -1229,6 +1237,10 @@ async def set_state(
         ):
             raise SynapseError(400, "Invalid presence state")
 
+        # If presence is disabled, no-op
+        if not self.hs.config.server.use_presence:
+            return
+
         user_id = target_user.to_string()
 
         prev_state = await self.current_state_for_user(user_id)

From bae966518e2ba86f3c8d2525f036d1a5ff261ec0 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 6 May 2022 15:21:44 +0100
Subject: [PATCH 2/5] Use a set for `WheelTimer` to better handle duplicates.

---
 synapse/util/wheel_timer.py | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index e108adc4604f..ab5d5a27de77 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -11,9 +11,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Generic, List, TypeVar
+from typing import Generic, Hashable, List, Set, TypeVar
 
-T = TypeVar("T")
+T = TypeVar("T", bound=Hashable)
 
 
 class _Entry(Generic[T]):
@@ -21,7 +21,10 @@ class _Entry(Generic[T]):
 
     def __init__(self, end_key: int) -> None:
         self.end_key: int = end_key
-        self.queue: List[T] = []
+
+        # We use a set here as otherwise we can end up with a lot of duplicate
+        # entries.
+        self.queue: Set[T] = set()
 
 
 class WheelTimer(Generic[T]):
@@ -55,7 +58,7 @@ def insert(self, now: int, obj: T, then: int) -> None:
 
             if then_key <= max_key:
                 # The max here is to protect against inserts for times in the past
-                self.entries[max(min_key, then_key) - min_key].queue.append(obj)
+                self.entries[max(min_key, then_key) - min_key].queue.add(obj)
                 return
 
         next_key = int(now / self.bucket_size) + 1
@@ -71,7 +74,7 @@ def insert(self, now: int, obj: T, then: int) -> None:
         # to insert. This ensures there are no gaps.
         self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
 
-        self.entries[-1].queue.append(obj)
+        self.entries[-1].queue.add(obj)
 
     def fetch(self, now: int) -> List[T]:
         """Fetch any objects that have timed out
@@ -84,7 +87,7 @@ def fetch(self, now: int) -> List[T]:
         """
         now_key = int(now / self.bucket_size)
 
-        ret = []
+        ret: List[T] = []
         while self.entries and self.entries[0].end_key <= now_key:
             ret.extend(self.entries.pop(0).queue)
 

From 769e141b1367ea9de09b24846032b18e40be7df4 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 6 May 2022 15:46:56 +0100
Subject: [PATCH 3/5] Warn if inserting to a wheel timer that hasn't been read
 recently

---
 synapse/util/wheel_timer.py | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index ab5d5a27de77..f36a2c9147c6 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -11,8 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 from typing import Generic, Hashable, List, Set, TypeVar
 
+logger = logging.getLogger(__name__)
+
 T = TypeVar("T", bound=Hashable)
 
 
@@ -51,17 +54,27 @@ def insert(self, now: int, obj: T, then: int) -> None:
             then: When to return the object strictly after.
         """
         then_key = int(then / self.bucket_size) + 1
+        now_key = int(now / self.bucket_size)
 
         if self.entries:
             min_key = self.entries[0].end_key
             max_key = self.entries[-1].end_key
 
+            if min_key < now_key - 10:
+                # If we have ten buckets that are due and still nothing has
+                # called `fetch()` then we likely have a bug that is causing a
+                # memory leak.
+                logger.warning(
+                    "Inserting into a wheel timer that hasn't been read from recently. Item: %s",
+                    obj,
+                )
+
             if then_key <= max_key:
                 # The max here is to protect against inserts for times in the past
                 self.entries[max(min_key, then_key) - min_key].queue.add(obj)
                 return
 
-        next_key = int(now / self.bucket_size) + 1
+        next_key = now_key + 1
         if self.entries:
             last_key = self.entries[-1].end_key
         else:

From 22413e8ac4a7215a31e0600eaad09842c1fee3fe Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 6 May 2022 15:47:47 +0100
Subject: [PATCH 4/5] Newsfile

---
 changelog.d/12656.misc | 1 +
 1 file changed, 1 insertion(+)
 create mode 100644 changelog.d/12656.misc

diff --git a/changelog.d/12656.misc b/changelog.d/12656.misc
new file mode 100644
index 000000000000..8a8743e614f3
--- /dev/null
+++ b/changelog.d/12656.misc
@@ -0,0 +1 @@
+Prevent memory leak from reoccurring when presence is disabled.

From 39fc179605837e8cbeb4e62ec2efbb8438083287 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 6 May 2022 17:15:34 +0100
Subject: [PATCH 5/5] Use attrs. Rename queue to elements

---
 synapse/util/wheel_timer.py | 21 +++++++++------------
 1 file changed, 9 insertions(+), 12 deletions(-)

diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index f36a2c9147c6..177e198e7e75 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -14,20 +14,17 @@
 import logging
 from typing import Generic, Hashable, List, Set, TypeVar
 
+import attr
+
 logger = logging.getLogger(__name__)
 
 T = TypeVar("T", bound=Hashable)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class _Entry(Generic[T]):
-    __slots__ = ["end_key", "queue"]
-
-    def __init__(self, end_key: int) -> None:
-        self.end_key: int = end_key
-
-        # We use a set here as otherwise we can end up with a lot of duplicate
-        # entries.
-        self.queue: Set[T] = set()
+    end_key: int
+    elements: Set[T] = attr.Factory(set)
 
 
 class WheelTimer(Generic[T]):
@@ -71,7 +68,7 @@ def insert(self, now: int, obj: T, then: int) -> None:
 
             if then_key <= max_key:
                 # The max here is to protect against inserts for times in the past
-                self.entries[max(min_key, then_key) - min_key].queue.add(obj)
+                self.entries[max(min_key, then_key) - min_key].elements.add(obj)
                 return
 
         next_key = now_key + 1
@@ -87,7 +84,7 @@ def insert(self, now: int, obj: T, then: int) -> None:
         # to insert. This ensures there are no gaps.
         self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
 
-        self.entries[-1].queue.add(obj)
+        self.entries[-1].elements.add(obj)
 
     def fetch(self, now: int) -> List[T]:
         """Fetch any objects that have timed out
@@ -102,9 +99,9 @@ def fetch(self, now: int) -> List[T]:
 
         ret: List[T] = []
         while self.entries and self.entries[0].end_key <= now_key:
-            ret.extend(self.entries.pop(0).queue)
+            ret.extend(self.entries.pop(0).elements)
 
         return ret
 
     def __len__(self) -> int:
-        return sum(len(entry.queue) for entry in self.entries)
+        return sum(len(entry.elements) for entry in self.entries)