Skip to content

Commit

Permalink
Add gRPC Context propagation
Browse files Browse the repository at this point in the history
It grabs the Context from the same thread as the thread used for gRPC,
and then forks and attaches it to the thread used when calling into
the gRPC stub implementation (async boundary).

Close typelevel#7
  • Loading branch information
Tapped committed Sep 17, 2019
1 parent 2281ce9 commit 4953a3e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
10 changes: 8 additions & 2 deletions java-runtime/src/main/scala/server/Fs2ServerCallListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import cats.effect._
import cats.effect.concurrent.Deferred
import cats.implicits._
import fs2.Stream
import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException}
import io.grpc.{Context, Metadata, Status, StatusException, StatusRuntimeException}

private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
def source: G[Request]
Expand Down Expand Up @@ -38,8 +38,14 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
case ExitCase.Error(t) => reportError(t)
}

// It's important that we call 'Context.current()' at this point,
// since Context is stored in thread local storage,
// so we have to grab it while we are in the callback thread of gRPC
val initialCtx = Context.current().fork()
val context = Resource.make(F.delay(initialCtx.attach()))(previous => F.delay(initialCtx.detach(previous)))

// Exceptions are reported by closing the call
F.runAsync(F.race(bracketed, isCancelled.get))(_ => IO.unit).unsafeRunSync()
F.runAsync(F.race(context.use(_ => bracketed), isCancelled.get))(_ => IO.unit).unsafeRunSync()
}

def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[Response])(
Expand Down
34 changes: 23 additions & 11 deletions java-runtime/src/test/scala/server/ServerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ object ServerSuite extends SimpleTestSuite {
val dummy = new DummyServerCall
val listener = Fs2UnaryServerCallListener[IO](dummy, options).unsafeRunSync()

listener.unsafeUnaryResponse(new Metadata(), _.map(_.length))
val testKey = Context.key[Int]("test-key")
val testKeyValue = 3123
Context.current().withValue(testKey, testKeyValue).run(new Runnable {
override def run(): Unit = listener.unsafeUnaryResponse(new Metadata(), _.map(_.length + testKey.get()))
})
listener.onMessage("123")
listener.onHalfClose()

ec.tick()

assertEquals(dummy.messages.size, 1)
assertEquals(dummy.messages(0), 3)
assertEquals(dummy.messages(0), 3 + testKeyValue)
assertEquals(dummy.currentStatus.isDefined, true)
assertEquals(dummy.currentStatus.get.isOk, true)
}
Expand Down Expand Up @@ -189,18 +193,22 @@ object ServerSuite extends SimpleTestSuite {
val dummy = new DummyServerCall
val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy).unsafeRunSync()

listener.unsafeStreamResponse(
new Metadata(),
_.map(_.length) ++ Stream.emit(0) ++ Stream.raiseError[IO](new RuntimeException("hello")))
val testKey = Context.key[Int]("test-key")
val testKeyValue = 3123
Context.current().withValue(testKey, testKeyValue).run(new Runnable {
override def run(): Unit = listener.unsafeStreamResponse(
new Metadata(),
_.map(_.length) ++ Stream.emit(0) ++ Stream.eval(IO(testKey.get())) ++ Stream.raiseError[IO](new RuntimeException("hello")))
})
listener.onMessage("a")
listener.onMessage("ab")
listener.onHalfClose()
listener.onMessage("abc")

ec.tick()

assertEquals(dummy.messages.length, 3)
assertEquals(dummy.messages.toList, List(1, 2, 0))
assertEquals(dummy.messages.length, 4)
assertEquals(dummy.messages.toList, List(1, 2, 0, testKeyValue))
assertEquals(dummy.currentStatus.isDefined, true)
assertEquals(dummy.currentStatus.get.isOk, false)
}
Expand All @@ -214,21 +222,25 @@ object ServerSuite extends SimpleTestSuite {
implicit val ec: TestContext = TestContext()
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val implementation: Stream[IO, String] => IO[Int] =
_.compile.foldMonoid.map(_.length)
val testKey = Context.key[Int]("test-key")
val testKeyValue = 3123
val implementation: Stream[IO, String] => IO[Int] = stream =>
IO(testKey.get) >>= (value => stream.compile.foldMonoid.map(_.length + value))

val dummy = new DummyServerCall
val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, options).unsafeRunSync()

listener.unsafeUnaryResponse(new Metadata(), implementation)
Context.current().withValue(testKey, testKeyValue).run(new Runnable {
override def run(): Unit = listener.unsafeUnaryResponse(new Metadata(), implementation)
})
listener.onMessage("ab")
listener.onMessage("abc")
listener.onHalfClose()

ec.tick()

assertEquals(dummy.messages.length, 1)
assertEquals(dummy.messages(0), 5)
assertEquals(dummy.messages(0), 5 + testKeyValue)
assertEquals(dummy.currentStatus.isDefined, true)
assertEquals(dummy.currentStatus.get.isOk, true)
}
Expand Down

0 comments on commit 4953a3e

Please sign in to comment.