From 8474613953c94ba2769939913b8c18e92897785c Mon Sep 17 00:00:00 2001 From: Tapped Date: Wed, 31 Jul 2019 17:25:40 +0200 Subject: [PATCH] Add gRPC Context propagation It grabs the Context from the same thread as the thread used for gRPC, and then forks and attaches it to the thread used when calling into the gRPC stub implementation (async boundary). Close #7 --- .../scala/server/Fs2ServerCallListener.scala | 10 ++++-- .../src/test/scala/server/ServerSuite.scala | 32 +++++++++++++------ 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala b/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala index 2bbffdc0..95c21a8b 100644 --- a/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala +++ b/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala @@ -6,7 +6,7 @@ import cats.effect._ import cats.effect.concurrent.Deferred import cats.implicits._ import fs2.Stream -import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException} +import io.grpc.{Context, Metadata, Status, StatusException, StatusRuntimeException} private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { def source: G[Request] @@ -38,8 +38,14 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { case ExitCase.Error(t) => reportError(t) } + // It's important that we call 'Context.current()' at this point, + // since Context is stored in thread local storage, + // so we have to grab it while we are in the callback thread of gRPC + val initialCtx = Context.current().fork() + val context = Resource.make(F.delay(initialCtx.attach()))(previous => F.delay(initialCtx.detach(previous))) + // Exceptions are reported by closing the call - F.runAsync(F.race(bracketed, isCancelled.get))(_ => IO.unit).unsafeRunSync() + F.runAsync(F.race(context.use(_ => bracketed), isCancelled.get))(_ => IO.unit).unsafeRunSync() } def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[Response])( diff --git a/java-runtime/src/test/scala/server/ServerSuite.scala b/java-runtime/src/test/scala/server/ServerSuite.scala index 64d4fd0b..917413f8 100644 --- a/java-runtime/src/test/scala/server/ServerSuite.scala +++ b/java-runtime/src/test/scala/server/ServerSuite.scala @@ -25,14 +25,18 @@ object ServerSuite extends SimpleTestSuite { val dummy = new DummyServerCall val listener = Fs2UnaryServerCallListener[IO](dummy, options).unsafeRunSync() - listener.unsafeUnaryResponse(new Metadata(), _.map(_.length)) + val testKey = Context.key[Int]("test-key") + val testKeyValue = 3123 + Context.current().withValue(testKey, testKeyValue).run(new Runnable { + override def run(): Unit = listener.unsafeUnaryResponse(new Metadata(), _.map(_.length + testKey.get())) + }) listener.onMessage("123") listener.onHalfClose() ec.tick() assertEquals(dummy.messages.size, 1) - assertEquals(dummy.messages(0), 3) + assertEquals(dummy.messages(0), 3 + testKeyValue) assertEquals(dummy.currentStatus.isDefined, true) assertEquals(dummy.currentStatus.get.isOk, true) } @@ -189,9 +193,13 @@ object ServerSuite extends SimpleTestSuite { val dummy = new DummyServerCall val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy).unsafeRunSync() - listener.unsafeStreamResponse( - new Metadata(), - _.map(_.length) ++ Stream.emit(0) ++ Stream.raiseError[IO](new RuntimeException("hello"))) + val testKey = Context.key[Int]("test-key") + val testKeyValue = 3123 + Context.current().withValue(testKey, testKeyValue).run(new Runnable { + override def run(): Unit = listener.unsafeStreamResponse( + new Metadata(), + _.map(_.length) ++ Stream.emit(0) ++ Stream.emit(testKey.get()) ++ Stream.raiseError[IO](new RuntimeException("hello"))) + }) listener.onMessage("a") listener.onMessage("ab") listener.onHalfClose() @@ -200,7 +208,7 @@ object ServerSuite extends SimpleTestSuite { ec.tick() assertEquals(dummy.messages.length, 3) - assertEquals(dummy.messages.toList, List(1, 2, 0)) + assertEquals(dummy.messages.toList, List(1, 2, testKeyValue, 0)) assertEquals(dummy.currentStatus.isDefined, true) assertEquals(dummy.currentStatus.get.isOk, false) } @@ -214,13 +222,17 @@ object ServerSuite extends SimpleTestSuite { implicit val ec: TestContext = TestContext() implicit val cs: ContextShift[IO] = IO.contextShift(ec) - val implementation: Stream[IO, String] => IO[Int] = - _.compile.foldMonoid.map(_.length) + val testKey = Context.key[Int]("test-key") + val testKeyValue = 3123 + val implementation: Stream[IO, String] => IO[Int] = stream => + IO(testKey.get) >>= (value => stream.compile.foldMonoid.map(_.length + value)) val dummy = new DummyServerCall val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, options).unsafeRunSync() - listener.unsafeUnaryResponse(new Metadata(), implementation) + Context.current().withValue(testKey, testKeyValue).run(new Runnable { + override def run(): Unit = listener.unsafeUnaryResponse(new Metadata(), implementation) + }) listener.onMessage("ab") listener.onMessage("abc") listener.onHalfClose() @@ -228,7 +240,7 @@ object ServerSuite extends SimpleTestSuite { ec.tick() assertEquals(dummy.messages.length, 1) - assertEquals(dummy.messages(0), 5) + assertEquals(dummy.messages(0), 5 + testKeyValue) assertEquals(dummy.currentStatus.isDefined, true) assertEquals(dummy.currentStatus.get.isOk, true) }