-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add headerbp integration #669
Changes from 5 commits
1bc1961
c52cdf1
496bc9b
75e5b96
7bdf051
6342673
a752b34
94c324c
49741ba
4200301
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,6 +15,7 @@ import ( | |||||||||||||||||||||||||||
"github.com/prometheus/client_golang/prometheus" | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
"github.com/reddit/baseplate.go/breakerbp" | ||||||||||||||||||||||||||||
"github.com/reddit/baseplate.go/internal/headerbp" | ||||||||||||||||||||||||||||
//lint:ignore SA1019 This library is internal only, not actually deprecated | ||||||||||||||||||||||||||||
"github.com/reddit/baseplate.go/internalv2compat" | ||||||||||||||||||||||||||||
"github.com/reddit/baseplate.go/retrybp" | ||||||||||||||||||||||||||||
|
@@ -88,6 +89,11 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien | |||||||||||||||||||||||||||
if config.CircuitBreaker != nil { | ||||||||||||||||||||||||||||
defaults = append([]ClientMiddleware{CircuitBreaker(*config.CircuitBreaker)}, defaults...) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
// only add the middleware to forward baseplate headers if the client is for internal calls | ||||||||||||||||||||||||||||
if config.InternalOnly { | ||||||||||||||||||||||||||||
defaults = append(defaults, ClientHeaderBPMiddleware(config.Slug)) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
middleware = append(middleware, defaults...) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
return &http.Client{ | ||||||||||||||||||||||||||||
|
@@ -349,3 +355,26 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware { | |||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
// ClientHeaderBPMiddleware is a middleware that forwards baseplate headers from the context to the outgoing request. | ||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||
// If it detects any new baseplate headers set on the request, it will reject the request and return an error. | ||||||||||||||||||||||||||||
func ClientHeaderBPMiddleware(client string) ClientMiddleware { | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
🔕 maybe best spelled out since we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||
return func(next http.RoundTripper) http.RoundTripper { | ||||||||||||||||||||||||||||
return roundTripperFunc(func(req *http.Request) (*http.Response, error) { | ||||||||||||||||||||||||||||
for k := range req.Header { | ||||||||||||||||||||||||||||
if err := headerbp.CheckClientHeader(k, | ||||||||||||||||||||||||||||
headerbp.WithHTTPClient("", client, ""), | ||||||||||||||||||||||||||||
); err != nil { | ||||||||||||||||||||||||||||
return nil, err | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
Comment on lines
+365
to
+370
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
🔕 I find this slightly more readable 🙂 |
||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
headerbp.SetOutgoingHeaders( | ||||||||||||||||||||||||||||
req.Context(), | ||||||||||||||||||||||||||||
headerbp.WithHTTPClient("", client, ""), | ||||||||||||||||||||||||||||
headerbp.WithHeaderSetter(req.Header.Set), | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
return next.RoundTrip(req) | ||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,13 +7,19 @@ import ( | |
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"reflect" | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"github.com/google/go-cmp/cmp/cmpopts" | ||
|
||
"github.com/reddit/baseplate.go" | ||
"github.com/reddit/baseplate.go/ecinterface" | ||
"github.com/reddit/baseplate.go/httpbp" | ||
"github.com/reddit/baseplate.go/internal/headerbp" | ||
"github.com/reddit/baseplate.go/log" | ||
"github.com/reddit/baseplate.go/secrets" | ||
) | ||
|
||
func TestEndpoint(t *testing.T) { | ||
|
@@ -427,3 +433,271 @@ func TestPanicRecovery(t *testing.T) { | |
t.Fatalf("unexpected service code") | ||
} | ||
} | ||
|
||
func TestBaseplateHeaderPropagation(t *testing.T) { | ||
expectedHeaders := map[string][]string{ | ||
"x-bp-from-edge": {"true"}, | ||
"x-bp-test": {"foo"}, | ||
} | ||
store, _, err := secrets.NewTestSecrets(context.TODO(), nil) | ||
if err != nil { | ||
t.Fatalf("failed to create test secrets: %v", err) | ||
} | ||
t.Cleanup(func() { | ||
store.Close() | ||
}) | ||
bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{ | ||
Config: baseplate.Config{ | ||
Addr: ":8081", | ||
}, | ||
Store: store, | ||
EdgeContextImpl: ecinterface.Mock(), | ||
}) | ||
downstreamServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{ | ||
Baseplate: bp, | ||
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{ | ||
"/say-hello": { | ||
Name: "say-hello", | ||
Methods: []string{http.MethodGet}, | ||
Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { | ||
for wantKey, wantValue := range expectedHeaders { | ||
if v := request.Header.Values(wantKey); len(v) == 0 { | ||
t.Fatalf("missing header %q", wantKey) | ||
} else if diff := cmp.Diff(v, wantValue, cmpopts.SortSlices(func(a, b string) bool { | ||
return a < b | ||
})); diff != "" { | ||
t.Fatalf("header %q values mismatch (-want +got):\n%s", wantKey, diff) | ||
} | ||
} | ||
return nil | ||
}, | ||
}, | ||
}, | ||
Middlewares: []httpbp.Middleware{ | ||
httpbp.ServerHeaderBPMiddleware("originHTTPBPV0"), | ||
}, | ||
}) | ||
if err != nil { | ||
t.Fatalf("failed to create test downstreamServer: %v", err) | ||
} | ||
t.Cleanup(func() { | ||
downstreamServer.Close() | ||
}) | ||
go downstreamServer.Serve() | ||
|
||
downstreamBaseURL, err := url.Parse("http://" + downstreamServer.Baseplate().GetConfig().Addr + "/") | ||
if err != nil { | ||
t.Fatalf("failed to parse test originServer base URL: %v", err) | ||
} | ||
|
||
downstreamClient, err := httpbp.NewClient( | ||
httpbp.ClientConfig{ | ||
Slug: "downstreamHTTPBPV0", | ||
InternalOnly: true, | ||
}, | ||
withBaseURL(downstreamBaseURL), | ||
) | ||
if err != nil { | ||
t.Fatalf("failed to create test client: %v", err) | ||
} | ||
|
||
originServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{ | ||
Baseplate: bp, | ||
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{ | ||
"/say-hello": { | ||
Name: "say-hello", | ||
Methods: []string{http.MethodGet}, | ||
Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { | ||
for wantKey, wantValue := range expectedHeaders { | ||
if v := request.Header.Values(wantKey); len(v) == 0 { | ||
t.Fatalf("missing header %q", wantKey) | ||
} else if diff := cmp.Diff(v, wantValue, cmpopts.SortSlices(func(a, b string) bool { | ||
return a < b | ||
})); diff != "" { | ||
t.Fatalf("header %q values mismatch (-want +got):\n%s", wantKey, diff) | ||
} | ||
} | ||
|
||
req, err := http.NewRequest( | ||
http.MethodGet, | ||
downstreamBaseURL.JoinPath("say-hello").String(), | ||
nil, | ||
) | ||
if err != nil { | ||
t.Fatalf("creating request: %v", err) | ||
} | ||
|
||
resp, err := downstreamClient.Do(req) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
if resp.StatusCode != http.StatusOK { | ||
t.Fatalf("unexpected status code: %d", resp.StatusCode) | ||
} | ||
|
||
invalidReq, err := http.NewRequest( | ||
http.MethodGet, | ||
downstreamBaseURL.JoinPath("say-hello").String(), | ||
nil, | ||
) | ||
if err != nil { | ||
t.Fatalf("creating request: %v", err) | ||
} | ||
invalidReq.Header.Set("x-bp-test", "bar") | ||
|
||
if _, err := downstreamClient.Do(req); !errors.Is(err, headerbp.ErrNewInternalHeaderNotAllowed) { | ||
t.Fatalf("error mismatch, want %v, got %v", headerbp.ErrNewInternalHeaderNotAllowed, err) | ||
} | ||
return nil | ||
}, | ||
}, | ||
}, | ||
Middlewares: []httpbp.Middleware{ | ||
httpbp.ServerHeaderBPMiddleware("originHTTPBPV0"), | ||
}, | ||
}) | ||
if err != nil { | ||
t.Fatalf("failed to create test originServer: %v", err) | ||
} | ||
t.Cleanup(func() { | ||
originServer.Close() | ||
}) | ||
go originServer.Serve() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can probably add a small sleep after this |
||
|
||
baseURL, err := url.Parse("http://" + originServer.Baseplate().GetConfig().Addr + "/") | ||
if err != nil { | ||
t.Fatalf("failed to parse test originServer base URL: %v", err) | ||
} | ||
|
||
client, err := httpbp.NewClient( | ||
httpbp.ClientConfig{ | ||
Slug: "downstreamHTTPBPV0", | ||
}, | ||
withBaseURL(baseURL), | ||
) | ||
if err != nil { | ||
t.Fatalf("failed to create test client: %v", err) | ||
} | ||
|
||
req, err := http.NewRequest( | ||
http.MethodGet, | ||
baseURL.JoinPath("say-hello").String(), | ||
nil, | ||
) | ||
if err != nil { | ||
t.Fatalf("creating request: %v", err) | ||
} | ||
for name, values := range expectedHeaders { | ||
req.Header.Set(name, values[0]) | ||
} | ||
|
||
resp, err := client.Do(req) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
if resp.StatusCode != http.StatusOK { | ||
t.Fatalf("unexpected status code: %d", resp.StatusCode) | ||
} | ||
} | ||
|
||
func TestBaseplateHeaderPropagation_untrusted(t *testing.T) { | ||
expectedHeaders := map[string][]string{ | ||
"x-bp-from-edge": {"true"}, | ||
"x-bp-test": {"foo"}, | ||
} | ||
store, _, err := secrets.NewTestSecrets(context.TODO(), nil) | ||
if err != nil { | ||
t.Fatalf("failed to create test secrets: %v", err) | ||
} | ||
t.Cleanup(func() { | ||
store.Close() | ||
}) | ||
bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{ | ||
Config: baseplate.Config{ | ||
Addr: ":8081", | ||
}, | ||
Store: store, | ||
EdgeContextImpl: ecinterface.Mock(), | ||
}) | ||
|
||
originServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{ | ||
Baseplate: bp, | ||
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{ | ||
"/say-hello": { | ||
Name: "say-hello", | ||
Methods: []string{http.MethodGet}, | ||
Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { | ||
for wantKey := range expectedHeaders { | ||
if v := request.Header.Values(wantKey); len(v) != 0 { | ||
t.Fatalf("expected no values for header %q, got %+v", wantKey, v) | ||
} | ||
} | ||
|
||
return nil | ||
}, | ||
}, | ||
}, | ||
Middlewares: []httpbp.Middleware{ | ||
httpbp.ServerHeaderBPMiddleware("originHTTPBPV0"), | ||
}, | ||
}) | ||
if err != nil { | ||
t.Fatalf("failed to create test originServer: %v", err) | ||
} | ||
t.Cleanup(func() { | ||
originServer.Close() | ||
}) | ||
go originServer.Serve() | ||
|
||
baseURL, err := url.Parse("http://" + originServer.Baseplate().GetConfig().Addr + "/") | ||
if err != nil { | ||
t.Fatalf("failed to parse test originServer base URL: %v", err) | ||
} | ||
|
||
client, err := httpbp.NewClient( | ||
httpbp.ClientConfig{ | ||
Slug: "downstreamHTTPBPV0", | ||
}, | ||
withBaseURL(baseURL), | ||
) | ||
if err != nil { | ||
t.Fatalf("failed to create test client: %v", err) | ||
} | ||
|
||
req, err := http.NewRequest( | ||
http.MethodGet, | ||
baseURL.JoinPath("say-hello").String(), | ||
nil, | ||
) | ||
if err != nil { | ||
t.Fatalf("creating request: %v", err) | ||
} | ||
req.Header.Set("X-Rddt-Untrusted", "1") | ||
for name, values := range expectedHeaders { | ||
req.Header.Set(name, values[0]) | ||
} | ||
|
||
resp, err := client.Do(req) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
if resp.StatusCode != http.StatusOK { | ||
t.Fatalf("unexpected status code: %d", resp.StatusCode) | ||
} | ||
} | ||
|
||
func withBaseURL(baseURL *url.URL) httpbp.ClientMiddleware { | ||
return func(next http.RoundTripper) http.RoundTripper { | ||
return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { | ||
resolved := req.Clone(req.Context()) | ||
resolved.URL = baseURL.ResolveReference(req.URL) | ||
return next.RoundTrip(resolved) | ||
}) | ||
} | ||
} | ||
|
||
type roundTripperFunc func(req *http.Request) (*http.Response, error) | ||
|
||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { | ||
return f(req) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
// Package headerbp provides the shared code for propagating baseplate headers using server and client middlewares. | ||
// | ||
// It is meant to be used by middlewares for different rpc frameworks like http and grpc, not used directly by services. | ||
// | ||
// It is only meant to propagate headers that the server receives, the client middlewares will return an error if they | ||
// detect a baseplate header in the request being sent. | ||
package headerbp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we don't have a distinct "internal vs non-internal" client type, I was thinking it would probably be good to have some sort of flag that you have to manually set to have HTTP clients automatically propagate headers?