Skip to content

Commit

Permalink
add headerbp middleware to httpbp
Browse files Browse the repository at this point in the history
  • Loading branch information
pacejackson committed Jan 9, 2025
1 parent 1bc1961 commit 954f000
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 0 deletions.
29 changes: 29 additions & 0 deletions httpbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/prometheus/client_golang/prometheus"

"github.com/reddit/baseplate.go/breakerbp"
"github.com/reddit/baseplate.go/internal/headerbp"
//lint:ignore SA1019 This library is internal only, not actually deprecated
"github.com/reddit/baseplate.go/internalv2compat"
"github.com/reddit/baseplate.go/retrybp"
Expand Down Expand Up @@ -88,6 +89,11 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien
if config.CircuitBreaker != nil {
defaults = append([]ClientMiddleware{CircuitBreaker(*config.CircuitBreaker)}, defaults...)
}

// only add the middleware to forward baseplate headers if the client is for internal calls
if config.InternalOnly {
defaults = append(defaults, ForwardBaseplateHeaders(config.Slug))
}
middleware = append(middleware, defaults...)

return &http.Client{
Expand Down Expand Up @@ -349,3 +355,26 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware {
})
}
}

// ForwardBaseplateHeaders is a middleware that forwards baseplate headers from the context to the outgoing request.
//
// If it detects any new baseplate headers set on the request, it will reject the request and return an error.
func ForwardBaseplateHeaders(client string) ClientMiddleware {
return func(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
for k := range req.Header {
if err := headerbp.CheckClientHeader(k,
headerbp.WithHTTPClient("", client, ""),
); err != nil {
return nil, err
}
}
headerbp.SetOutgoingHeaders(
req.Context(),
headerbp.WithHTTPClient("", client, ""),
headerbp.WithHeaderSetter(req.Header.Set),
)
return next.RoundTrip(req)
})
}
}
1 change: 1 addition & 0 deletions httpbp/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ClientConfig struct {
MaxConnections int `yaml:"maxConnections"`
CircuitBreaker *breakerbp.Config `yaml:"circuitBreaker"`
RetryOptions []retry.Option
InternalOnly bool
}

// Validate checks ClientConfig for any missing or erroneous values.
Expand Down
27 changes: 27 additions & 0 deletions httpbp/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/reddit/baseplate.go/ecinterface"
"github.com/reddit/baseplate.go/errorsbp"
"github.com/reddit/baseplate.go/internal/headerbp"
//lint:ignore SA1019 This library is internal only, not actually deprecated
"github.com/reddit/baseplate.go/internalv2compat"
"github.com/reddit/baseplate.go/log"
Expand Down Expand Up @@ -517,3 +518,29 @@ func (rr *responseRecorder) WriteHeader(code int) {
rr.ResponseWriter.WriteHeader(code)
rr.responseCode = code
}

// ExtractBaseplateHeaders is a middleware that extracts baseplate headers from the incoming request and adds them to the context.
func ExtractBaseplateHeaders(service string) Middleware {
return func(name string, next HandlerFunc) HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) {
if r.Header.Get(headerbp.IsUntrustedRequestHeaderCanonicalHTTP) != "" {
for k := range r.Header {
if headerbp.IsBaseplateHeader(k) {
r.Header.Del(k)
}
}
return next(ctx, w, r)
}
headers := headerbp.NewIncomingHeaders(
headerbp.WithHTTPService(service, name),
)
for k, v := range r.Header {
if len(v) > 0 {
headers.RecordHeader(k, v[0])
}
}
ctx = headers.SetOnContext(ctx)
return next(ctx, w, r.WithContext(ctx))
}
}
}
274 changes: 274 additions & 0 deletions httpbp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/reddit/baseplate.go"
"github.com/reddit/baseplate.go/ecinterface"
"github.com/reddit/baseplate.go/httpbp"
"github.com/reddit/baseplate.go/internal/headerbp"
"github.com/reddit/baseplate.go/log"
"github.com/reddit/baseplate.go/secrets"
)

func TestEndpoint(t *testing.T) {
Expand Down Expand Up @@ -427,3 +433,271 @@ func TestPanicRecovery(t *testing.T) {
t.Fatalf("unexpected service code")
}
}

func TestBaseplateHeaderPropagation(t *testing.T) {
expectedHeaders := map[string][]string{
"x-bp-from-edge": {"true"},
"x-bp-test": {"foo"},
}
store, _, err := secrets.NewTestSecrets(context.TODO(), nil)
if err != nil {
t.Fatalf("failed to create test secrets: %v", err)
}
t.Cleanup(func() {
store.Close()
})
bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{
Config: baseplate.Config{
Addr: ":8081",
},
Store: store,
EdgeContextImpl: ecinterface.Mock(),
})
downstreamServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{
Baseplate: bp,
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{
"/say-hello": {
Name: "say-hello",
Methods: []string{http.MethodGet},
Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error {
for wantKey, wantValue := range expectedHeaders {
if v := request.Header.Values(wantKey); len(v) == 0 {
t.Fatalf("missing header %q", wantKey)
} else if diff := cmp.Diff(v, wantValue, cmpopts.SortSlices(func(a, b string) bool {
return a < b
})); diff != "" {
t.Fatalf("header %q values mismatch (-want +got):\n%s", wantKey, diff)
}
}
return nil
},
},
},
Middlewares: []httpbp.Middleware{
httpbp.ExtractBaseplateHeaders("originHTTPBPV0"),
},
})
if err != nil {
t.Fatalf("failed to create test downstreamServer: %v", err)
}
t.Cleanup(func() {
downstreamServer.Close()
})
go downstreamServer.Serve()

downstreamBaseURL, err := url.Parse("http://" + downstreamServer.Baseplate().GetConfig().Addr + "/")
if err != nil {
t.Fatalf("failed to parse test originServer base URL: %v", err)
}

downstreamClient, err := httpbp.NewClient(
httpbp.ClientConfig{
Slug: "downstreamHTTPBPV0",
InternalOnly: true,
},
withBaseURL(downstreamBaseURL),
)
if err != nil {
t.Fatalf("failed to create test client: %v", err)
}

originServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{
Baseplate: bp,
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{
"/say-hello": {
Name: "say-hello",
Methods: []string{http.MethodGet},
Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error {
for wantKey, wantValue := range expectedHeaders {
if v := request.Header.Values(wantKey); len(v) == 0 {
t.Fatalf("missing header %q", wantKey)
} else if diff := cmp.Diff(v, wantValue, cmpopts.SortSlices(func(a, b string) bool {
return a < b
})); diff != "" {
t.Fatalf("header %q values mismatch (-want +got):\n%s", wantKey, diff)
}
}

req, err := http.NewRequest(
http.MethodGet,
downstreamBaseURL.JoinPath("say-hello").String(),
nil,
)
if err != nil {
t.Fatalf("creating request: %v", err)
}

resp, err := downstreamClient.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}

invalidReq, err := http.NewRequest(
http.MethodGet,
downstreamBaseURL.JoinPath("say-hello").String(),
nil,
)
if err != nil {
t.Fatalf("creating request: %v", err)
}
invalidReq.Header.Set("x-bp-test", "bar")

if _, err := downstreamClient.Do(req); !errors.Is(err, headerbp.ErrNewInternalHeaderNotAllowed) {
t.Fatalf("error mismatch, want %v, got %v", headerbp.ErrNewInternalHeaderNotAllowed, err)
}
return nil
},
},
},
Middlewares: []httpbp.Middleware{
httpbp.ExtractBaseplateHeaders("originHTTPBPV0"),
},
})
if err != nil {
t.Fatalf("failed to create test originServer: %v", err)
}
t.Cleanup(func() {
originServer.Close()
})
go originServer.Serve()

baseURL, err := url.Parse("http://" + originServer.Baseplate().GetConfig().Addr + "/")
if err != nil {
t.Fatalf("failed to parse test originServer base URL: %v", err)
}

client, err := httpbp.NewClient(
httpbp.ClientConfig{
Slug: "downstreamHTTPBPV0",
},
withBaseURL(baseURL),
)
if err != nil {
t.Fatalf("failed to create test client: %v", err)
}

req, err := http.NewRequest(
http.MethodGet,
baseURL.JoinPath("say-hello").String(),
nil,
)
if err != nil {
t.Fatalf("creating request: %v", err)
}
for name, values := range expectedHeaders {
req.Header.Set(name, values[0])
}

resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
}

func TestBaseplateHeaderPropagation_untrusted(t *testing.T) {
expectedHeaders := map[string][]string{
"x-bp-from-edge": {"true"},
"x-bp-test": {"foo"},
}
store, _, err := secrets.NewTestSecrets(context.TODO(), nil)
if err != nil {
t.Fatalf("failed to create test secrets: %v", err)
}
t.Cleanup(func() {
store.Close()
})
bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{
Config: baseplate.Config{
Addr: ":8081",
},
Store: store,
EdgeContextImpl: ecinterface.Mock(),
})

originServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{
Baseplate: bp,
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{
"/say-hello": {
Name: "say-hello",
Methods: []string{http.MethodGet},
Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error {
for wantKey := range expectedHeaders {
if v := request.Header.Values(wantKey); len(v) != 0 {
t.Fatalf("expected no values for header %q, got %+v", wantKey, v)
}
}

return nil
},
},
},
Middlewares: []httpbp.Middleware{
httpbp.ExtractBaseplateHeaders("originHTTPBPV0"),
},
})
if err != nil {
t.Fatalf("failed to create test originServer: %v", err)
}
t.Cleanup(func() {
originServer.Close()
})
go originServer.Serve()

baseURL, err := url.Parse("http://" + originServer.Baseplate().GetConfig().Addr + "/")
if err != nil {
t.Fatalf("failed to parse test originServer base URL: %v", err)
}

client, err := httpbp.NewClient(
httpbp.ClientConfig{
Slug: "downstreamHTTPBPV0",
},
withBaseURL(baseURL),
)
if err != nil {
t.Fatalf("failed to create test client: %v", err)
}

req, err := http.NewRequest(
http.MethodGet,
baseURL.JoinPath("say-hello").String(),
nil,
)
if err != nil {
t.Fatalf("creating request: %v", err)
}
req.Header.Set("X-Rddt-Untrusted", "1")
for name, values := range expectedHeaders {
req.Header.Set(name, values[0])
}

resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
}

func withBaseURL(baseURL *url.URL) httpbp.ClientMiddleware {
return func(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
resolved := req.Clone(req.Context())
resolved.URL = baseURL.ResolveReference(req.URL)
return next.RoundTrip(resolved)
})
}
}

type roundTripperFunc func(req *http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

0 comments on commit 954f000

Please sign in to comment.