Skip to content

Commit

Permalink
Add gRPC Context propagation
Browse files Browse the repository at this point in the history
It grabs the Context from the same thread as the thread used for gRPC,
and then forks and attaches it to the thread used when calling into
the gRPC stub implementation (async boundary).

Close typelevel#7
  • Loading branch information
Tapped committed Aug 1, 2019
1 parent 9c526a2 commit c966c66
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
10 changes: 8 additions & 2 deletions java-runtime/src/main/scala/server/Fs2ServerCallListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import cats.effect._
import cats.effect.concurrent.Deferred
import cats.implicits._
import fs2.Stream
import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException}
import io.grpc.{Context, Metadata, Status, StatusException, StatusRuntimeException}

private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
def source: G[Request]
Expand Down Expand Up @@ -38,8 +38,14 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
case ExitCase.Error(t) => reportError(t)
}

// It's important that we call 'Context.current()' at this point,
// since Context is stored in thread local storage,
// so we have to grab it while we are in the callback thread of gRPC
val initialCtx = Context.current().fork()
val context = Resource.make(F.delay(initialCtx.attach()))(previous => F.delay(initialCtx.detach(previous)))

// Exceptions are reported by closing the call
F.runAsync(F.race(bracketed, isCancelled.get))(_ => IO.unit).unsafeRunSync()
F.runAsync(F.race(context.use(_ => bracketed), isCancelled.get))(_ => IO.unit).unsafeRunSync()
}

def unsafeUnaryResponse(headers: Metadata, implementation: G[Request] => F[Response])(
Expand Down
22 changes: 22 additions & 0 deletions java-runtime/src/test/scala/server/ServerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ object ServerSuite extends SimpleTestSuite {
assertEquals(dummy.currentStatus.get.isOk, true)
}

test("single message to unaryToUnary with context propagation") {
implicit val ec: TestContext = TestContext()
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyServerCall
val listener = Fs2UnaryServerCallListener[IO](dummy, ServerCallOptions.default).unsafeRunSync()

val testKey = Context.key[Int]("test-key")
Context.current().withValue(testKey, 3123).run(() =>
listener.unsafeUnaryResponse(new Metadata(), _.map(_ => testKey.get()))
)
listener.onMessage("123")
listener.onHalfClose()

ec.tick()

assertEquals(dummy.messages.size, 1)
assertEquals(dummy.messages(0), 3123)
assertEquals(dummy.currentStatus.isDefined, true)
assertEquals(dummy.currentStatus.get.isOk, true)
}

test("cancellation for unaryToUnary") {

implicit val ec: TestContext = TestContext()
Expand Down

0 comments on commit c966c66

Please sign in to comment.