Skip to content

Commit

Permalink
add headerbp middleware to thriftbp
Browse files Browse the repository at this point in the history
  • Loading branch information
pacejackson committed Jan 10, 2025
1 parent c52cdf1 commit 496bc9b
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 5 deletions.
44 changes: 44 additions & 0 deletions thriftbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/apache/thrift/lib/go/thrift"
"github.com/avast/retry-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/reddit/baseplate.go/internal/headerbp"

"github.com/reddit/baseplate.go/breakerbp"
"github.com/reddit/baseplate.go/ecinterface"
Expand Down Expand Up @@ -162,6 +163,7 @@ func BaseplateDefaultClientMiddlewares(args DefaultClientMiddlewareArgs) []thrif
BaseplateErrorWrapper,
thrift.ExtractIDLExceptionClientMiddleware,
SetDeadlineBudget,
ClientHeaderBPMiddleware(args.ServiceSlug, args.ClientName),
)
return middlewares
}
Expand Down Expand Up @@ -396,3 +398,45 @@ func getClientError(result thrift.TStruct, err error) error {
}
return thrift.ExtractExceptionFromResult(result)
}

// ClientHeaderBPMiddleware is a middleware that forwards baseplate headers from the context to the outgoing request.
//
// It will also verify that you are not adding any headers with the baseplate header prefix, if you try to send
// a header with the baseplate header prefix it will return an error.
func ClientHeaderBPMiddleware(service, client string) thrift.ClientMiddleware {
return func(next thrift.TClient) thrift.TClient {
return thrift.WrappedTClient{
Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) {
outgoing := thrift.GetWriteHeaderList(ctx)
for _, k := range outgoing {
if err := headerbp.CheckClientHeader(k,
headerbp.WithThriftClient(service, client, method),
); err != nil {
return thrift.ResponseMeta{}, err
}
}

var toAdd map[string]string
headerbp.SetOutgoingHeaders(
ctx,
headerbp.WithThriftClient(service, client, method),
headerbp.WithHeaderSetter(func(k, v string) {
if toAdd == nil {
toAdd = make(map[string]string)
}
toAdd[k] = v
outgoing = append(outgoing, k)
}),
)

if len(toAdd) > 0 {
for k, v := range toAdd {
ctx = thrift.SetHeader(ctx, k, v)
}
ctx = thrift.SetWriteHeaderList(ctx, outgoing)
}
return next.Call(ctx, method, args, result)
},
}
}
}
1 change: 0 additions & 1 deletion thriftbp/client_middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ func TestRetry(t *testing.T) {
defer cancel()

store := newSecretsStore(t)
defer store.Close()

c := &counter{}
handler := BaseplateService{}
Expand Down
4 changes: 0 additions & 4 deletions thriftbp/client_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ func TestThriftHostnameHeader(t *testing.T) {
defer cancel()

store := newSecretsStore(t)
defer store.Close()

handler := thriftHostnameHandler{}
server, err := thrifttest.NewBaseplateServer(thrifttest.ServerConfig{
Expand Down Expand Up @@ -313,9 +312,6 @@ func TestUDS(t *testing.T) {
t.Cleanup(cancel)

store := newSecretsStore(t)
t.Cleanup(func() {
store.Close()
})

handler := thriftHostnameHandler{}
server, err := thriftbp.NewServer(thriftbp.ServerConfig{
Expand Down
3 changes: 3 additions & 0 deletions thriftbp/fixtures_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@ func newSecretsStore(t testing.TB) *secrets.Store {
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
store.Close()
})
return store
}
50 changes: 50 additions & 0 deletions thriftbp/server_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"errors"
"fmt"
"log/slog"
"slices"
"strconv"
"sync"
"time"

"github.com/apache/thrift/lib/go/thrift"
"github.com/prometheus/client_golang/prometheus"
"github.com/reddit/baseplate.go/internal/headerbp"

"github.com/reddit/baseplate.go/ecinterface"
"github.com/reddit/baseplate.go/errorsbp"
Expand Down Expand Up @@ -79,6 +81,7 @@ func BaseplateDefaultProcessorMiddlewares(args DefaultProcessorMiddlewaresArgs)
InjectEdgeContext(args.EdgeContextImpl),
ReportPayloadSizeMetrics(0),
PrometheusServerMiddleware,
ServerHeaderBPMiddleware(),
}
}

Expand Down Expand Up @@ -450,3 +453,50 @@ func PrometheusServerMiddleware(method string, next thrift.TProcessorFunction) t
}
return thrift.WrappedTProcessorFunction{Wrapped: process}
}

// ServerHeaderBPMiddleware is a middleware that extracts baseplate headers from the incoming request and adds them to the context.
func ServerHeaderBPMiddleware() thrift.ProcessorMiddleware {
return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction {
return thrift.WrappedTProcessorFunction{
Wrapped: func(ctx context.Context, seqID int32, in, out thrift.TProtocol) (bool, thrift.TException) {
readHeaderList := thrift.GetReadHeaderList(ctx)
if len(readHeaderList) == 0 {
return next.Process(ctx, seqID, in, out)
}

// check both the lower case and the http canonical case since thrift.GetHeader is case sensitive
var untrusted bool
if _, ok := thrift.GetHeader(ctx, headerbp.IsUntrustedRequestHeaderLower); ok {
untrusted = true
} else if _, ok := thrift.GetHeader(ctx, headerbp.IsUntrustedRequestHeaderCanonicalHTTP); ok {
untrusted = true
}
if untrusted {
var cleared bool
for i := 0; i < len(readHeaderList); i++ {
k := readHeaderList[i]
if headerbp.IsBaseplateHeader(k) {
cleared = true
readHeaderList = slices.Delete(readHeaderList, i, i+1)
thrift.SetHeader(ctx, k, "")
}
}
if cleared {
ctx = thrift.SetReadHeaderList(ctx, readHeaderList)
}
return next.Process(ctx, seqID, in, out)
}

headers := headerbp.NewIncomingHeaders(
headerbp.WithThriftService("", name),
)
for _, k := range readHeaderList {
v, _ := thrift.GetHeader(ctx, k)
headers.RecordHeader(k, v)
}
ctx = headers.SetOnContext(ctx)
return next.Process(ctx, seqID, in, out)
},
}
}
}
140 changes: 140 additions & 0 deletions thriftbp/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package thriftbp_test

import (
"context"
"errors"
"fmt"
"testing"

"github.com/apache/thrift/lib/go/thrift"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/reddit/baseplate.go/ecinterface"
baseplatethrift "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate"
"github.com/reddit/baseplate.go/internal/headerbp"
"github.com/reddit/baseplate.go/thriftbp"
"github.com/reddit/baseplate.go/thriftbp/thrifttest"
)

type headerPropagationVerificationServic struct {
want map[string]string
wantUnset []string

client func() baseplatethrift.BaseplateServiceV2
}

func (s *headerPropagationVerificationServic) IsHealthy(ctx context.Context, _ *baseplatethrift.IsHealthyRequest) (bool, error) {
var errs []error
got := make(map[string]string, len(s.want))
for k := range s.want {
got[k], _ = thrift.GetHeader(ctx, k)
}
if diff := cmp.Diff(s.want, got, cmpopts.EquateEmpty()); diff != "" {
errs = append(errs, fmt.Errorf("header mismatch (-want +got): %s", diff))
}

var unwantedHeaders []string
for _, k := range s.wantUnset {
if _, ok := thrift.GetHeader(ctx, k); ok {
unwantedHeaders = append(unwantedHeaders, k)
}
}
if len(unwantedHeaders) > 0 {
errs = append(errs, fmt.Errorf("unwanted headers: %v", unwantedHeaders))
}

if err := errors.Join(errs...); err != nil {
return false, err
}

outgoingCtx := setHeader(ctx, "x-bp-test", "bar")
if _, err := s.client().IsHealthy(outgoingCtx, &baseplatethrift.IsHealthyRequest{}); !errors.Is(err, headerbp.ErrNewInternalHeaderNotAllowed) {
return false, fmt.Errorf("error mismatch, want %v, got %v", headerbp.ErrNewInternalHeaderNotAllowed, err)
}
return true, nil
}

type echoService struct{}

func (s *echoService) IsHealthy(ctx context.Context, req *baseplatethrift.IsHealthyRequest) (bool, error) {
return true, nil
}

func TestHeaderPropagation(t *testing.T) {
store := newSecretsStore(t)

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

ecImpl := ecinterface.Mock()

downstreamProcessor := baseplatethrift.NewBaseplateServiceV2Processor(&echoService{})
downstreamServer, err := thrifttest.NewBaseplateServer(thrifttest.ServerConfig{
Processor: downstreamProcessor,
SecretStore: store,
EdgeContextImpl: ecImpl,
})
if err != nil {
t.Fatal(err)
}
downstreamServer.Start(ctx)

originProcessor := baseplatethrift.NewBaseplateServiceV2Processor(&headerPropagationVerificationServic{
want: map[string]string{
"x-bp-test": "foo",
},
client: func() baseplatethrift.BaseplateServiceV2 {
return baseplatethrift.NewBaseplateServiceV2Client(downstreamServer.ClientPool.TClient())
},
})
server, err := thrifttest.NewBaseplateServer(thrifttest.ServerConfig{
Processor: originProcessor,
SecretStore: store,
})
if err != nil {
t.Fatal(err)
}
server.Start(ctx)

clientCfg := thriftbp.ClientPoolConfig{
ServiceSlug: thrifttest.DefaultServiceSlug,
Addr: server.Baseplate().GetConfig().Addr,
InitialConnections: thrifttest.InitialClientConnections,
MaxConnections: thrifttest.DefaultClientMaxConnections,
ConnectTimeout: thrifttest.DefaultClientConnectTimeout,
SocketTimeout: thrifttest.DefaultClientSocketTimeout,
EdgeContextImpl: ecImpl,
ClientName: "header-check",
}
// we have to use a custom pool to avoid using the default middleware which will block baseplate headers
pool, err := thriftbp.NewCustomClientPoolWithContext(
ctx,
clientCfg,
thriftbp.SingleAddressGenerator(clientCfg.Addr),
thrift.NewTHeaderProtocolFactoryConf(clientCfg.ToTConfiguration()),
)
if err != nil {
server.Close()
t.Fatalf("error creating client pool: %v", err)
}
client := baseplatethrift.NewBaseplateServiceV2Client(pool.TClient())
ctx = setHeader(ctx, "x-bp-test", "foo")
got, err := client.IsHealthy(ctx, &baseplatethrift.IsHealthyRequest{
Probe: baseplatethrift.IsHealthyProbePtr(baseplatethrift.IsHealthyProbe_READINESS),
})

if err != nil {
t.Errorf("expected no error, got %v", err)
}

const want = true
if got != want {
t.Errorf("success mismatch, want %v, got %v", want, got)
}
}

func setHeader(ctx context.Context, key, value string) context.Context {
ctx = thrift.SetHeader(ctx, key, value)
return thrift.SetWriteHeaderList(ctx, append(thrift.GetWriteHeaderList(ctx), key))
}

0 comments on commit 496bc9b

Please sign in to comment.