From 145b2d7b6deef8ae469696157bcb974d045cfc05 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 24 Jan 2025 10:09:23 -0800 Subject: [PATCH] internal/http3: add RoundTrip Send request headers, receive response headers. For golang/go#70914 Change-Id: I78d4dcc69c253ed7ad1543dfc3c5d8f1c321ced9 Reviewed-on: https://go-review.googlesource.com/c/net/+/644118 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam Auto-Submit: Damien Neil --- internal/http3/roundtrip.go | 226 ++++++++++++++++++++++++++++++ internal/http3/roundtrip_test.go | 232 +++++++++++++++++++++++++++++++ internal/http3/transport.go | 11 +- internal/http3/transport_test.go | 167 ++++++++++++++++++++-- 4 files changed, 621 insertions(+), 15 deletions(-) create mode 100644 internal/http3/roundtrip.go create mode 100644 internal/http3/roundtrip_test.go diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go new file mode 100644 index 000000000..9042c15bf --- /dev/null +++ b/internal/http3/roundtrip.go @@ -0,0 +1,226 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package http3 + +import ( + "io" + "net/http" + "strconv" + + "golang.org/x/net/internal/httpcommon" +) + +// RoundTrip sends a request on the connection. +func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) { + // Each request gets its own QUIC stream. + st, err := newConnStream(req.Context(), cc.qconn, streamTypeRequest) + if err != nil { + return nil, err + } + defer func() { + switch e := err.(type) { + case nil: + case *connectionError: + cc.abort(e) + case *streamError: + st.stream.CloseRead() + st.stream.Reset(uint64(e.code)) + default: + st.stream.CloseRead() + st.stream.Reset(uint64(errH3NoError)) + } + }() + + // Cancel reads/writes on the stream when the request expires. + st.stream.SetReadContext(req.Context()) + st.stream.SetWriteContext(req.Context()) + + var encr httpcommon.EncodeHeadersResult + headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) { + encr, err = httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ + Request: req, + AddGzipHeader: false, // TODO: add when appropriate + PeerMaxHeaderListSize: 0, + DefaultUserAgent: "Go-http-client/3", + }, func(name, value string) { + // Issue #71374: Consider supporting never-indexed fields. + yield(mayIndex, name, value) + }) + }) + if err != nil { + return nil, err + } + + // Write the HEADERS frame. + st.writeVarint(int64(frameTypeHeaders)) + st.writeVarint(int64(len(headers))) + st.Write(headers) + if err := st.Flush(); err != nil { + return nil, err + } + + if encr.HasBody { + // TODO: Send the request body. + } + + // Read the response headers. + for { + ftype, err := st.readFrameHeader() + if err != nil { + return nil, err + } + switch ftype { + case frameTypeHeaders: + statusCode, h, err := cc.handleHeaders(st) + if err != nil { + return nil, err + } + + if statusCode >= 100 && statusCode < 199 { + // TODO: Handle 1xx responses. + continue + } + + // We have the response headers. + // Set up the response and return it to the caller. + contentLength, err := parseResponseContentLength(req.Method, statusCode, h) + if err != nil { + return nil, err + } + resp := &http.Response{ + Proto: "HTTP/3.0", + ProtoMajor: 3, + Header: h, + StatusCode: statusCode, + Status: strconv.Itoa(statusCode) + " " + http.StatusText(statusCode), + ContentLength: contentLength, + Body: io.NopCloser(nil), // TODO: read the response body + } + // TODO: Automatic Content-Type: gzip decoding. + return resp, nil + case frameTypePushPromise: + if err := cc.handlePushPromise(st); err != nil { + return nil, err + } + default: + if err := st.discardUnknownFrame(ftype); err != nil { + return nil, err + } + } + } +} + +func parseResponseContentLength(method string, statusCode int, h http.Header) (int64, error) { + clens := h["Content-Length"] + if len(clens) == 0 { + return -1, nil + } + + // We allow duplicate Content-Length headers, + // but only if they all have the same value. + for _, v := range clens[1:] { + if clens[0] != v { + return -1, &streamError{errH3MessageError, "mismatching Content-Length headers"} + } + } + + // "A server MUST NOT send a Content-Length header field in any response + // with a status code of 1xx (Informational) or 204 (No Content). + // A server MUST NOT send a Content-Length header field in any 2xx (Successful) + // response to a CONNECT request [...]" + // https://www.rfc-editor.org/rfc/rfc9110#section-8.6-8 + if (statusCode >= 100 && statusCode < 200) || + statusCode == 204 || + (method == "CONNECT" && statusCode >= 200 && statusCode < 300) { + // This is a protocol violation, but a fairly harmless one. + // Just ignore the header. + return -1, nil + } + + contentLen, err := strconv.ParseUint(clens[0], 10, 63) + if err != nil { + return -1, &streamError{errH3MessageError, "invalid Content-Length header"} + } + return int64(contentLen), nil +} + +func (cc *ClientConn) handleHeaders(st *stream) (statusCode int, h http.Header, err error) { + haveStatus := false + cookie := "" + // Issue #71374: Consider tracking the never-indexed status of headers + // with the N bit set in their QPACK encoding. + err = cc.dec.decode(st, func(_ indexType, name, value string) error { + switch { + case name == ":status": + if haveStatus { + return &streamError{errH3MessageError, "duplicate :status"} + } + haveStatus = true + statusCode, err = strconv.Atoi(value) + if err != nil { + return &streamError{errH3MessageError, "invalid :status"} + } + case name[0] == ':': + // "Endpoints MUST treat a request or response + // that contains undefined or invalid + // pseudo-header fields as malformed." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3-3 + return &streamError{errH3MessageError, "undefined pseudo-header"} + case name == "cookie": + // "If a decompressed field section contains multiple cookie field lines, + // these MUST be concatenated into a single byte string [...]" + // using the two-byte delimiter of "; "'' + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2 + if cookie == "" { + cookie = value + } else { + cookie += "; " + value + } + default: + if h == nil { + h = make(http.Header) + } + // TODO: Use a per-connection canonicalization cache as we do in HTTP/2. + // Maybe we could put this in the QPACK decoder and have it deliver + // pre-canonicalized headers to us here? + cname := httpcommon.CanonicalHeader(name) + // TODO: Consider using a single []string slice for all headers, + // as we do in the HTTP/1 and HTTP/2 cases. + // This is a bit tricky, since we don't know the number of headers + // at the start of decoding. Perhaps it's worth doing a two-pass decode, + // or perhaps we should just allocate header value slices in + // reasonably-sized chunks. + h[cname] = append(h[cname], value) + } + return nil + }) + if !haveStatus { + // "[The :status] pseudo-header field MUST be included in all responses [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3.2-1 + err = errH3MessageError + } + if cookie != "" { + if h == nil { + h = make(http.Header) + } + h["Cookie"] = []string{cookie} + } + if err := st.endFrame(); err != nil { + return 0, nil, err + } + return statusCode, h, err +} + +func (cc *ClientConn) handlePushPromise(st *stream) error { + // "A client MUST treat receipt of a PUSH_PROMISE frame that contains a + // larger push ID than the client has advertised as a connection error of H3_ID_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5 + return &connectionError{ + code: errH3IDError, + message: "PUSH_PROMISE received when no MAX_PUSH_ID has been sent", + } +} diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go new file mode 100644 index 000000000..34397c07f --- /dev/null +++ b/internal/http3/roundtrip_test.go @@ -0,0 +1,232 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 && goexperiment.synctest + +package http3 + +import ( + "net/http" + "testing" +) + +func TestRoundTripSimple(t *testing.T) { + runSynctest(t, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + req.Header["User-Agent"] = nil + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(http.Header{ + ":authority": []string{"example.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/"}, + ":scheme": []string{"https"}, + }) + st.writeHeaders(http.Header{ + ":status": []string{"200"}, + "x-some-header": []string{"value"}, + }) + rt.wantStatus(200) + rt.wantHeaders(http.Header{ + "X-Some-Header": []string{"value"}, + }) + }) +} + +func TestRoundTripWithBadHeaders(t *testing.T) { + runSynctest(t, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + req.Header["Invalid\nHeader"] = []string{"x"} + rt := tc.roundTrip(req) + rt.wantError("RoundTrip fails when request contains invalid headers") + }) +} + +func TestRoundTripWithUnknownFrame(t *testing.T) { + runSynctest(t, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + + // Write an unknown frame type before the response HEADERS. + data := "frame content" + st.writeVarint(0x1f + 0x21) // reserved frame type + st.writeVarint(int64(len(data))) // size + st.Write([]byte(data)) + + st.writeHeaders(http.Header{ + ":status": []string{"200"}, + }) + rt.wantStatus(200) + }) +} + +func TestRoundTripWithInvalidPushPromise(t *testing.T) { + // "A client MUST treat receipt of a PUSH_PROMISE frame that contains + // a larger push ID than the client has advertised as a connection error of H3_ID_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5 + runSynctest(t, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + + // Write a PUSH_PROMISE frame. + // Since the client hasn't indicated willingness to accept pushes, + // this is a connection error. + st.writePushPromise(0, http.Header{ + ":path": []string{"/foo"}, + }) + rt.wantError("RoundTrip fails after receiving invalid PUSH_PROMISE") + tc.wantClosed( + "push ID exceeds client's MAX_PUSH_ID", + errH3IDError, + ) + }) +} + +func TestRoundTripResponseContentLength(t *testing.T) { + for _, test := range []struct { + name string + respHeader http.Header + wantContentLength int64 + wantError bool + }{{ + name: "valid", + respHeader: http.Header{ + ":status": []string{"200"}, + "content-length": []string{"100"}, + }, + wantContentLength: 100, + }, { + name: "absent", + respHeader: http.Header{ + ":status": []string{"200"}, + }, + wantContentLength: -1, + }, { + name: "unparseable", + respHeader: http.Header{ + ":status": []string{"200"}, + "content-length": []string{"1 1"}, + }, + wantError: true, + }, { + name: "duplicated", + respHeader: http.Header{ + ":status": []string{"200"}, + "content-length": []string{"500", "500", "500"}, + }, + wantContentLength: 500, + }, { + name: "inconsistent", + respHeader: http.Header{ + ":status": []string{"200"}, + "content-length": []string{"1", "2"}, + }, + wantError: true, + }, { + // 204 responses aren't allowed to contain a Content-Length header. + // We just ignore it. + name: "204", + respHeader: http.Header{ + ":status": []string{"204"}, + "content-length": []string{"100"}, + }, + wantContentLength: -1, + }} { + runSynctestSubtest(t, test.name, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + st.writeHeaders(test.respHeader) + if test.wantError { + rt.wantError("invalid content-length in response") + return + } + if got, want := rt.response().ContentLength, test.wantContentLength; got != want { + t.Errorf("Response.ContentLength = %v, want %v", got, want) + } + }) + } +} + +func TestRoundTripMalformedResponses(t *testing.T) { + for _, test := range []struct { + name string + respHeader http.Header + }{{ + name: "duplicate :status", + respHeader: http.Header{ + ":status": []string{"200", "204"}, + }, + }, { + name: "unparseable :status", + respHeader: http.Header{ + ":status": []string{"frogpants"}, + }, + }, { + name: "undefined pseudo-header", + respHeader: http.Header{ + ":status": []string{"200"}, + ":unknown": []string{"x"}, + }, + }, { + name: "no :status", + respHeader: http.Header{}, + }} { + runSynctestSubtest(t, test.name, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + st.writeHeaders(test.respHeader) + rt.wantError("malformed response") + }) + } +} + +func TestRoundTripCrumbledCookiesInResponse(t *testing.T) { + // "If a decompressed field section contains multiple cookie field lines, + // these MUST be concatenated into a single byte string [...]" + // using the two-byte delimiter of "; "'' + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2 + runSynctest(t, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + st.writeHeaders(http.Header{ + ":status": []string{"200"}, + "cookie": []string{"a=1", "b=2; c=3", "d=4"}, + }) + rt.wantStatus(200) + rt.wantHeaders(http.Header{ + "Cookie": []string{"a=1; b=2; c=3; d=4"}, + }) + }) +} diff --git a/internal/http3/transport.go b/internal/http3/transport.go index 7c465117f..2acf40f08 100644 --- a/internal/http3/transport.go +++ b/internal/http3/transport.go @@ -9,10 +9,8 @@ package http3 import ( "context" "crypto/tls" - "errors" "fmt" "io" - "net/http" "sync" "golang.org/x/net/quic" @@ -99,12 +97,16 @@ type ClientConn struct { // streamsCreated is a bitset of streams created so far. // Bits are 1 << streamType. streamsCreated uint8 + + enc qpackEncoder + dec qpackDecoder } func newClientConn(ctx context.Context, qconn *quic.Conn) (*ClientConn, error) { cc := &ClientConn{ qconn: qconn, } + cc.enc.init() // Create control stream and send SETTINGS frame. controlStream, err := newConnStream(ctx, cc.qconn, streamTypeControl) @@ -131,11 +133,6 @@ func (cc *ClientConn) Close() error { return cc.qconn.Wait(ctx) } -// RoundTrip sends a request on the connection. -func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { - return nil, errors.New("TODO") -} - func (cc *ClientConn) acceptStreams() { for { // Use context.Background: This blocks until a stream is accepted diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go index b915d79c8..a61c9a661 100644 --- a/internal/http3/transport_test.go +++ b/internal/http3/transport_test.go @@ -9,18 +9,22 @@ package http3 import ( "context" "errors" + "fmt" + "maps" + "net/http" + "reflect" + "slices" "testing" "testing/synctest" + "golang.org/x/net/internal/quic/quicwire" "golang.org/x/net/quic" ) func TestTransportCreatesControlStream(t *testing.T) { runSynctest(t, func(t testing.TB) { tc := newTestClientConn(t) - controlStream := tc.wantStream( - "client creates control stream", - streamTypeControl) + controlStream := tc.wantStream(streamTypeControl) controlStream.wantFrameHeader( "client sends SETTINGS frame on control stream", frameTypeSettings) @@ -228,10 +232,11 @@ func (tq *testQUICConn) wantClosed(reason string, want error) { // wantStream asserts that a stream of a given type has been created, // and returns that stream. -func (tq *testQUICConn) wantStream(reason string, stype streamType) *testQUICStream { +func (tq *testQUICConn) wantStream(stype streamType) *testQUICStream { tq.t.Helper() + synctest.Wait() if len(tq.streams[stype]) == 0 { - tq.t.Fatalf("%v: stream not created", reason) + tq.t.Fatalf("expected a %v stream to be created, but none were", stype) } ts := tq.streams[stype][0] tq.streams[stype] = tq.streams[stype][1:] @@ -256,6 +261,7 @@ func newTestQUICStream(t testing.TB, st *stream) *testQUICStream { // wantFrameHeader calls readFrameHeader and asserts that the frame is of a given type. func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) { ts.t.Helper() + synctest.Wait() gotType, err := ts.readFrameHeader() if err != nil { ts.t.Fatalf("%v: failed to read frame header: %v", reason, err) @@ -265,6 +271,83 @@ func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) { } } +// wantHeaders reads a HEADERS frame. +// If want is nil, the contents of the frame are ignored. +func (ts *testQUICStream) wantHeaders(want http.Header) { + ts.t.Helper() + ftype, err := ts.readFrameHeader() + if err != nil { + ts.t.Fatalf("want HEADERS frame, got error: %v", err) + } + if ftype != frameTypeHeaders { + ts.t.Fatalf("want HEADERS frame, got: %v", ftype) + } + + if want == nil { + return + } + + got := make(http.Header) + var dec qpackDecoder + err = dec.decode(ts.stream, func(_ indexType, name, value string) error { + got.Add(name, value) + return nil + }) + if diff := diffHeaders(got, want); diff != "" { + ts.t.Fatalf("unexpected response headers:\n%v", diff) + } +} + +func (ts *testQUICStream) encodeHeaders(h http.Header) []byte { + ts.t.Helper() + var enc qpackEncoder + return enc.encode(func(yield func(itype indexType, name, value string)) { + names := slices.Collect(maps.Keys(h)) + slices.Sort(names) + for _, k := range names { + for _, v := range h[k] { + yield(mayIndex, k, v) + } + } + }) +} + +func (ts *testQUICStream) writeHeaders(h http.Header) { + ts.t.Helper() + headers := ts.encodeHeaders(h) + ts.writeVarint(int64(frameTypeHeaders)) + ts.writeVarint(int64(len(headers))) + ts.Write(headers) + if err := ts.Flush(); err != nil { + ts.t.Fatalf("flushing HEADERS frame: %v", err) + } +} + +func (ts *testQUICStream) writePushPromise(pushID int64, h http.Header) { + ts.t.Helper() + headers := ts.encodeHeaders(h) + ts.writeVarint(int64(frameTypePushPromise)) + ts.writeVarint(int64(quicwire.SizeVarint(uint64(pushID)) + len(headers))) + ts.writeVarint(pushID) + ts.Write(headers) + if err := ts.Flush(); err != nil { + ts.t.Fatalf("flushing PUSH_PROMISE frame: %v", err) + } +} + +func diffHeaders(got, want http.Header) string { + // nil and 0-length non-nil are equal. + if len(got) == 0 && len(want) == 0 { + return "" + } + // We could do a more sophisticated diff here. + // DeepEqual is good enough for now. + if reflect.DeepEqual(got, want) { + return "" + } + return fmt.Sprintf("got: %v\nwant: %v", got, want) +} + func (ts *testQUICStream) Flush() error { err := ts.stream.Flush() if err != nil { @@ -316,9 +399,7 @@ func newTestClientConn(t testing.TB) *testClientConn { // greet performs initial connection handshaking with the client. func (tc *testClientConn) greet() { // Client creates a control stream. - clientControlStream := tc.wantStream( - "client creates control stream", - streamTypeControl) + clientControlStream := tc.wantStream(streamTypeControl) clientControlStream.wantFrameHeader( "client sends SETTINGS frame on control stream", frameTypeSettings) @@ -333,6 +414,76 @@ func (tc *testClientConn) greet() { synctest.Wait() } +type testRoundTrip struct { + t testing.TB + resp *http.Response + respErr error +} + +func (rt *testRoundTrip) done() bool { + synctest.Wait() + return rt.resp != nil || rt.respErr != nil +} + +func (rt *testRoundTrip) result() (*http.Response, error) { + rt.t.Helper() + if !rt.done() { + rt.t.Fatal("RoundTrip is not done; want it to be") + } + return rt.resp, rt.respErr +} + +func (rt *testRoundTrip) response() *http.Response { + rt.t.Helper() + if !rt.done() { + rt.t.Fatal("RoundTrip is not done; want it to be") + } + if rt.respErr != nil { + rt.t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr) + } + return rt.resp +} + +// err returns the (possibly nil) error result of RoundTrip. +func (rt *testRoundTrip) err() error { + rt.t.Helper() + _, err := rt.result() + return err +} + +func (rt *testRoundTrip) wantError(reason string) { + rt.t.Helper() + if !rt.done() { + rt.t.Fatalf("%v: RoundTrip is not done; want it to have returned an error", reason) + } + if rt.respErr == nil { + rt.t.Fatalf("%v: RoundTrip succeeded; want it to have returned an error", reason) + } +} + +// wantStatus indicates the expected response StatusCode. +func (rt *testRoundTrip) wantStatus(want int) { + rt.t.Helper() + if got := rt.response().StatusCode; got != want { + rt.t.Fatalf("got response status %v, want %v", got, want) + } +} + +func (rt *testRoundTrip) wantHeaders(want http.Header) { + rt.t.Helper() + if diff := diffHeaders(rt.response().Header, want); diff != "" { + rt.t.Fatalf("unexpected response headers:\n%v", diff) + } +} + +func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{t: tc.t} + go func() { + rt.resp, rt.respErr = tc.cc.RoundTrip(req) + }() + return rt +} + func (tc *testClientConn) newStream(stype streamType) *testQUICStream { tc.t.Helper() var qs *quic.Stream