Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add headerbp integration #669

Merged
merged 10 commits into from
Jan 21, 2025
29 changes: 29 additions & 0 deletions httpbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/prometheus/client_golang/prometheus"

"github.com/reddit/baseplate.go/breakerbp"
"github.com/reddit/baseplate.go/internal/headerbp"
//lint:ignore SA1019 This library is internal only, not actually deprecated
"github.com/reddit/baseplate.go/internalv2compat"
"github.com/reddit/baseplate.go/retrybp"
Expand Down Expand Up @@ -88,6 +89,11 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien
if config.CircuitBreaker != nil {
defaults = append([]ClientMiddleware{CircuitBreaker(*config.CircuitBreaker)}, defaults...)
}

// only add the middleware to forward baseplate headers if the client is for internal calls
if config.InternalOnly {
defaults = append(defaults, ClientHeaderBPMiddleware(config.Slug))
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't have a distinct "internal vs non-internal" client type, I was thinking it would probably be good to have some sort of flag that you have to manually set to have HTTP clients automatically propagate headers?

middleware = append(middleware, defaults...)

return &http.Client{
Expand Down Expand Up @@ -349,3 +355,26 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware {
})
}
}

// ClientHeaderBPMiddleware is a middleware that forwards baseplate headers from the context to the outgoing request.
//
// If it detects any new baseplate headers set on the request, it will reject the request and return an error.
func ClientHeaderBPMiddleware(client string) ClientMiddleware {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func ClientHeaderBPMiddleware(client string) ClientMiddleware {
func ClientHeaderBaseplateMiddleware(client string) ClientMiddleware {

🔕 maybe best spelled out since we have Baseplate more often in function or method names than BP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

headerbp is the name of the actual library so I figured it was better to keep it consistent with that?

return func(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
for k := range req.Header {
if err := headerbp.CheckClientHeader(k,
headerbp.WithHTTPClient("", client, ""),
); err != nil {
return nil, err
}
Comment on lines +365 to +370
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for k := range req.Header {
if err := headerbp.CheckClientHeader(k,
headerbp.WithHTTPClient("", client, ""),
); err != nil {
return nil, err
}
for key := range req.Header {
if err := headerbp.CheckClientHeader(
key,
headerbp.WithHTTPClient("", client, ""),
); err != nil {
return nil, err
}

🔕 I find this slightly more readable 🙂

}
headerbp.SetOutgoingHeaders(
req.Context(),
headerbp.WithHTTPClient("", client, ""),
headerbp.WithHeaderSetter(req.Header.Set),
)
return next.RoundTrip(req)
})
}
}
1 change: 1 addition & 0 deletions httpbp/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ClientConfig struct {
MaxConnections int `yaml:"maxConnections"`
CircuitBreaker *breakerbp.Config `yaml:"circuitBreaker"`
RetryOptions []retry.Option
InternalOnly bool
}

// Validate checks ClientConfig for any missing or erroneous values.
Expand Down
46 changes: 46 additions & 0 deletions httpbp/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/reddit/baseplate.go/ecinterface"
"github.com/reddit/baseplate.go/errorsbp"
"github.com/reddit/baseplate.go/internal/headerbp"
//lint:ignore SA1019 This library is internal only, not actually deprecated
"github.com/reddit/baseplate.go/internalv2compat"
"github.com/reddit/baseplate.go/log"
Expand Down Expand Up @@ -517,3 +518,48 @@ func (rr *responseRecorder) WriteHeader(code int) {
rr.ResponseWriter.WriteHeader(code)
rr.responseCode = code
}

type untrustedHeadersKey struct{}

func setUntrustedHeaders(ctx context.Context, h map[string]string) context.Context {
return context.WithValue(ctx, untrustedHeadersKey{}, h)
}

func GetUntrustedBaseplateHeaders(ctx context.Context) (map[string]string, bool) {
h, ok := ctx.Value(untrustedHeadersKey{}).(map[string]string)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a big fan of always checking if the assertion was successful, though I got the impression that if we tightly control it, it's okay to not check to make the API more user-friendly 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that really does anything here though

return h, ok
}

// ServerHeaderBPMiddleware is a middleware that extracts baseplate headers from the incoming request and adds them to the context.
//
// If the request is flagged as untrusted, it will remove the baseplate headers from the request and add them to the
// context. These can be retrieved using GetUntrustedBaseplateHeaders.
func ServerHeaderBPMiddleware(service string) Middleware {
return func(name string, next HandlerFunc) HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) {
if r.Header.Get(headerbp.IsUntrustedRequestHeaderCanonicalHTTP) != "" {
untrusted := make(map[string]string)
for k, v := range r.Header {
if headerbp.IsBaseplateHeader(k) {
if len(v) > 0 {
untrusted[strings.ToLower(k)] = v[0]
}
r.Header.Del(k)
}
}
ctx = setUntrustedHeaders(ctx, untrusted)
return next(ctx, w, r)
}
headers := headerbp.NewIncomingHeaders(
headerbp.WithHTTPService(service, name),
)
for k, v := range r.Header {
if len(v) > 0 {
headers.RecordHeader(k, v[0])
}
}
ctx = headers.SetOnContext(ctx)
return next(ctx, w, r.WithContext(ctx))
}
}
}
Loading
Loading