Skip to content

Commit

Permalink
feat: add ssl support to sync service (#1479) (#1501)
Browse files Browse the repository at this point in the history
Adds SSL support to the flagd sync service

---------

Signed-off-by: Alexandra Oberaigner <[email protected]>
Co-authored-by: Simon Schrottner <[email protected]>
Co-authored-by: Todd Baert <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent 9891df2 commit d50fcc8
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 128 deletions.
2 changes: 1 addition & 1 deletion core/pkg/telemetry/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func buildTransportCredentials(_ context.Context, cfg CollectorConfig) (credenti

tlsConfig := &tls.Config{
RootCAs: capool,
MinVersion: tls.VersionTLS13,
MinVersion: tls.VersionTLS12,
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
certs, err := reloader.GetCertificate()
if err != nil {
Expand Down
33 changes: 32 additions & 1 deletion flagd/pkg/service/flag-sync/sync_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sync

import (
"context"
"crypto/tls"
"fmt"
"net"
"slices"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/open-feature/flagd/core/pkg/store"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type ISyncService interface {
Expand All @@ -28,6 +30,8 @@ type SvcConfigurations struct {
Sources []string
Store *store.Flags
ContextValues map[string]any
CertPath string
KeyPath string
}

type Service struct {
Expand All @@ -39,14 +43,41 @@ type Service struct {
startupTracker syncTracker
}

func loadTLSCredentials(certPath string, keyPath string) (credentials.TransportCredentials, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("failed to load key pair from certificate paths '%s' and '%s': %w", certPath, keyPath, err)
}

// Create the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
MinVersion: tls.VersionTLS12,
}

return credentials.NewTLS(config), nil
}

func NewSyncService(cfg SvcConfigurations) (*Service, error) {
l := cfg.Logger
mux, err := NewMux(cfg.Store, cfg.Sources)
if err != nil {
return nil, fmt.Errorf("error initializing multiplexer: %w", err)
}

server := grpc.NewServer()
var server *grpc.Server
if cfg.CertPath != "" && cfg.KeyPath != "" {
tlsCredentials, err := loadTLSCredentials(cfg.CertPath, cfg.KeyPath)
if err != nil {
return nil, fmt.Errorf("failed to load TLS cert and key: %w", err)
}
server = grpc.NewServer(grpc.Creds(tlsCredentials))
} else {
server = grpc.NewServer()
}

syncv1grpc.RegisterFlagSyncServiceServer(server, &syncHandler{
mux: mux,
log: l,
Expand Down
292 changes: 166 additions & 126 deletions flagd/pkg/service/flag-sync/sync_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sync
import (
"context"
"fmt"
"log"
"testing"
"time"

Expand All @@ -14,134 +15,173 @@ import (
)

func TestSyncServiceEndToEnd(t *testing.T) {
// given
port := 18016
store, sources := getSimpleFlagStore()

service, err := NewSyncService(SvcConfigurations{
Logger: logger.NewLogger(nil, false),
Port: uint16(port),
Sources: sources,
Store: store,
})
if err != nil {
t.Fatal("error creating the service: %w", err)
return
testCases := []struct {
certPath string
keyPath string
clientCertPath string
tls bool
wantErr bool
}{
{"./test-cert/server-cert.pem", "./test-cert/server-key.pem", "./test-cert/ca-cert.pem", true, false},
{"", "", "", false, false},
{"./lol/not/a/cert", "./test-cert/server-key.pem", "./test-cert/ca-cert.pem", true, true},
}

ctx, cancelFunc := context.WithCancel(context.Background())
doneChan := make(chan interface{})

go func() {
// error ignored, tests will fail if start is not successful
_ = service.Start(ctx)
close(doneChan)
}()

// trigger manual emits matching sources, so that service can start
for _, source := range sources {
service.Emit(false, source)
}

// when - derive a client for sync service
con, err := grpc.DialContext(ctx, fmt.Sprintf("localhost:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(fmt.Printf("error creating grpc dial ctx: %v", err))
return
}

serviceClient := syncv1grpc.NewFlagSyncServiceClient(con)

// then

// sync flags request
flags, err := serviceClient.SyncFlags(ctx, &v1.SyncFlagsRequest{})
if err != nil {
t.Fatal(fmt.Printf("error from sync request: %v", err))
return
}

syncRsp, err := flags.Recv()
if err != nil {
t.Fatal(fmt.Printf("stream error: %v", err))
return
}

if len(syncRsp.GetFlagConfiguration()) == 0 {
t.Error("expected non empty sync response, but got empty")
}

// validate emits
dataReceived := make(chan interface{})
go func() {
_, err := flags.Recv()
if err != nil {
return
for _, tc := range testCases {
var testTitle string
if tc.tls {
testTitle = "Testing Sync Service with TLS Connection"
} else {
testTitle = "Testing Sync Service without TLS Connection"
}

dataReceived <- nil
}()

// Emit as a resync
service.Emit(true, "A")

select {
case <-dataReceived:
t.Fatal("expected no data as this is a resync")
case <-time.After(1 * time.Second):
break
}

// Emit as a resync
service.Emit(false, "A")

select {
case <-dataReceived:
break
case <-time.After(1 * time.Second):
t.Fatal("expected data but timeout waiting for sync")
}

// fetch all flags
allRsp, err := serviceClient.FetchAllFlags(ctx, &v1.FetchAllFlagsRequest{})
if err != nil {
t.Fatal(fmt.Printf("fetch all error: %v", err))
return
}

if allRsp.GetFlagConfiguration() != syncRsp.GetFlagConfiguration() {
t.Errorf("expected both sync and fetch all responses to be same, but got %s from sync & %s from fetch all",
syncRsp.GetFlagConfiguration(), allRsp.GetFlagConfiguration())
}

// metadata request
metadataRsp, err := serviceClient.GetMetadata(ctx, &v1.GetMetadataRequest{})
if err != nil {
t.Fatal(fmt.Printf("metadata error: %v", err))
return
}

asMap := metadataRsp.GetMetadata().AsMap()

// expect `sources` to be present
if asMap["sources"] == nil {
t.Fatal("expected sources entry in the metadata, but got nil")
}

if asMap["sources"] != "A,B,C" {
t.Fatal("incorrect sources entry in metadata")
}

// validate shutdown from context cancellation
go func() {
cancelFunc()
}()

select {
case <-doneChan:
// exit successful
return
case <-time.After(2 * time.Second):
t.Fatal("service did not exist within sufficient timeframe")
t.Run(testTitle, func(t *testing.T) {
// given
port := 18016
store, sources := getSimpleFlagStore()

service, err := NewSyncService(SvcConfigurations{
Logger: logger.NewLogger(nil, false),
Port: uint16(port),
Sources: sources,
Store: store,
CertPath: tc.certPath,
KeyPath: tc.keyPath,
})

if tc.wantErr {
if err == nil {
t.Fatal("expected error creating the service!")
}
return
} else if err != nil {
t.Fatal("unexpected error creating the service: %w", err)
return
}

ctx, cancelFunc := context.WithCancel(context.Background())
doneChan := make(chan interface{})

go func() {
// error ignored, tests will fail if start is not successful
_ = service.Start(ctx)
close(doneChan)
}()

// trigger manual emits matching sources, so that service can start
for _, source := range sources {
service.Emit(false, source)
}

// when - derive a client for sync service
var con *grpc.ClientConn
if tc.tls {
tlsCredentials, e := loadTLSClientCredentials(tc.clientCertPath)
if e != nil {
log.Fatal("cannot load TLS credentials: ", e)
}
con, err = grpc.Dial(fmt.Sprintf("0.0.0.0:%d", port), grpc.WithTransportCredentials(tlsCredentials))
} else {
con, err = grpc.DialContext(ctx, fmt.Sprintf("localhost:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
}
if err != nil {
t.Fatal(fmt.Printf("error creating grpc dial ctx: %v", err))
return
}

serviceClient := syncv1grpc.NewFlagSyncServiceClient(con)

// then

// sync flags request
flags, err := serviceClient.SyncFlags(ctx, &v1.SyncFlagsRequest{})
if err != nil {
t.Fatal(fmt.Printf("error from sync request: %v", err))
return
}

syncRsp, err := flags.Recv()
if err != nil {
t.Fatal(fmt.Printf("stream error: %v", err))
return
}

if len(syncRsp.GetFlagConfiguration()) == 0 {
t.Error("expected non empty sync response, but got empty")
}

// validate emits
dataReceived := make(chan interface{})
go func() {
_, err := flags.Recv()
if err != nil {
return
}

dataReceived <- nil
}()

// Emit as a resync
service.Emit(true, "A")

select {
case <-dataReceived:
t.Fatal("expected no data as this is a resync")
case <-time.After(1 * time.Second):
break
}

// Emit as a resync
service.Emit(false, "A")

select {
case <-dataReceived:
break
case <-time.After(1 * time.Second):
t.Fatal("expected data but timeout waiting for sync")
}

// fetch all flags
allRsp, err := serviceClient.FetchAllFlags(ctx, &v1.FetchAllFlagsRequest{})
if err != nil {
t.Fatal(fmt.Printf("fetch all error: %v", err))
return
}

if allRsp.GetFlagConfiguration() != syncRsp.GetFlagConfiguration() {
t.Errorf("expected both sync and fetch all responses to be same, but got %s from sync & %s from fetch all",
syncRsp.GetFlagConfiguration(), allRsp.GetFlagConfiguration())
}

// metadata request
metadataRsp, err := serviceClient.GetMetadata(ctx, &v1.GetMetadataRequest{})
if err != nil {
t.Fatal(fmt.Printf("metadata error: %v", err))
return
}

asMap := metadataRsp.GetMetadata().AsMap()

// expect `sources` to be present
if asMap["sources"] == nil {
t.Fatal("expected sources entry in the metadata, but got nil")
}

if asMap["sources"] != "A,B,C" {
t.Fatal("incorrect sources entry in metadata")
}

// validate shutdown from context cancellation
go func() {
cancelFunc()
}()

select {
case <-doneChan:
// exit successful
return
case <-time.After(2 * time.Second):
t.Fatal("service did not exist within sufficient timeframe")
}
})
}
}
Loading

0 comments on commit d50fcc8

Please sign in to comment.