diff --git a/synapse/logging/context.py b/synapse/logging/context.py index b456c31f7071..60af53ae093c 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -611,10 +611,20 @@ def make_deferred_yieldable(deferred): # immediately. We may as well optimise out the logcontext faffery. return deferred - # ok, we can't be sure that a yield won't block, so let's reset the + if LoggingContext.current_context() is LoggingContext.sentinel: + # We're already in the sentinel context, so there's nothing to do. If we + # did attempt to "restore" the log context then we could easily clobber + # another log context that had been saved during the generation of this + # deferred. + # This makes it safe to call `make_deferred_yieldable` on deferreds that + # already have set the log contexts. + return deferred + + # Ok, we can't be sure that a yield won't block, so let's reset the # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) - deferred.addBoth(_set_context_cb, prev_context) + if LoggingContext.current_context() is not LoggingContext.sentinel: + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) return deferred diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8b8455c8b7b0..bdb4fb7eb05b 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -174,6 +174,29 @@ def test_make_deferred_yieldable_on_non_deferred(self): self.assertEqual(r, "bum") self._check_test_key("one") + @defer.inlineCallbacks + def test_make_deferred_yieldable_reentrant(self): + """Test that `make_deferred_yieldable` does the right thing on deferreds + that already follow log context rules, i.e. wrapping a deferred with + `make_deferred_yieldable` multiple times. + """ + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = make_deferred_yieldable(_chained_deferred_function()) + d2 = make_deferred_yieldable(d1) + + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d2 + + # now it should be restored + self._check_test_key("one") + def test_nested_logging_context(self): with LoggingContext(request="foo"): nested_context = nested_logging_context(suffix="bar")