Skip to content

Commit

Permalink
H3: HttpBody implementation on the receive side
Browse files Browse the repository at this point in the history
  • Loading branch information
stammw committed Apr 25, 2020
1 parent 625772b commit 04482ba
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 74 deletions.
1 change: 1 addition & 0 deletions interop/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ anyhow = "1.0.22"
bytes = "0.5.2"
futures = "0.3.1"
http = "0.2"
http-body = "0.3"
hyper = "0.13"
hyper-rustls = "0.20"
lazy_static = "1"
Expand Down
17 changes: 7 additions & 10 deletions interop/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use std::{

use anyhow::{anyhow, bail, Context as _, Result};
use bytes::Bytes;
use futures::{ready, AsyncReadExt, Future, StreamExt, TryFutureExt};
use futures::{ready, AsyncReadExt, Future, StreamExt, TryFutureExt, future};
use http::{Response, StatusCode};
use hyper::service::{make_service_fn, service_fn};
use http_body::Body as _;
use hyper::{body::HttpBody, service::{make_service_fn, service_fn}};
use structopt::{self, StructOpt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{server::TlsStream, TlsAcceptor};
Expand Down Expand Up @@ -137,17 +138,13 @@ async fn h3_handle_connection(connecting: quinn::Connecting) -> Result<()> {
}

async fn h3_handle_request(recv_request: RecvRequest) -> Result<()> {
let (request, mut recv_body, sender) = recv_request.await?;
let (mut request, sender) = recv_request.await?;
println!("received request: {:?}", request);

let mut body = Vec::with_capacity(1024);
recv_body
.read_to_end(&mut body)
.await
.map_err(|e| anyhow!("failed to send response headers: {:?}", e))?;

let body = request.body_mut().read_to_end().await?;
println!("received body: {}", String::from_utf8_lossy(&body));
if let Some(trailers) = recv_body.trailers().await {

if let Some(trailers) = request.body_mut().trailers().await? {
println!("received trailers: {:?}", trailers);
}

Expand Down
9 changes: 3 additions & 6 deletions quinn-h3/examples/h3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::{fs, io, net::ToSocketAddrs, path::PathBuf};
use structopt::{self, StructOpt};

use anyhow::Result;
use futures::AsyncReadExt;
use http::{Request, Uri};
use tracing::{error, info};
use tracing_subscriber::filter::LevelFilter;
Expand Down Expand Up @@ -56,18 +55,16 @@ async fn main() -> Result<()> {
let (send_data, recv_response) = conn.send_request(request);
send_data.await?;
// Wait for the response
let (response, mut recv_body) = recv_response.await?;
let mut response = recv_response.await?;

info!("received response: {:?}", response);

// Stream the response body into a vec
let mut body = Vec::with_capacity(1024);
recv_body.read_to_end(&mut body).await?;

let body = response.body_mut().read_to_end().await?;
info!("received body: {}", String::from_utf8_lossy(&body));

// Get the trailers if any
if let Some(trailers) = recv_body.trailers().await {
if let Some(trailers) = response.body_mut().trailers().await? {
info!("received trailers: {:?}", trailers);
}

Expand Down
2 changes: 1 addition & 1 deletion quinn-h3/examples/h3_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async fn main() -> Result<()> {

async fn handle_request(recv_request: RecvRequest) -> Result<()> {
// Receive the request's headers
let (request, _body_reader, sender) = recv_request.await?;
let (request, sender) = recv_request.await?;
info!("received request: {:?}", request);

let response = Response::builder()
Expand Down
105 changes: 102 additions & 3 deletions quinn-h3/src/body.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
use std::{
cmp,
cmp, fmt,
future::Future,
io::{self, ErrorKind},
mem,
pin::Pin,
task::{Context, Poll},
};

use bytes::{Bytes, BytesMut};
use bytes::{Buf, Bytes, BytesMut};
use futures::{
future,
io::{AsyncRead, AsyncWrite},
ready,
stream::Stream,
FutureExt,
};
use http::HeaderMap;
use http_body::Body as HttpBody;
use quinn::SendStream;
use quinn_proto::StreamId;
use std::future::Future;

use crate::{
connection::ConnectionRef,
Expand Down Expand Up @@ -554,3 +556,100 @@ impl HttpBody for SimpleBody<Bytes> {
Poll::Ready(Ok(None))
}
}

pub struct RecvBody {
conn: ConnectionRef,
stream_id: StreamId,
recv: FrameStream,
trailers: Option<HeadersFrame>,
}

impl RecvBody {
pub(crate) fn new(conn: ConnectionRef, stream_id: StreamId, recv: FrameStream) -> Self {
Self {
conn,
stream_id,
recv,
trailers: None,
}
}

pub async fn read_to_end(&mut self) -> Result<Bytes, Error> {
let mut body = BytesMut::with_capacity(10_240);

let mut me = self;
let res: Result<(), Error> = future::poll_fn(|cx| {
while let Some(d) = ready!(Pin::new(&mut me).poll_data(cx)) {
body.extend(d?.bytes());
}
Poll::Ready(Ok(()))
})
.await;
res?;

Ok(body.freeze())
}

pub async fn trailers(&mut self) -> Result<Option<HeaderMap>, Error> {
let mut me = self;
Ok(future::poll_fn(|cx| Pin::new(&mut me).poll_trailers(cx)).await?)
}
}

impl HttpBody for RecvBody {
type Data = bytes::Bytes;
type Error = Error;

fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
loop {
return match ready!(Pin::new(&mut self.recv).poll_next(cx)) {
None => Poll::Ready(None),
Some(Ok(HttpFrame::Reserved)) => continue,
Some(Ok(HttpFrame::Data(d))) => Poll::Ready(Some(Ok(d.payload))),
Some(Ok(HttpFrame::Headers(t))) => {
self.trailers = Some(t);
Poll::Ready(None)
}
Some(Err(e)) => {
self.recv.reset(e.code());
Poll::Ready(Some(Err(e.into())))
}
Some(Ok(f)) => {
self.recv.reset(ErrorCode::FRAME_UNEXPECTED);
Poll::Ready(Some(Err(Error::Peer(format!(
"Invalid frame type in body: {:?}",
f
)))))
}
};
}
}

fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
if self.trailers.is_none() {
return Poll::Ready(Ok(None));
}

let header = {
let mut conn = self.conn.h3.lock().unwrap();
ready!(conn.poll_decode(cx, self.stream_id, self.trailers.as_ref().unwrap()))?
};
self.trailers = None;

Poll::Ready(Ok(Some(header.into_fields())))
}
}

impl fmt::Debug for RecvBody {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecvBody")
.field("stream", &self.stream_id)
.finish()
}
}
53 changes: 24 additions & 29 deletions quinn-h3/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ use quinn_proto::{Side, StreamId};
use tracing::trace;

use crate::{
body::BodyReader,
body::RecvBody,
connection::{ConnectionDriver, ConnectionRef},
frame::{FrameDecoder, FrameStream},
headers::DecodeHeaders,
Expand Down Expand Up @@ -681,10 +681,29 @@ impl RecvResponse {
.cancel_request(self.stream_id.unwrap());
recv.reset(ErrorCode::REQUEST_CANCELLED);
}

fn build_response(
&self,
header: Header,
recv: FrameStream,
) -> Result<Response<RecvBody>, Error> {
let (status, headers) = header.into_response_parts()?;
let mut response = Response::builder()
.status(status)
.version(http::version::Version::HTTP_3)
.body(RecvBody::new(
self.conn.clone(),
self.stream_id.unwrap(),
recv,
))
.unwrap();
*response.headers_mut() = headers;
Ok(response)
}
}

impl Future for RecvResponse {
type Output = Result<(Response<()>, BodyReader), crate::Error>;
type Output = Result<Response<RecvBody>, crate::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop {
Expand Down Expand Up @@ -737,39 +756,15 @@ impl Future for RecvResponse {
}
RecvResponseState::Decoding(ref mut decode) => {
let headers = ready!(Pin::new(decode).poll(cx))?;
let response = build_response(headers);
match response {
Err(e) => return Poll::Ready(Err(e)),
Ok(r) => {
self.state = RecvResponseState::Finished;
return Poll::Ready(Ok((
r,
BodyReader::new(
self.recv.take().unwrap(),
self.conn.clone(),
self.stream_id.unwrap(),
true,
),
)));
}
}
let recv = self.recv.take().unwrap();
let response = self.build_response(headers, recv)?;
return Poll::Ready(Ok(response));
}
}
}
}
}

fn build_response(header: Header) -> Result<Response<()>, Error> {
let (status, headers) = header.into_response_parts()?;
let mut response = Response::builder()
.status(status)
.version(http::version::Version::HTTP_3)
.body(())
.unwrap();
*response.headers_mut() = headers;
Ok(response)
}

#[cfg(test)]
impl Connection {
pub(crate) fn inner(&self) -> &ConnectionRef {
Expand Down
Loading

0 comments on commit 04482ba

Please sign in to comment.