Skip to content

Commit

Permalink
=htc refactor HttpClient stream setup, closes akka#16510
Browse files Browse the repository at this point in the history
  • Loading branch information
sirthias committed Dec 19, 2014
1 parent 735fdb4 commit 5538a31
Show file tree
Hide file tree
Showing 13 changed files with 687 additions and 127 deletions.
224 changes: 199 additions & 25 deletions akka-http-core/src/main/scala/akka/http/engine/client/HttpClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
package akka.http.engine.client

import java.net.InetSocketAddress
import scala.collection.immutable.Queue
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import akka.stream.stage._
import akka.util.ByteString
import akka.event.LoggingAdapter
import akka.stream.FlattenStrategy
import akka.stream.scaladsl._
import akka.stream.scaladsl.OperationAttributes._
import akka.http.model.{ HttpMethod, HttpRequest, HttpResponse }
import akka.http.model.{ IllegalResponseException, HttpMethod, HttpRequest, HttpResponse }
import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory }
import akka.http.engine.parsing.{ HttpHeaderParser, HttpResponseParser }
import akka.http.engine.parsing.ParserOutput._
import akka.http.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser }
import akka.http.util._

/**
Expand All @@ -37,39 +38,212 @@ private[http] object HttpClient {
})

val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log)
val requestMethodByPass = new RequestMethodByPass(remoteAddress)

Flow[HttpRequest]
.map(requestMethodByPass)
/*
Basic Stream Setup
==================
requestIn +----------+
+-----------------------------------------------+--->| Termi- | requestRendering
| | nation +---------------------> |
+-------------------------------------->| Merge | |
| Termination Backchannel | +----------+ | TCP-
| | | level
| | Method | client
| +------------+ | Bypass | flow
responseOut | responsePrep | Response |<---+ |
<------------+----------------| Parsing | |
| Merge |<------------------------------------------ V
+------------+
*/

val requestIn = UndefinedSource[HttpRequest]
val responseOut = UndefinedSink[HttpResponse]

val methodBypassFanout = Broadcast[HttpRequest]
val responseParsingMerge = new ResponseParsingMerge(rootParser)

val terminationFanout = Broadcast[HttpResponse]
val terminationMerge = new TerminationMerge

val requestRendering = Flow[HttpRequest]
.map(RequestRenderingContext(_, remoteAddress))
.section(name("renderer"))(_.transform(() requestRendererFactory.newRenderer))
.flatten(FlattenStrategy.concat)

val transportFlow = Flow[ByteString]
.section(name("errorLogger"))(_.transform(() errorLogger(log, "Outgoing request stream error")))
.via(transport)
.section(name("rootParser"))(_.transform(()
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
rootParser.createShallowCopy(requestMethodByPass)))
.splitWhen(_.isInstanceOf[MessageStart])

val methodBypass = Flow[HttpRequest].map(_.method)

import ParserOutput._
val responsePrep = Flow[List[ResponseOutput]]
.transform(recover { case x: ResponseParsingError x.error :: Nil }) // FIXME after #16565
.mapConcat(identityFunc)
.splitWhen(x x.isInstanceOf[MessageStart] || x == MessageEnd)
.headAndTail
.collect {
case (ResponseStart(statusCode, protocol, headers, createEntity, _), entityParts)
HttpResponse(statusCode, headers, createEntity(entityParts), protocol)
case (MessageStartError(_, info), _) throw IllegalResponseException(info)
}

import FlowGraphImplicits._

Flow() { implicit b
requestIn ~> methodBypassFanout ~> terminationMerge.requestInput ~> requestRendering ~> transportFlow ~>
responseParsingMerge.dataInput ~> responsePrep ~> terminationFanout ~> responseOut
methodBypassFanout ~> methodBypass ~> responseParsingMerge.methodBypassInput
terminationFanout ~> terminationMerge.terminationBackchannelInput

b.allowCycles()

requestIn -> responseOut
}
}

// FIXME: refactor to a pure-stream design that allows us to get rid of this ad-hoc queue here
class RequestMethodByPass(serverAddress: InetSocketAddress)
extends (HttpRequest RequestRenderingContext) with (() HttpMethod) {
private[this] var requestMethods = Queue.empty[HttpMethod]
def apply(request: HttpRequest) = {
requestMethods = requestMethods.enqueue(request.method)
RequestRenderingContext(request, serverAddress)
// a simple merge stage that simply forwards its first input and ignores its second input
// (the terminationBackchannelInput), but applies a special completion handling
class TerminationMerge extends FlexiMerge[HttpRequest] {
import FlexiMerge._
val requestInput = createInputPort[HttpRequest]()
val terminationBackchannelInput = createInputPort[HttpResponse]()

def createMergeLogic() = new MergeLogic[HttpRequest] {
override def inputHandles(inputCount: Int) = {
require(inputCount == 2, s"TerminationMerge must have 2 connected inputs, was $inputCount")
Vector(requestInput, terminationBackchannelInput)
}

override def initialState = State[Any](ReadAny(requestInput, terminationBackchannelInput)) {
case (ctx, _, request: HttpRequest) { ctx.emit(request); SameState }
case _ SameState // simply drop all responses, we are only interested in the completion of the response input
}

override def initialCompletionHandling = CompletionHandling(
onComplete = {
case (ctx, `requestInput`) SameState
case (ctx, `terminationBackchannelInput`)
ctx.complete()
SameState
},
onError = defaultCompletionHandling.onError)
}
}

import ParserOutput._

/**
* A FlexiMerge that follows this logic:
* 1. Wait on the methodBypass for the method of the request corresponding to the next response to be received
* 2. Read from the dataInput until exactly one response has been fully received
* 3. Go back to 1.
*/
class ResponseParsingMerge(rootParser: HttpResponseParser) extends FlexiMerge[List[ResponseOutput]] {
import FlexiMerge._
val dataInput = createInputPort[ByteString]()
val methodBypassInput = createInputPort[HttpMethod]()

def createMergeLogic() = new MergeLogic[List[ResponseOutput]] {
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
val parser = rootParser.createShallowCopy()
var methodBypassCompleted = false

override def inputHandles(inputCount: Int) = {
require(inputCount == 2, s"ResponseParsingMerge must have 2 connected inputs, was $inputCount")
Vector(dataInput, methodBypassInput)
}

override val initialState: State[HttpMethod] =
State(Read(methodBypassInput)) {
case (ctx, _, method)
parser.setRequestMethodForNextResponse(method)
drainParser(parser.onPush(ByteString.empty), ctx,
onNeedNextMethod = () SameState,
onNeedMoreData = () {
ctx.changeCompletionHandling(responseReadingCompletionHandling)
responseReadingState
})
}

val responseReadingState: State[ByteString] =
State(Read(dataInput)) {
case (ctx, _, bytes)
drainParser(parser.onPush(bytes), ctx,
onNeedNextMethod = () {
if (methodBypassCompleted) {
ctx.complete()
SameState
} else {
ctx.changeCompletionHandling(initialCompletionHandling)
initialState
}
},
onNeedMoreData = () SameState)
}

@tailrec def drainParser(current: ResponseOutput, ctx: MergeLogicContext,
onNeedNextMethod: () State[_], onNeedMoreData: () State[_],
b: ListBuffer[ResponseOutput] = ListBuffer.empty): State[_] = {
def emit(output: List[ResponseOutput]): Unit = if (output.nonEmpty) ctx.emit(output)
current match {
case NeedNextRequestMethod
emit(b.result())
onNeedNextMethod()
case StreamEnd
emit(b.result())
ctx.complete()
SameState
case NeedMoreData
emit(b.result())
onNeedMoreData()
case x drainParser(parser.onPull(), ctx, onNeedNextMethod, onNeedMoreData, b += x)
}
}

override val initialCompletionHandling = CompletionHandling(
onComplete = (ctx, _) { ctx.complete(); SameState },
onError = defaultCompletionHandling.onError)

val responseReadingCompletionHandling = CompletionHandling(
onComplete = {
case (ctx, `methodBypassInput`)
methodBypassCompleted = true
SameState
case (ctx, `dataInput`)
if (parser.onUpstreamFinish()) {
ctx.complete()
} else {
// not pretty but because the FlexiMerge doesn't let us emit from here (#16565)
// we need to funnel the error through the error channel
ctx.error(new ResponseParsingError(parser.onPull().asInstanceOf[ErrorOutput]))
}
SameState
},
onError = defaultCompletionHandling.onError)
}
}

private class ResponseParsingError(val error: ErrorOutput) extends RuntimeException

// TODO: remove after #16394 is cleared
def recover[A, B >: A](pf: PartialFunction[Throwable, B]): () PushPullStage[A, B] = {
val stage = new PushPullStage[A, B] {
var recovery: Option[B] = None
def onPush(elem: A, ctx: Context[B]): Directive = ctx.push(elem)
def onPull(ctx: Context[B]): Directive = recovery match {
case None ctx.pull()
case Some(x) { recovery = null; ctx.push(x) }
case null ctx.finish()
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[B]): TerminationDirective =
if (pf isDefinedAt cause) {
recovery = Some(pf(cause))
ctx.absorbTermination()
} else super.onUpstreamFailure(cause, ctx)
}
def apply(): HttpMethod =
if (requestMethods.nonEmpty) {
val method = requestMethods.head
requestMethods = requestMethods.tail
method
} else HttpResponseParser.NoMethod
() stage
}
}
Loading

0 comments on commit 5538a31

Please sign in to comment.