diff --git a/.golangci.yml b/.golangci.yml index 10a1a810..07d0c9eb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -146,6 +146,11 @@ linters: - whitespace issues: + exclude-files: + - internal/iso/internal/reverseproxy.go + - internal/iso/internal/reverseproxy_test.go + - internal/iso/internal/acsii.go + - internal/iso/internal/acsii_test.go # Excluding configuration per-path, per-linter, per-text and per-source exclude-rules: - path: _test\.go diff --git a/internal/iso/internal/LICENSE b/internal/iso/internal/LICENSE new file mode 100644 index 00000000..2a7cf70d --- /dev/null +++ b/internal/iso/internal/LICENSE @@ -0,0 +1,27 @@ +Copyright 2009 The Go Authors. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google LLC nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/internal/iso/internal/acsii.go b/internal/iso/internal/acsii.go new file mode 100644 index 00000000..f4b95d0a --- /dev/null +++ b/internal/iso/internal/acsii.go @@ -0,0 +1,38 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +// EqualFold is [strings.EqualFold], ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func EqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func IsPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} diff --git a/internal/iso/internal/acsii_test.go b/internal/iso/internal/acsii_test.go new file mode 100644 index 00000000..9481e00c --- /dev/null +++ b/internal/iso/internal/acsii_test.go @@ -0,0 +1,95 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import "testing" + +func TestEqualFold(t *testing.T) { + var tests = []struct { + name string + a, b string + want bool + }{ + { + name: "empty", + want: true, + }, + { + name: "simple match", + a: "CHUNKED", + b: "chunked", + want: true, + }, + { + name: "same string", + a: "chunked", + b: "chunked", + want: true, + }, + { + name: "Unicode Kelvin symbol", + a: "chunKed", // This "K" is 'KELVIN SIGN' (\u212A) + b: "chunked", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := EqualFold(tt.a, tt.b); got != tt.want { + t.Errorf("AsciiEqualFold(%q,%q): got %v want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func TestIsPrint(t *testing.T) { + var tests = []struct { + name string + in string + want bool + }{ + { + name: "empty", + want: true, + }, + { + name: "ASCII low", + in: "This is a space: ' '", + want: true, + }, + { + name: "ASCII high", + in: "This is a tilde: '~'", + want: true, + }, + { + name: "ASCII low non-print", + in: "This is a unit separator: \x1F", + want: false, + }, + { + name: "Ascii high non-print", + in: "This is a Delete: \x7F", + want: false, + }, + { + name: "Unicode letter", + in: "Today it's 280K outside: it's freezing!", // This "K" is 'KELVIN SIGN' (\u212A) + want: false, + }, + { + name: "Unicode emoji", + in: "Gophers like 🧀", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsPrint(tt.in); got != tt.want { + t.Errorf("IsASCIIPrint(%q): got %v want %v", tt.in, got, tt.want) + } + }) + } +} diff --git a/internal/iso/internal/context.go b/internal/iso/internal/context.go new file mode 100644 index 00000000..741cdc23 --- /dev/null +++ b/internal/iso/internal/context.go @@ -0,0 +1,19 @@ +package internal + +import "context" + +type patchCtxKeyType string + +const isoPatchCtxKey patchCtxKeyType = "iso-patch" + +func WithPatch(ctx context.Context, patch []byte) context.Context { + return context.WithValue(ctx, isoPatchCtxKey, patch) +} + +func GetPatch(ctx context.Context) []byte { + patch, ok := ctx.Value(isoPatchCtxKey).([]byte) + if !ok { + return nil + } + return patch +} diff --git a/internal/iso/internal/reverseproxy.go b/internal/iso/internal/reverseproxy.go new file mode 100644 index 00000000..6a58d94a --- /dev/null +++ b/internal/iso/internal/reverseproxy.go @@ -0,0 +1,863 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package internal + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httptrace" + "net/textproto" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" +) + +// A ProxyRequest contains a request to be rewritten by a [ReverseProxy]. +type ProxyRequest struct { + // In is the request received by the proxy. + // The Rewrite function must not modify In. + In *http.Request + + // Out is the request which will be sent by the proxy. + // The Rewrite function may modify or replace this request. + // Hop-by-hop headers are removed from this request + // before Rewrite is called. + Out *http.Request +} + +// SetURL routes the outbound request to the scheme, host, and base path +// provided in target. If the target's path is "/base" and the incoming +// request was for "/dir", the target request will be for "/base/dir". +// +// SetURL rewrites the outbound Host header to match the target's host. +// To preserve the inbound request's Host header (the default behavior +// of [NewSingleHostReverseProxy]): +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.SetURL(url) +// r.Out.Host = r.In.Host +// } +func (r *ProxyRequest) SetURL(target *url.URL) { + rewriteRequestURL(r.Out, target) + r.Out.Host = "" +} + +// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and +// X-Forwarded-Proto headers of the outbound request. +// +// - The X-Forwarded-For header is set to the client IP address. +// - The X-Forwarded-Host header is set to the host name requested +// by the client. +// - The X-Forwarded-Proto header is set to "http" or "https", depending +// on whether the inbound request was made on a TLS-enabled connection. +// +// If the outbound request contains an existing X-Forwarded-For header, +// SetXForwarded appends the client IP address to it. To append to the +// inbound request's X-Forwarded-For header (the default behavior of +// [ReverseProxy] when using a Director function), copy the header +// from the inbound request before calling SetXForwarded: +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] +// r.SetXForwarded() +// } +func (r *ProxyRequest) SetXForwarded() { + clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) + if err == nil { + prior := r.Out.Header["X-Forwarded-For"] + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + r.Out.Header.Set("X-Forwarded-For", clientIP) + } else { + r.Out.Header.Del("X-Forwarded-For") + } + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + if r.In.TLS == nil { + r.Out.Header.Set("X-Forwarded-Proto", "http") + } else { + r.Out.Header.Set("X-Forwarded-Proto", "https") + } +} + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +// +// 1xx responses are forwarded to the client if the underlying +// transport supports ClientTrace.Got1xxResponse. +type ReverseProxy struct { + // Rewrite must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Rewrite must not access the provided ProxyRequest + // or its contents after returning. + // + // The Forwarded, X-Forwarded, X-Forwarded-Host, + // and X-Forwarded-Proto headers are removed from the + // outbound request before Rewrite is called. See also + // the ProxyRequest.SetXForwarded method. + // + // Unparsable query parameters are removed from the + // outbound request before Rewrite is called. + // The Rewrite function may copy the inbound URL's + // RawQuery to the outbound URL to preserve the original + // parameter string. Note that this can lead to security + // issues if the proxy's interpretation of query parameters + // does not match that of the downstream server. + // + // At most one of Rewrite or Director may be set. + Rewrite func(*ProxyRequest) + + // Director is a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Director must not access the provided Request + // after returning. + // + // By default, the X-Forwarded-For header is set to the + // value of the client IP address. If an X-Forwarded-For + // header already exists, the client IP is appended to the + // existing values. As a special case, if the header + // exists in the Request.Header map but has a nil value + // (such as when set by the Director func), the X-Forwarded-For + // header is not modified. + // + // To prevent IP spoofing, be sure to delete any pre-existing + // X-Forwarded-For header coming from the client or + // an untrusted proxy. + // + // Hop-by-hop headers are removed from the request after + // Director returns, which can remove headers added by + // Director. Use a Rewrite function instead to ensure + // modifications to the request are preserved. + // + // Unparsable query parameters are removed from the outbound + // request if Request.Form is set after Director returns. + // + // At most one of Rewrite or Director may be set. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + // A negative value means to flush immediately + // after each write to the client. + // The FlushInterval is ignored when ReverseProxy + // recognizes a response as a streaming response, or + // if its ContentLength is -1; for such responses, writes + // are flushed to the client immediately. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging is done via the log package's standard logger. + ErrorLog *log.Logger + + // BufferPool optionally specifies a buffer pool to + // get byte slices for use by io.CopyBuffer when + // copying HTTP response bodies. + BufferPool BufferPool + + // ModifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional ErrorHandler is + // called without any call to ModifyResponse. + // + // If ModifyResponse returns an error, ErrorHandler is called + // with its error value. If ErrorHandler is nil, its default + // implementation is used. + ModifyResponse func(*http.Response) error + + // ErrorHandler is an optional function that handles errors + // reaching the backend or errors from ModifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + ErrorHandler func(http.ResponseWriter, *http.Request, error) + + // CopyBuffer is an optional function for handling the copying of the + // response body. If nil, an internal implementation is used. + CopyBuffer CopyBuffer +} + +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by [io.CopyBuffer]. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +type CopyBuffer interface { + Copy(ctx context.Context, dst io.Writer, src io.Reader, buf []byte) (int64, error) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func joinURLPath(a, b *url.URL) (path, rawpath string) { + if a.RawPath == "" && b.RawPath == "" { + return singleJoiningSlash(a.Path, b.Path), "" + } + // Same as singleJoiningSlash, but uses EscapedPath to determine + // whether a slash should be added + apath := a.EscapedPath() + bpath := b.EscapedPath() + + aslash := strings.HasSuffix(apath, "/") + bslash := strings.HasPrefix(bpath, "/") + + switch { + case aslash && bslash: + return a.Path + b.Path[1:], apath + bpath[1:] + case !aslash && !bslash: + return a.Path + "/" + b.Path, apath + "/" + bpath + } + return a.Path + b.Path, apath + bpath +} + +// NewSingleHostReverseProxy returns a new [ReverseProxy] that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// +// NewSingleHostReverseProxy does not rewrite the Host header. +// +// To customize the ReverseProxy behavior beyond what +// NewSingleHostReverseProxy provides, use ReverseProxy directly +// with a Rewrite function. The ProxyRequest SetURL method +// may be used to route the outbound request. (Note that SetURL, +// unlike NewSingleHostReverseProxy, rewrites the Host header +// of the outbound request by default.) +// +// proxy := &ReverseProxy{ +// Rewrite: func(r *ProxyRequest) { +// r.SetURL(target) +// r.Out.Host = r.In.Host // if desired +// }, +// } +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + director := func(req *http.Request) { + rewriteRequestURL(req, target) + } + return &ReverseProxy{Director: director} +} + +func rewriteRequestURL(req *http.Request, target *url.URL) { + targetQuery := target.RawQuery + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) +} + +func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { + if p.ErrorHandler != nil { + return p.ErrorHandler + } + return p.defaultErrorHandler +} + +// modifyResponse conditionally runs the optional ModifyResponse hook +// and reports whether the request should proceed. +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { + if p.ModifyResponse == nil { + return true + } + if err := p.ModifyResponse(res); err != nil { + res.Body.Close() + p.getErrorHandler()(rw, req, err) + return false + } + return true +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx := req.Context() + if ctx.Done() != nil { + // CloseNotifier predates context.Context, and has been + // entirely superseded by it. If the request contains + // a Context that carries a cancellation signal, don't + // bother spinning up a goroutine to watch the CloseNotify + // channel (if any). + // + // If the request Context has a nil Done channel (which + // means it is either context.Background, or a custom + // Context implementation with no cancellation signal), + // then consult the CloseNotifier if available. + } else if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := req.Clone(ctx) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + if outreq.Body != nil { + // Reading from the request body after returning from a handler is not + // allowed, and the RoundTrip goroutine that reads the Body can outlive + // this handler. This can lead to a crash if the handler panics (see + // Issue 46866). Although calling Close doesn't guarantee there isn't + // any Read in flight after the handle returns, in practice it's safe to + // read after closing it. + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + } + + if (p.Director != nil) == (p.Rewrite != nil) { + p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set")) + return + } + + if p.Director != nil { + p.Director(outreq) + if outreq.Form != nil { + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + } + } + outreq.Close = false + + reqUpType := upgradeType(outreq.Header) + if !IsPrint(reqUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) + return + } + removeHopByHopHeaders(outreq.Header) + + // Issue 21096: tell backend applications that care about trailer support + // that we support trailers. (We do, but we don't go out of our way to + // advertise that unless the incoming client request thought it was worth + // mentioning.) Note that we look at req.Header, not outreq.Header, since + // the latter has passed through removeHopByHopHeaders. + if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + if p.Rewrite != nil { + // Strip client-provided forwarding headers. + // The Rewrite func may use SetXForwarded to set new values + // for these or copy the previous values from the inbound request. + outreq.Header.Del("Forwarded") + outreq.Header.Del("X-Forwarded-For") + outreq.Header.Del("X-Forwarded-Host") + outreq.Header.Del("X-Forwarded-Proto") + + // Remove unparsable query parameters from the outbound request. + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + + pr := &ProxyRequest{ + In: req, + Out: outreq, + } + p.Rewrite(pr) + outreq = pr.Out + } else { + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := outreq.Header["X-Forwarded-For"] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + if !omit { + outreq.Header.Set("X-Forwarded-For", clientIP) + } + } + } + + if _, ok := outreq.Header["User-Agent"]; !ok { + // If the outbound request doesn't have a User-Agent header set, + // don't send the default Go HTTP client User-Agent. + outreq.Header.Set("User-Agent", "") + } + + var ( + roundTripMutex sync.Mutex + roundTripDone bool + ) + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + roundTripMutex.Lock() + defer roundTripMutex.Unlock() + if roundTripDone { + // If RoundTrip has returned, don't try to further modify + // the ResponseWriter's header map. + return nil + } + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + clear(h) + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + roundTripMutex.Lock() + roundTripDone = true + roundTripMutex.Unlock() + if err != nil { + p.getErrorHandler()(rw, outreq, err) + return + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + if !p.modifyResponse(rw, res, outreq) { + return + } + p.handleUpgradeResponse(rw, outreq, res) + return + } + + removeHopByHopHeaders(res.Header) + + if !p.modifyResponse(rw, res, outreq) { + return + } + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + + var resContext context.Context + if res.Request != nil { + resContext = res.Request.Context() + } + err = p.copyResponse(resContext, rw, res.Body, p.flushInterval(res)) + if err != nil { + defer res.Body.Close() + // Since we're streaming the response, if we run into an error all we can do + // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // on read error while copying body. + if !shouldPanicOnCopyError(req) { + p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return + } + panic(http.ErrAbortHandler) + } + res.Body.Close() // close now, instead of defer, to populate res.Trailer + + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + http.NewResponseController(rw).Flush() + } + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +var inOurTests bool // whether we're in our own tests + +// shouldPanicOnCopyError reports whether the reverse proxy should +// panic with http.ErrAbortHandler. This is the right thing to do by +// default, but Go 1.10 and earlier did not, so existing unit tests +// weren't expecting panics. Only panic in our own tests, or when +// running under the HTTP server. +func shouldPanicOnCopyError(req *http.Request) bool { + if inOurTests { + // Our tests know to handle this panic. + return true + } + if req.Context().Value(http.ServerContextKey) != nil { + // We seem to be running under an HTTP server, so + // it'll recover the panic. + return true + } + // Otherwise act like Go 1.10 and earlier to not break + // existing tests. + return false +} + +// removeHopByHopHeaders removes hop-by-hop headers. +func removeHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } +} + +// flushInterval returns the p.FlushInterval value, conditionally +// overriding its value for a specific request/response. +func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" { + return -1 // negative means immediately + } + + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + + return p.FlushInterval +} + +func (p *ReverseProxy) copyResponse(ctx context.Context, dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var w io.Writer = dst + + if flushInterval != 0 { + mlw := &maxLatencyWriter{ + dst: dst, + flush: http.NewResponseController(dst).Flush, + latency: flushInterval, + } + defer mlw.stop() + + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + + w = mlw + } + + var buf []byte + if p.BufferPool != nil { + buf = p.BufferPool.Get() + defer p.BufferPool.Put(buf) + } + + var err error + if p.CopyBuffer != nil { + _, err = p.CopyBuffer.Copy(ctx, w, src, buf) + } else { + _, err = p.copyBuffer(w, src, buf) + } + + return err +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +func (p *ReverseProxy) logf(format string, args ...any) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type maxLatencyWriter struct { + dst io.Writer + flush func() error + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return h.Get("Upgrade") +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType)) + } + if !EqualFold(reqUpType, resUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + + rc := http.NewResponseController(rw) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + defer close(backConnCloseCh) + + if hijackErr != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) + return + } + defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} + +func cleanQueryParams(s string) string { + reencode := func(s string) string { + v, _ := url.ParseQuery(s) + return v.Encode() + } + for i := 0; i < len(s); { + switch s[i] { + case ';': + return reencode(s) + case '%': + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + return reencode(s) + } + i += 3 + default: + i++ + } + } + return s +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} diff --git a/internal/iso/internal/reverseproxy_test.go b/internal/iso/internal/reverseproxy_test.go new file mode 100644 index 00000000..b46c66c6 --- /dev/null +++ b/internal/iso/internal/reverseproxy_test.go @@ -0,0 +1,1931 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package internal + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/textproto" + "net/url" + "os" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + inOurTests = true + hopHeaders = append(hopHeaders, fakeHopHeader) +} + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.FormValue("mode") == "hangup" { + c, _, _ := w.(http.Hijacker).Hijack() + c.Close() + return + } + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + if c := r.Header.Get("Te"); c != "trailers" { + t.Errorf("handler got Te header value %q; want 'trailers'", c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } + if c := r.Header.Get("Proxy-Connection"); c != "" { + t.Errorf("handler got Proxy-Connection header value %q", c) + } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } + w.Header().Set("Trailers", "not a special header field name") + w.Header().Set("Trailer", "X-Trailer") + w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") + http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + w.Header().Set("X-Trailer", "trailer_value") + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close, TE") + getReq.Header.Add("Te", "foo") + getReq.Header.Add("Te", "bar, trailers") + getReq.Header.Set("Proxy-Connection", "should be deleted") + getReq.Header.Set("Upgrade", "foo") + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { + t.Errorf("header Trailers = %q; want %q", g, e) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { + t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) + } + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { + t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) + } + if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) + } + + // Test that a backend failing to be reached or one which doesn't return + // a response results in a StatusBadGateway. + getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) + getReq.Close = true + res, err = frontendClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusBadGateway { + t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) + } + +} + +// Issue 16875: remove any proxied headers mentioned in the "Connection" +// header value. +func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { + const fakeConnectionToken = "X-Fake-Connection-Token" + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := r.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) + w.Header().Add("Connection", someConnHeader) + w.Header().Set(someConnHeader, "should be deleted") + w.Header().Set(fakeConnectionToken, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") + } + c := r.Header["Connection"] + var cf []string + for _, f := range c { + for _, sf := range strings.Split(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + cf = append(cf, sf) + } + } + } + slices.Sort(cf) + expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} + slices.Sort(expectedValues) + if !reflect.DeepEqual(cf, expectedValues) { + t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") + getReq.Header.Set(fakeConnectionToken, "should be deleted") + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + if c := res.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } +} + +func TestReverseProxyStripEmptyConnection(t *testing.T) { + // See Issue 46313. + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Values("Connection"); len(c) != 0 { + t.Errorf("handler got header %q = %v; want empty", "Connection", c) + } + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + w.Header().Add("Connection", "") + w.Header().Add("Connection", someConnHeader) + w.Header().Set(someConnHeader, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Add("Connection", "") + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } +} + +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +// Issue 38079: don't append to X-Forwarded-For if it's present but nil +func TestXForwardedFor_Omit(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if v := r.Header.Get("X-Forwarded-For"); v != "" { + t.Errorf("got X-Forwarded-For header: %q", v) + } + w.Write([]byte("hi")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + r.Header["X-Forwarded-For"] = nil + oldDirector(r) + } + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +func TestReverseProxyRewriteStripsForwarded(t *testing.T) { + headers := []string{ + "Forwarded", + "X-Forwarded-For", + "X-Forwarded-Host", + "X-Forwarded-Proto", + } + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, h := range headers { + if v := r.Header.Get(h); v != "" { + t.Errorf("got %v header: %q", h, v) + } + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + for _, h := range headers { + getReq.Header.Set(h, "x") + } + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +var proxyQueryTests = []struct { + baseSuffix string // suffix to add to backend URL + reqSuffix string // suffix to add to frontend's request URL + want string // what backend should see for final request URL (without ?) +}{ + {"", "", ""}, + {"?sta=tic", "?us=er", "sta=tic&us=er"}, + {"", "?us=er", "us=er"}, + {"?sta=tic", "", "sta=tic"}, +} + +func TestReverseProxyQuery(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Got-Query", r.URL.RawQuery) + w.Write([]byte("hi")) + })) + defer backend.Close() + + for i, tt := range proxyQueryTests { + backendURL, err := url.Parse(backend.URL + tt.baseSuffix) + if err != nil { + t.Fatal(err) + } + frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) + req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("%d. Get: %v", i, err) + } + if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { + t.Errorf("%d. got query %q; expected %q", i, g, e) + } + res.Body.Close() + frontend.Close() + } +} + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } +} + +type mockFlusher struct { + http.ResponseWriter + flushed bool +} + +func (m *mockFlusher) Flush() { + m.flushed = true +} + +type wrappedRW struct { + http.ResponseWriter +} + +func (w *wrappedRW) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func TestReverseProxyResponseControllerFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + mf := &mockFlusher{} + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = -1 // flush immediately + proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mf.ResponseWriter = w + w = &wrappedRW{mf} + proxyHandler.ServeHTTP(w, r) + }) + + frontend := httptest.NewServer(proxyWithMiddleware) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + if !mf.flushed { + t.Errorf("response writer was not flushed") + } +} + +func TestReverseProxyFlushIntervalHeaders(t *testing.T) { + const expected = "hi" + stopCh := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("MyHeader", expected) + w.WriteHeader(200) + w.(http.Flusher).Flush() + <-stopCh + })) + defer backend.Close() + defer close(stopCh) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + + ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) + defer cancel() + req = req.WithContext(ctx) + + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + if res.Header.Get("MyHeader") != expected { + t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) + } +} + +func TestReverseProxyCancellation(t *testing.T) { + const backendResponse = "I am the backend" + + reqInFlight := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(reqInFlight) // cause the client to cancel its request + + select { + case <-time.After(10 * time.Second): + // Note: this should only happen in broken implementations, and the + // closenotify case should be instantaneous. + t.Error("Handler never saw CloseNotify") + return + case <-w.(http.CloseNotifier).CloseNotify(): + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(backendResponse)) + })) + + defer backend.Close() + + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + + // Discards errors of the form: + // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + go func() { + <-reqInFlight + frontendClient.Transport.(*http.Transport).CancelRequest(getReq) + }() + res, err := frontendClient.Do(getReq) + if res != nil { + t.Errorf("got response %v; want nil", res.Status) + } + if err == nil { + // This should be an error like: + // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: + // use of closed network connection + t.Error("Server.Client().Do() returned nil error; want non-nil error") + } +} + +func req(t *testing.T, v string) *http.Request { + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) + if err != nil { + t.Fatal(err) + } + return req +} + +// Issue 12344 +func TestNilBody(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + })) + defer backend.Close() + + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backURL, _ := url.Parse(backend.URL) + rp := NewSingleHostReverseProxy(backURL) + r := req(t, "GET / HTTP/1.0\r\n\r\n") + r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working + rp.ServeHTTP(w, r) + })) + defer frontend.Close() + + res, err := http.Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "hi" { + t.Errorf("Got %q; want %q", slurp, "hi") + } +} + +// Issue 15524 +func TestUserAgentHeader(t *testing.T) { + var gotUA string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := new(ReverseProxy) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(req *http.Request) { + req.URL = backendURL + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + for _, sentUA := range []string{"explicit UA", ""} { + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("User-Agent", sentUA) + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() + if got, want := gotUA, sentUA; got != want { + t.Errorf("got forwarded User-Agent %q, want %q", got, want) + } + } +} + +type bufferPool struct { + get func() []byte + put func([]byte) +} + +func (bp bufferPool) Get() []byte { return bp.get() } +func (bp bufferPool) Put(v []byte) { bp.put(v) } + +func TestReverseProxyGetPutBuffer(t *testing.T) { + const msg = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + var ( + mu sync.Mutex + log []string + ) + addLog := func(event string) { + mu.Lock() + defer mu.Unlock() + log = append(log, event) + } + rp := NewSingleHostReverseProxy(backendURL) + const size = 1234 + rp.BufferPool = bufferPool{ + get: func() []byte { + addLog("getBuf") + return make([]byte, size) + }, + put: func(p []byte) { + addLog("putBuf-" + strconv.Itoa(len(p))) + }, + } + frontend := httptest.NewServer(rp) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if string(slurp) != msg { + t.Errorf("msg = %q; want %q", slurp, msg) + } + wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} + mu.Lock() + defer mu.Unlock() + if !reflect.DeepEqual(log, wantLog) { + t.Errorf("Log events = %q; want %q", log, wantLog) + } +} + +func TestReverseProxy_Post(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 200 + var requestBody = bytes.Repeat([]byte("a"), 1<<20) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slurp, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Backend body read = %v", err) + } + if len(slurp) != len(requestBody) { + t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) + } + if !bytes.Equal(slurp, requestBody) { + t.Error("Backend read wrong request body.") // 1MB; omitting details + } + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) + res, err := frontend.Client().Do(postReq) + if err != nil { + t.Fatalf("Do: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +// Issue 16036: send a Request with a nil Body when possible +func TestReverseProxy_NilBody(t *testing.T) { + backendURL, _ := url.Parse("http://fake.tld/") + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Body != nil { + t.Error("Body != nil; want a nil Body") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 502 { + t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) + } +} + +// Issue 33142: always allocate the request headers +func TestReverseProxy_AllocatedHeader(t *testing.T) { + proxyHandler := new(ReverseProxy) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(*http.Request) {} // noop + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header == nil { + t.Error("Header == nil; want a non-nil Header") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + + proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ + Method: "GET", + URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, + Proto: "HTTP/1.0", + ProtoMajor: 1, + }) +} + +// Issue 14237. Test ModifyResponse and that an error from it +// causes the proxy to return StatusBadGateway, or StatusOK otherwise. +func TestReverseProxyModifyResponse(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) + })) + defer backendServer.Close() + + rpURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(resp *http.Response) error { + if resp.Header.Get("X-Hit-Mod") != "true" { + return fmt.Errorf("tried to by-pass proxy") + } + return nil + } + + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + tests := []struct { + url string + wantCode int + }{ + {frontendProxy.URL + "/mod", http.StatusOK}, + {frontendProxy.URL + "/schedule", http.StatusBadGateway}, + } + + for i, tt := range tests { + resp, err := http.Get(tt.url) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) + } + resp.Body.Close() + } +} + +type failingRoundTripper struct{} + +func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errors.New("some error") +} + +type staticResponseRoundTripper struct{ res *http.Response } + +func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return rt.res, nil +} + +func TestReverseProxyErrorHandler(t *testing.T) { + tests := []struct { + name string + wantCode int + errorHandler func(http.ResponseWriter, *http.Request, error) + transport http.RoundTripper // defaults to failingRoundTripper + modifyResponse func(*http.Response) error + }{ + { + name: "default", + wantCode: http.StatusBadGateway, + }, + { + name: "errorhandler", + wantCode: http.StatusTeapot, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + }, + { + name: "modifyresponse_noerr", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return nil + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: 346, + }, + { + name: "modifyresponse_err", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return errors.New("some error to trigger errorHandler") + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: http.StatusTeapot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target := &url.URL{ + Scheme: "http", + Host: "dummy.tld", + Path: "/", + } + rproxy := NewSingleHostReverseProxy(target) + rproxy.Transport = tt.transport + rproxy.ModifyResponse = tt.modifyResponse + if rproxy.Transport == nil { + rproxy.Transport = failingRoundTripper{} + } + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + if tt.errorHandler != nil { + rproxy.ErrorHandler = tt.errorHandler + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL + "/test") + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + resp.Body.Close() + }) + } +} + +// Issue 16659: log errors from short read +func TestReverseProxy_CopyBuffer(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.UnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + var proxyLog bytes.Buffer + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) + donec := make(chan bool, 1) + frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { donec <- true }() + rproxy.ServeHTTP(w, r) + })) + defer frontendProxy.Close() + + if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { + t.Fatalf("want non-nil error") + } + // The race detector complains about the proxyLog usage in logf in copyBuffer + // and our usage below with proxyLog.Bytes() so we're explicitly using a + // channel to ensure that the ReverseProxy's ServeHTTP is done before we + // continue after Get. + <-donec + + expected := []string{ + "EOF", + "read", + } + for _, phrase := range expected { + if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { + t.Errorf("expected log to contain phrase %q", phrase) + } + } +} + +type staticTransport struct { + res *http.Response +} + +func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { + return t.res, nil +} + +func BenchmarkServeHTTP(b *testing.B) { + res := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("")), + } + proxy := &ReverseProxy{ + Director: func(*http.Request) {}, + Transport: &staticTransport{res}, + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + proxy.ServeHTTP(w, r) + } +} + +func TestServeHTTPDeepCopy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello Gopher!")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + type result struct { + before, after string + } + + resultChan := make(chan result, 1) + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + before := r.URL.String() + proxyHandler.ServeHTTP(w, r) + after := r.URL.String() + resultChan <- result{before: before, after: after} + })) + defer frontend.Close() + + want := result{before: "/", after: "/"} + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Do: %v", err) + } + res.Body.Close() + + got := <-resultChan + if got != want { + t.Errorf("got = %+v; want = %+v", got, want) + } +} + +// Issue 18327: verify we always do a deep copy of the Request.Header map +// before any mutations. +func TestClonesRequestHeaders(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + rp := &ReverseProxy{ + Director: func(req *http.Request) { + req.Header.Set("From-Director", "1") + }, + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if v := req.Header.Get("From-Director"); v != "1" { + t.Errorf("From-Directory value = %q; want 1", v) + } + return nil, io.EOF + }), + } + rp.ServeHTTP(httptest.NewRecorder(), req) + + for _, h := range []string{ + "From-Director", + "X-Forwarded-For", + } { + if req.Header.Get(h) != "" { + t.Errorf("%v header mutation modified caller's request", h) + } + } +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestModifyResponseClosesBody(t *testing.T) { + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + closeCheck := new(checkCloser) + logBuf := new(strings.Builder) + outErr := errors.New("ModifyResponse error") + rp := &ReverseProxy{ + Director: func(req *http.Request) {}, + Transport: &staticTransport{&http.Response{ + StatusCode: 200, + Body: closeCheck, + }}, + ErrorLog: log.New(logBuf, "", 0), + ModifyResponse: func(*http.Response) error { + return outErr + }, + } + rec := httptest.NewRecorder() + rp.ServeHTTP(rec, req) + res := rec.Result() + if g, e := res.StatusCode, http.StatusBadGateway; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if !closeCheck.closed { + t.Errorf("body should have been closed") + } + if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { + t.Errorf("ErrorLog %q does not contain %q", g, e) + } +} + +type checkCloser struct { + closed bool +} + +func (cc *checkCloser) Close() error { + cc.closed = true + return nil +} + +func (cc *checkCloser) Read(b []byte) (int, error) { + return len(b), nil +} + +// Issue 23643: panic on body copy error +func TestReverseProxy_PanicBodyError(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + rproxy := NewSingleHostReverseProxy(rpURL) + + // Ensure that the handler panics when the body read encounters an + // io.ErrUnexpectedEOF + defer func() { + err := recover() + if err == nil { + t.Fatal("handler should have panicked") + } + if err != http.ErrAbortHandler { + t.Fatal("expected ErrAbortHandler, got", err) + } + }() + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + rproxy.ServeHTTP(httptest.NewRecorder(), req) +} + +/* Commented out because `neverEnding` is not available and not something I want to copy in. +// Issue #46866: panic without closing incoming request body causes a panic +func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + const reqLen = 6 * 1024 * 1024 + req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) + req.ContentLength = reqLen + resp, _ := frontendClient.Transport.RoundTrip(req) + if resp != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + } + }() + } + wg.Wait() +} +*/ + +func TestSelectFlushInterval(t *testing.T) { + tests := []struct { + name string + p *ReverseProxy + res *http.Response + want time.Duration + }{ + { + name: "default", + res: &http.Response{}, + p: &ReverseProxy{FlushInterval: 123}, + want: 123, + }, + { + name: "server-sent events overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { + name: "Content-Length: -1, overrides non-zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "Content-Length: -1, overrides zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.p.flushInterval(tt.res) + if got != tt.want { + t.Errorf("flushLatency = %v; want %v", got, tt.want) + } + }) + } +} + +func TestReverseProxyWebSocket(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if upgradeType(r.Header) != "websocket" { + t.Error("unexpected backend request") + http.Error(w, "unexpected request", 400) + return + } + c, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer c.Close() + io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") + bs := bufio.NewScanner(c) + if !bs.Scan() { + t.Errorf("backend failed to read line from client: %v", bs.Err()) + return + } + fmt.Fprintf(c, "backend got %q\n", bs.Text()) + })) + defer backendServer.Close() + + backURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(backURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + rproxy.ServeHTTP(rw, req) + if got, want := rw.Header().Get("X-Modified"), "true"; got != want { + t.Errorf("response writer X-Modified header = %q; want %q", got, want) + } + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + c := frontendProxy.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("status = %v; want 101", res.Status) + } + + got := res.Header.Get("X-Header") + want := "X-Value" + if got != want { + t.Errorf("Header(XHeader) = %q; want %q", got, want) + } + + if !EqualFold(upgradeType(res.Header), "websocket") { + t.Fatalf("not websocket upgrade; got %#v", res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) + } + defer rwc.Close() + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + io.WriteString(rwc, "Hello\n") + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("Scan: %v", bs.Err()) + } + got = bs.Text() + want = `backend got "Hello"` + if got != want { + t.Errorf("got %#q, want %#q", got, want) + } +} + +func TestReverseProxyWebSocketCancellation(t *testing.T) { + n := 5 + triggerCancelCh := make(chan bool, n) + nthResponse := func(i int) string { + return fmt.Sprintf("backend response #%d\n", i) + } + terminalMsg := "final message" + + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if g, ws := upgradeType(r.Header), "websocket"; g != ws { + t.Errorf("Unexpected upgrade type %q, want %q", g, ws) + http.Error(w, "Unexpected request", 400) + return + } + conn, bufrw, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" + if _, err := io.WriteString(conn, upgradeMsg); err != nil { + t.Error(err) + return + } + if _, _, err := bufrw.ReadLine(); err != nil { + t.Errorf("Failed to read line from client: %v", err) + return + } + + for i := 0; i < n; i++ { + if _, err := bufrw.WriteString(nthResponse(i)); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Writing response #%d failed: %v", i, err) + } + return + } + bufrw.Flush() + time.Sleep(time.Second) + } + if _, err := bufrw.WriteString(terminalMsg); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Failed to write terminal message: %v", err) + } + } + bufrw.Flush() + })) + defer cst.Close() + + backendURL, _ := url.Parse(cst.URL) + rproxy := NewSingleHostReverseProxy(backendURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + ctx, cancel := context.WithCancel(req.Context()) + go func() { + <-triggerCancelCh + cancel() + }() + rproxy.ServeHTTP(rw, req.WithContext(ctx)) + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + res, err := frontendProxy.Client().Do(req) + if err != nil { + t.Fatalf("Dialing to frontend proxy: %v", err) + } + defer res.Body.Close() + if g, w := res.StatusCode, 101; g != w { + t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) + } + + if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { + t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + if g, w := upgradeType(res.Header), "websocket"; !EqualFold(g, w) { + t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) + } + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + if _, err := io.WriteString(rwc, "Hello\n"); err != nil { + t.Fatalf("Failed to write first message: %v", err) + } + + // Read loop. + + br := bufio.NewReader(rwc) + for { + line, err := br.ReadString('\n') + switch { + case line == terminalMsg: // this case before "err == io.EOF" + t.Fatalf("The websocket request was not canceled, unfortunately!") + + case err == io.EOF: + return + + case err != nil: + t.Fatalf("Unexpected error: %v", err) + + case line == nthResponse(0): // We've gotten the first response back + // Let's trigger a cancel. + close(triggerCancelCh) + } + } +} + +func TestUnannouncedTrailer(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + + io.ReadAll(res.Body) + + if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) + } + +} + +func TestSetURL(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.Host)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Reading body: %v", err) + } + + if got, want := string(body), backendURL.Host; got != want { + t.Errorf("backend got Host %q, want %q", got, want) + } +} + +func TestSingleJoinSlash(t *testing.T) { + tests := []struct { + slasha string + slashb string + expected string + }{ + {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "", "https://www.google.com/"}, + {"", "favicon.ico", "/favicon.ico"}, + } + for _, tt := range tests { + if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { + t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", + tt.slasha, + tt.slashb, + tt.expected, + got) + } + } +} + +func TestJoinURLPath(t *testing.T) { + tests := []struct { + a *url.URL + b *url.URL + wantPath string + wantRaw string + }{ + {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, + {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, + } + + for _, tt := range tests { + p, rp := joinURLPath(tt.a, tt.b) + if p != tt.wantPath || rp != tt.wantRaw { + t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", + tt.a.Path, tt.a.RawPath, + tt.b.Path, tt.b.RawPath, + tt.wantPath, tt.wantRaw, + p, rp) + } + } +} + +func TestReverseProxyRewriteReplacesOut(t *testing.T) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(content)) + })) + defer backend.Close() + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.Out, _ = http.NewRequest("GET", backend.URL, nil) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if got, want := string(body), content; got != want { + t.Errorf("got response %q, want %q", got, want) + } +} + +func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) { + // https://go.dev/issue/65123: We use httptrace.Got1xxResponse to capture 1xx responses + // and proxy them. httptrace handlers can execute after RoundTrip returns, in particular + // after experiencing connection errors. When this happens, we shouldn't modify the + // ResponseWriter headers after ReverseProxy.ServeHTTP returns. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < 5; i++ { + w.WriteHeader(103) + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + + rw := &testResponseWriter{} + func() { + // Cancel the request (and cause RoundTrip to return) immediately upon + // seeing a 1xx response. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + cancel() + return nil + }, + }) + + req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil) + proxyHandler.ServeHTTP(rw, req) + }() + // Trigger data race while iterating over response headers. + // When run with -race, this causes the condition in https://go.dev/issue/65123 often + // enough to detect reliably. + for _ = range rw.Header() { + } +} + +func Test1xxResponses(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + w.Write([]byte("Hello")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) + + res, err := frontendClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +} + +const ( + testWantsCleanQuery = true + testWantsRawQuery = false +) + +func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + // Parsing the form causes ReverseProxy to remove unparsable + // query parameters before forwarding. + r.FormValue("a") + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + }, + } + }) +} + +func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + r.Out.URL.RawQuery = r.In.URL.RawQuery + }, + } + }) +} + +func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.URL.RawQuery)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := newProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + // Don't spam output with logs of queries containing semicolons. + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + frontend.Config.ErrorLog = log.New(io.Discard, "", 0) + + for _, test := range []struct { + rawQuery string + cleanQuery string + }{{ + rawQuery: "a=1&a=2;b=3", + cleanQuery: "a=1", + }, { + rawQuery: "a=1&a=%zz&b=3", + cleanQuery: "a=1&b=3", + }} { + res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + wantQuery := test.rawQuery + if wantCleanQuery { + wantQuery = test.cleanQuery + } + if got, want := string(body), wantQuery; got != want { + t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want) + } + } +} + +type testResponseWriter struct { + h http.Header + writeHeader func(int) + write func([]byte) (int, error) +} + +func (rw *testResponseWriter) Header() http.Header { + if rw.h == nil { + rw.h = make(http.Header) + } + return rw.h +} + +func (rw *testResponseWriter) WriteHeader(statusCode int) { + if rw.writeHeader != nil { + rw.writeHeader(statusCode) + } +} + +func (rw *testResponseWriter) Write(p []byte) (int, error) { + if rw.write != nil { + return rw.write(p) + } + return len(p), nil +} diff --git a/internal/iso/ipam.go b/internal/iso/ipam.go new file mode 100644 index 00000000..d5e6f70f --- /dev/null +++ b/internal/iso/ipam.go @@ -0,0 +1,79 @@ +package iso + +import ( + "fmt" + "net" + "net/netip" + "strings" + + "github.com/tinkerbell/smee/internal/dhcp/data" +) + +func parseIPAM(d *data.DHCP) string { + if d == nil { + return "" + } + // return format is ipam=:::::::: + ipam := make([]string, 9) + ipam[0] = func() string { + m := d.MACAddress.String() + + return strings.ReplaceAll(m, ":", "-") + }() + ipam[1] = func() string { + if d.VLANID != "" { + return d.VLANID + } + return "" + }() + ipam[2] = func() string { + if d.IPAddress.Compare(netip.Addr{}) != 0 { + return d.IPAddress.String() + } + return "" + }() + ipam[3] = func() string { + if d.SubnetMask != nil { + return net.IP(d.SubnetMask).String() + } + return "" + }() + ipam[4] = func() string { + if d.DefaultGateway.Compare(netip.Addr{}) != 0 { + return d.DefaultGateway.String() + } + return "" + }() + ipam[5] = d.Hostname + ipam[6] = func() string { + var nameservers []string + for _, e := range d.NameServers { + nameservers = append(nameservers, e.String()) + } + if len(nameservers) > 0 { + return strings.Join(nameservers, ",") + } + + return "" + }() + ipam[7] = func() string { + if len(d.DomainSearch) > 0 { + return strings.Join(d.DomainSearch, ",") + } + + return "" + }() + ipam[8] = func() string { + var ntp []string + for _, e := range d.NTPServers { + ntp = append(ntp, e.String()) + } + if len(ntp) > 0 { + return strings.Join(ntp, ",") + } + + return "" + }() + + return fmt.Sprintf("ipam=%s", strings.Join(ipam, ":")) +} diff --git a/internal/iso/ipam_test.go b/internal/iso/ipam_test.go new file mode 100644 index 00000000..597203ba --- /dev/null +++ b/internal/iso/ipam_test.go @@ -0,0 +1,46 @@ +package iso + +import ( + "net" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/tinkerbell/smee/internal/dhcp/data" +) + +func TestParseIPAM(t *testing.T) { + tests := map[string]struct { + input *data.DHCP + want string + }{ + "empty": {}, + "only MAC": { + input: &data.DHCP{MACAddress: net.HardwareAddr{0xde, 0xed, 0xbe, 0xef, 0xfe, 0xed}}, + want: "ipam=de-ed-be-ef-fe-ed::::::::", + }, + "everything": { + input: &data.DHCP{ + MACAddress: net.HardwareAddr{0xde, 0xed, 0xbe, 0xef, 0xfe, 0xed}, + IPAddress: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + SubnetMask: net.IPv4Mask(255, 255, 255, 0), + DefaultGateway: netip.AddrFrom4([4]byte{127, 0, 0, 2}), + NameServers: []net.IP{{1, 1, 1, 1}, {4, 4, 4, 4}}, + Hostname: "myhost", + NTPServers: []net.IP{{129, 6, 15, 28}, {129, 6, 15, 29}}, + DomainSearch: []string{"example.com", "example.org"}, + VLANID: "400", + }, + want: "ipam=de-ed-be-ef-fe-ed:400:127.0.0.1:255.255.255.0:127.0.0.2:myhost:1.1.1.1,4.4.4.4:example.com,example.org:129.6.15.28,129.6.15.29", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + got := parseIPAM(tt.input) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("diff: %v", diff) + } + }) + } +} diff --git a/internal/iso/iso.go b/internal/iso/iso.go index c122e948..b14c5d7f 100644 --- a/internal/iso/iso.go +++ b/internal/iso/iso.go @@ -10,8 +10,6 @@ import ( "math/big" "net" "net/http" - "net/http/httputil" - "net/netip" "net/url" "path" "path/filepath" @@ -21,11 +19,11 @@ import ( "github.com/go-logr/logr" "github.com/tinkerbell/smee/internal/dhcp/data" + "github.com/tinkerbell/smee/internal/iso/internal" ) const ( - defaultConsoles = "console=ttyAMA0 console=ttyS0 console=tty0 console=tty1 console=ttyS1" - maxContentLength int64 = 500 * 1024 // 500Kb + defaultConsoles = "console=ttyAMA0 console=ttyS0 console=tty0 console=tty1 console=ttyS1" ) // BackendReader is an interface that defines the method to read data from a backend. @@ -55,7 +53,8 @@ type Handler struct { StaticIPAMEnabled bool // parsedURL derives a url.URL from the SourceISO field. // It needed for validation of SourceISO and easier modification. - parsedURL *url.URL + parsedURL *url.URL + magicStrPadding []byte } // HandlerFunc returns a reverse proxy HTTP handler function that performs ISO patching. @@ -65,28 +64,68 @@ func (h *Handler) HandlerFunc() (http.HandlerFunc, error) { return nil, err } h.parsedURL = target - proxy := httputil.NewSingleHostReverseProxy(target) + + proxy := internal.NewSingleHostReverseProxy(target) proxy.Transport = h proxy.FlushInterval = -1 + proxy.CopyBuffer = h + + h.magicStrPadding = bytes.Repeat([]byte{' '}, len(h.MagicString)) return proxy.ServeHTTP, nil } +// Copy implements the internal.CopyBuffer interface. +// This implementation allows us to inspect and patch content on its way to the client without buffering the entire response +// in memory. This allows memory use to be constant regardless of the size of the response. +func (h *Handler) Copy(ctx context.Context, dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { //nolint: errorlint // going to defer to the stdlib on this one. + h.Logger.Info("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + // This is the patching check and handling. + b := buf[:nr] + i := bytes.Index(b, []byte(h.MagicString)) + if i != -1 { + dup := make([]byte, len(b)) + copy(dup, b) + copy(dup[i:], h.magicStrPadding) + copy(dup[i:], internal.GetPatch(ctx)) + b = dup + } + nw, werr := dst.Write(b) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + // RoundTrip is a method on the Handler struct that implements the http.RoundTripper interface. -// This method is called by the httputil.NewSingleHostReverseProxy to handle the incoming request. -// The method is responsible for validating the incoming request, reading the source ISO, patching the ISO. +// This method is called by the internal.NewSingleHostReverseProxy to handle the incoming request. +// The method is responsible for validating the incoming request and getting the source ISO. func (h *Handler) RoundTrip(req *http.Request) (*http.Response, error) { - log := h.Logger.WithValues("method", req.Method, "urlPath", req.URL.Path, "remoteAddr", req.RemoteAddr, "fullURL", req.URL.String()) + log := h.Logger.WithValues("method", req.Method, "urlPath", req.URL.Path, "remoteAddr", req.RemoteAddr) log.V(1).Info("starting the ISO patching HTTP handler") - if req.Method != http.MethodHead && req.Method != http.MethodGet { - return &http.Response{ - Status: fmt.Sprintf("%d %s", http.StatusNotImplemented, http.StatusText(http.StatusNotImplemented)), - StatusCode: http.StatusNotImplemented, - Body: http.NoBody, - Request: req, - }, nil - } if filepath.Ext(req.URL.Path) != ".iso" { log.Info("extension not supported, only supported extension is '.iso'") @@ -132,8 +171,21 @@ func (h *Handler) RoundTrip(req *http.Request) (*http.Response, error) { Request: req, }, nil } + // The hardware object doesn't contain a dedicated field for consoles right now and + // historically the facility is used as a way to define consoles on a per Hardware basis. + var consoles string + switch { + case fac != "" && strings.Contains(fac, "console="): + consoles = fmt.Sprintf("facility=%s", fac) + case fac != "": + consoles = fmt.Sprintf("facility=%s %s", fac, defaultConsoles) + default: + consoles = defaultConsoles + } + // The patch is added to the request context so that it can be used in the Copy method. + req = req.WithContext(internal.WithPatch(req.Context(), []byte(h.constructPatch(consoles, ha.String(), dhcpData)))) - // The httputil.NewSingleHostReverseProxy takes the incoming request url and adds the path to the target (h.SourceISO). + // The internal.NewSingleHostReverseProxy takes the incoming request url and adds the path to the target (h.SourceISO). // This function is more than a pass through proxy. The MAC address in the url path is required to do hardware lookups using the backend reader // and is not used when making http calls to the target (h.SourceISO). All valid requests are passed through to the target. req.URL.Path = h.parsedURL.Path @@ -149,104 +201,18 @@ func (h *Handler) RoundTrip(req *http.Request) (*http.Response, error) { // we do this because there are a lot of partial content requests and it allow this handler to take care of logging. resp.Header.Set("X-Global-Logging", "false") - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { - // This log line is not rate limited as we don't anticipate this to be a common occurrence or happen frequently when it does. - log.Info("the request to get the source ISO returned a status other than ok (200) or partial content (206)", "sourceIso", h.SourceISO, "status", resp.Status) - return resp, nil - } - - if req.Method == http.MethodHead { - // Fuse clients typically make a HEAD request before they start requesting content. This is not rate limited as the occurrence is expected to be low. - // This allows provides us some info on the progress of the client. - log.Info("HTTP HEAD method received", "status", resp.Status) - return resp, nil - } - - // At this point we only allow HTTP GET method with a 206 status code. - // Otherwise we are potentially reading the entire ISO file and patching it. - // This is not the intended use case for this handler. - // And this can cause memory issues, like OOM, if the ISO file is too large. - // By returning the `resp` here we allow clients to download the ISO file but without any patching. - // This is done so that there can be a minimal amount of troubleshooting for SourceISO issues. - if resp.StatusCode != http.StatusPartialContent { - log.Info("HTTP GET method received with a status code other than 206, source iso will be unpatched", "status", resp.Status, "respHeader", resp.Header, "reqHeaders", resp.Request.Header) - return resp, nil - } - // If the request is a partial content request, we need to validate the Content-Range header. - // Because we read the entire response body into memory for patching, we need to ensure that the - // Content-Range is within a reasonable size. For now, we are limiting the size to 500Kb (partialContentMax). - - // Content range RFC: https://tools.ietf.org/html/rfc7233#section-4.2 - // https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 - - // Get the content range from the response header - if resp.ContentLength > maxContentLength { - log.Info("content length is greater than max", "contentLengthBytes", resp.ContentLength, "maxAllowedBytes", maxContentLength) - return resp, nil - } - - // 0.002% of the time we log a 206 request message. - // In testing, it was observed that about 3000 HTTP 206 requests are made per ISO mount. - // 0.002% gives us about 5 - 10, log messages per ISO mount. - // We're optimizing for showing "enough" log messages so that progress can be observed. - if p := randomPercentage(100000); p < 0.002 { - log.Info("HTTP GET method received with a 206 status code") - } - - // this roundtripper func should only return error when there is no response from the server. - // for any other case we log the error and return a 500 response. See the http.RoundTripper interface code - // comments for more details. - var b []byte - respBuf := new(bytes.Buffer) - if _, err := io.CopyN(respBuf, resp.Body, resp.ContentLength); err != nil { - log.Info("unable to read response bytes", "error", err) - return &http.Response{ - Status: fmt.Sprintf("%d %s", http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)), - StatusCode: http.StatusInternalServerError, - Body: http.NoBody, - Request: req, - Header: resp.Header, - }, nil - } - b = respBuf.Bytes() - if err := resp.Body.Close(); err != nil { - log.Info("unable to close response body", "error", err) - return &http.Response{ - Status: fmt.Sprintf("%d %s", http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)), - StatusCode: http.StatusInternalServerError, - Body: http.NoBody, - Request: req, - Header: resp.Header, - }, nil - } - - // The hardware object doesn't contain a dedicated field for consoles right now and - // historically the facility is used as a way to define consoles on a per Hardware basis. - var consoles string - switch { - case fac != "" && strings.Contains(fac, "console="): - consoles = fmt.Sprintf("facility=%s", fac) - case fac != "": - consoles = fmt.Sprintf("facility=%s %s", fac, defaultConsoles) - default: - consoles = defaultConsoles - } - magicStringPadding := bytes.Repeat([]byte{' '}, len(h.MagicString)) - - // TODO: revisit later to handle the magic string potentially being spread across two chunks. - // In current implementation we will never patch the above case. Add logic to patch the case of - // magic string spread across multiple response bodies in the future. - i := bytes.Index(b, []byte(h.MagicString)) - if i != -1 { - log.Info("magic string found, patching the content", "contentRange", resp.Header.Get("Content-Range")) - dup := make([]byte, len(b)) - copy(dup, b) - copy(dup[i:], magicStringPadding) - copy(dup[i:], []byte(h.constructPatch(consoles, ha.String(), dhcpData))) - b = dup + if resp.StatusCode == http.StatusPartialContent { + // 0.002% of the time we log a 206 request message. + // In testing, it was observed that about 3000 HTTP 206 requests are made per ISO mount. + // 0.002% gives us about 5 - 10, log messages per ISO mount. + // We're optimizing for showing "enough" log messages so that progress can be observed. + if p := randomPercentage(100000); p < 0.002 { + log.Info("206 status code response", "sourceIso", h.SourceISO, "status", resp.Status) + } + } else { + log.Info("response received", "sourceIso", h.SourceISO, "status", resp.Status) } - resp.Body = io.NopCloser(bytes.NewReader(b)) log.V(1).Info("roundtrip complete") return resp, nil @@ -303,72 +269,3 @@ func randomPercentage(precision int64) float64 { return float64(random.Int64()) / float64(precision) } - -func parseIPAM(d *data.DHCP) string { - if d == nil { - return "" - } - // return format is ipam=:::::::: - ipam := make([]string, 9) - ipam[0] = func() string { - m := d.MACAddress.String() - - return strings.ReplaceAll(m, ":", "-") - }() - ipam[1] = func() string { - if d.VLANID != "" { - return d.VLANID - } - return "" - }() - ipam[2] = func() string { - if d.IPAddress.Compare(netip.Addr{}) != 0 { - return d.IPAddress.String() - } - return "" - }() - ipam[3] = func() string { - if d.SubnetMask != nil { - return net.IP(d.SubnetMask).String() - } - return "" - }() - ipam[4] = func() string { - if d.DefaultGateway.Compare(netip.Addr{}) != 0 { - return d.DefaultGateway.String() - } - return "" - }() - ipam[5] = d.Hostname - ipam[6] = func() string { - var nameservers []string - for _, e := range d.NameServers { - nameservers = append(nameservers, e.String()) - } - if len(nameservers) > 0 { - return strings.Join(nameservers, ",") - } - - return "" - }() - ipam[7] = func() string { - if len(d.DomainSearch) > 0 { - return strings.Join(d.DomainSearch, ",") - } - - return "" - }() - ipam[8] = func() string { - var ntp []string - for _, e := range d.NTPServers { - ntp = append(ntp, e.String()) - } - if len(ntp) > 0 { - return strings.Join(ntp, ",") - } - - return "" - }() - - return fmt.Sprintf("ipam=%s", strings.Join(ipam, ":")) -} diff --git a/internal/iso/iso_test.go b/internal/iso/iso_test.go index f5c7f4c3..b5816cfd 100644 --- a/internal/iso/iso_test.go +++ b/internal/iso/iso_test.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/http/httptest" - "net/netip" "net/url" "os" "testing" @@ -142,24 +141,22 @@ menuentry 'LinuxKit ISO Image' { parsedURL: parsedURL, MagicString: magicString, } + h.magicStrPadding = bytes.Repeat([]byte{' '}, len(h.MagicString)) // for debugging enable a logger // h.Logger = logr.FromSlogHandler(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{AddSource: true})) - rurl := hs.URL + "/iso/de:ed:be:ef:fe:ed/output.iso" - purl, _ := url.Parse(rurl) - req := http.Request{ - Header: http.Header{}, - Method: http.MethodGet, - URL: purl, - } - req.Header.Set("Range", "bytes=0-") - res, err := h.RoundTrip(&req) + hf, err := h.HandlerFunc() if err != nil { t.Fatal(err) } + + w := httptest.NewRecorder() + hf.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/iso/de:ed:be:ef:fe:ed/output.iso", nil)) + + res := w.Result() defer res.Body.Close() - if res.StatusCode != http.StatusPartialContent { - t.Fatalf("got status code: %d, want status code: %d", res.StatusCode, http.StatusPartialContent) + if res.StatusCode != http.StatusOK { + t.Fatalf("got status code: %d, want status code: %d", res.StatusCode, http.StatusOK) } isoContents, err := io.ReadAll(res.Body) @@ -195,39 +192,3 @@ func (m *mockBackend) GetByIP(context.Context, net.IP) (*data.DHCP, *data.Netboo } return d, n, nil } - -func TestParseIPAM(t *testing.T) { - tests := map[string]struct { - input *data.DHCP - want string - }{ - "empty": {}, - "only MAC": { - input: &data.DHCP{MACAddress: net.HardwareAddr{0xde, 0xed, 0xbe, 0xef, 0xfe, 0xed}}, - want: "ipam=de-ed-be-ef-fe-ed::::::::", - }, - "everything": { - input: &data.DHCP{ - MACAddress: net.HardwareAddr{0xde, 0xed, 0xbe, 0xef, 0xfe, 0xed}, - IPAddress: netip.AddrFrom4([4]byte{127, 0, 0, 1}), - SubnetMask: net.IPv4Mask(255, 255, 255, 0), - DefaultGateway: netip.AddrFrom4([4]byte{127, 0, 0, 2}), - NameServers: []net.IP{{1, 1, 1, 1}, {4, 4, 4, 4}}, - Hostname: "myhost", - NTPServers: []net.IP{{129, 6, 15, 28}, {129, 6, 15, 29}}, - DomainSearch: []string{"example.com", "example.org"}, - VLANID: "400", - }, - want: "ipam=de-ed-be-ef-fe-ed:400:127.0.0.1:255.255.255.0:127.0.0.2:myhost:1.1.1.1,4.4.4.4:example.com,example.org:129.6.15.28,129.6.15.29", - }, - } - - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - got := parseIPAM(tt.input) - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatalf("diff: %v", diff) - } - }) - } -}