Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redis mTLS support #3429

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions bindings/redis/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,19 @@ metadata:
type: bool
required: false
description: |
If the Redis instance supports TLS with public certificates, can be
configured to be enabled or disabled.
If the Redis instance supports TLS; can be configured to be enabled or disabled.
example: "true"
default: "false"
- name: clientCert
required: false
description: Client certificate for Redis host. No Default. Can be secretKeyRef to use a secret reference
example: ""
type: string
- name: clientKey
required: false
description: Client key for Redis host. No Default. Can be secretKeyRef to use a secret reference
example: ""
type: string
- name: redisMaxRetries
type: number
required: false
Expand Down
30 changes: 18 additions & 12 deletions common/component/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func ParseClientFromProperties(properties map[string]string, componentType metad
switch componentType {
case metadata.PubSubType:
if val, ok := properties[processingTimeoutKey]; ok && val != "" {
if processingTimeoutMs, err := strconv.ParseUint(val, 10, 64); err == nil {
if processingTimeoutMs, parseErr := strconv.ParseUint(val, 10, 64); parseErr == nil {
// because of legacy reasons, we need to interpret a number as milliseconds
// the library would default to seconds otherwise
settings.ProcessingTimeout = time.Duration(processingTimeoutMs) * time.Millisecond
Expand All @@ -149,7 +149,7 @@ func ParseClientFromProperties(properties map[string]string, componentType metad
}

if val, ok := properties[redeliverIntervalKey]; ok && val != "" {
if redeliverIntervalMs, err := strconv.ParseUint(val, 10, 64); err == nil {
if redeliverIntervalMs, parseErr := strconv.ParseUint(val, 10, 64); parseErr == nil {
// because of legacy reasons, we need to interpret a number as milliseconds
// the library would default to seconds otherwise
settings.RedeliverInterval = time.Duration(redeliverIntervalMs) * time.Millisecond
Expand All @@ -159,11 +159,16 @@ func ParseClientFromProperties(properties map[string]string, componentType metad
}

var c RedisClient
newClientFunc := newV8Client
if settings.Failover {
c = newV8FailoverClient(settings)
} else {
c = newV8Client(settings)
newClientFunc = newV8FailoverClient
}

c, err = newClientFunc(settings)
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}

version, versionErr := GetServerVersion(c)
c.Close() // close the client to avoid leaking connections

Expand All @@ -175,17 +180,18 @@ func ParseClientFromProperties(properties map[string]string, componentType metad
// if the server version is >= 7, we will use the v9 client
useNewClient = true
}

if useNewClient {
newClientFunc = newV9Client
if settings.Failover {
return newV9FailoverClient(settings), settings, nil
newClientFunc = newV9FailoverClient
}
return newV9Client(settings), settings, nil
} else {
if settings.Failover {
return newV8FailoverClient(settings), settings, nil
}
return newV8Client(settings), settings, nil
}
c, err = newClientFunc(settings)
if err != nil {
return nil, nil, fmt.Errorf("redis client configuration error: %w", err)
}
return c, settings, nil
}

func ClientHasJSONSupport(c RedisClient) bool {
Expand Down
6 changes: 6 additions & 0 deletions common/component/redis/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ const (
idleCheckFrequency = "idleCheckFrequency"
maxConnAge = "maxConnAge"
enableTLS = "enableTLS"
clientCert = "clientCert"
clientKey = "clientKey"
failover = "failover"
sentinelMasterName = "sentinelMasterName"
)
Expand All @@ -51,6 +53,8 @@ func getFakeProperties() map[string]string {
username: "fakeUsername",
redisType: "node",
enableTLS: "true",
clientCert: "fakeCert",
clientKey: "fakeKey",
dialTimeout: "5s",
readTimeout: "5s",
writeTimeout: "50000",
Expand Down Expand Up @@ -84,6 +88,8 @@ func TestParseRedisMetadata(t *testing.T) {
assert.Equal(t, fakeProperties[username], m.Username)
assert.Equal(t, fakeProperties[redisType], m.RedisType)
assert.True(t, m.EnableTLS)
assert.Equal(t, fakeProperties[clientCert], m.ClientCert)
assert.Equal(t, fakeProperties[clientKey], m.ClientKey)
assert.Equal(t, 5*time.Second, time.Duration(m.DialTimeout))
assert.Equal(t, 5*time.Second, time.Duration(m.ReadTimeout))
assert.Equal(t, 50000*time.Millisecond, time.Duration(m.WriteTimeout))
Expand Down
17 changes: 17 additions & 0 deletions common/component/redis/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
package redis

import (
"crypto/tls"
"fmt"
"strconv"
"time"
Expand Down Expand Up @@ -78,6 +79,10 @@ type Settings struct {
// A flag to enables TLS by setting InsecureSkipVerify to true
EnableTLS bool `mapstructure:"enableTLS"`

// Client certificate and key
ClientCert string `mapstructure:"clientCert"`
ClientKey string `mapstructure:"clientKey"`

// == state only properties ==
TTLInSeconds *int `mapstructure:"ttlInSeconds" mdonly:"state"`
QueryIndexes string `mapstructure:"queryIndexes" mdonly:"state"`
Expand Down Expand Up @@ -106,6 +111,18 @@ func (s *Settings) Decode(in interface{}) error {
return nil
}

func (s *Settings) SetCertificate(fn func(cert *tls.Certificate)) error {
if s.ClientCert == "" || s.ClientKey == "" {
return nil
}
cert, err := tls.X509KeyPair([]byte(s.ClientCert), []byte(s.ClientKey))
if err != nil {
return err
}
fn(&cert)
return nil
}

type Duration time.Duration

func (r *Duration) DecodeString(value string) error {
Expand Down
42 changes: 42 additions & 0 deletions common/component/redis/settings_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package redis

import (
"crypto/tls"
"testing"

"github.com/stretchr/testify/require"
)

func TestSettings(t *testing.T) {
t.Run("test set certificate, missing data", func(t *testing.T) {
var c *tls.Certificate
settings := &Settings{}
err := settings.SetCertificate(func(cert *tls.Certificate) {
c = cert
})
require.NoError(t, err)
require.Nil(t, c)
})

t.Run("test set certificate, invalid data", func(t *testing.T) {
settings := &Settings{
ClientCert: "foo",
ClientKey: "bar",
}
err := settings.SetCertificate(nil)
require.Error(t, err)
})

t.Run("test set certificate, valid data", func(t *testing.T) {
var c *tls.Certificate
settings := &Settings{
ClientCert: "-----BEGIN CERTIFICATE-----\nMIIC+jCCAeKgAwIBAgIUcVW+K5LM+rLj80F0XWG0YXh6Hq4wDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJTXlSZWRpc0NBMB4XDTI0MDUyNzExMzQyMloXDTI1MDUy\nNzExMzQyMlowFjEUMBIGA1UEAwwLUmVkaXNDbGllbnQwggEiMA0GCSqGSIb3DQEB\nAQUAA4IBDwAwggEKAoIBAQCpSQKejofOA42jBSsfDVE5FdSxGEU+ktpqcp2CBZ8Z\nD9YLW4H6JTMU2JzPQLrwd5oF+FBdVDYkpunFs8lPGlvR7KMzXv130PSJ4ieSAEwJ\n7ocxKvqYpYmyFsPUHHOVJEYxlUK0nd8KvBw7OKbdk5tL/gEDoHKHJOZpiDmcMFqw\nlMfNrGGlsgZjcWvnZfEa3Q4D7hD3iNYJbLT9ETZZF36V5I8sXrexnlzN4EXyCZuF\nV9M/+5V+JwYamvpHrTiCR9oDVrHJytjSyvyysW7PKmLjs9C12opo6LBHoKEidCKV\nNyicgBfkvvHnlRDaANmELJpX5vNuW9lsEG+Rxiyf47rtAgMBAAGjQjBAMB0GA1Ud\nDgQWBBQal6ypaK/1V0SGwfLefKrIUkl2jTAfBgNVHSMEGDAWgBQ4NNFfx71nrJ19\nF9/rtg8TAqZbjjANBgkqhkiG9w0BAQsFAAOCAQEAEqN0Ja31Tkv6oHp35ksFFM2X\npej6BljJH61X4jIGJ7qFacG8OxkpA3mspF8NZx4SG1NeZVC5eYMixxqDDDxz5cli\nLVaxP9T3TiIU/lYpqnGBTaKJ6q4ngwTSTdZ9Xp1cVhYx80F9SK3l3TeLAUCVl/HK\neepV2BOfr/B/sK1gVTOcmxRHh8piPEcY49WCDA6+zd6UYXiaYtegMAomaoPA76Kf\nNRcNkWrm5sbyRKijw2AmjRxFGH2lTdCGg1CNXwQhCPrsGQKuDh5JPN/wqxF2GjSK\ncpuArGhwuCLf4kLoEB+0VgFwaKqrUwsWyD9P+vBh5kNf76B+NkBtp+19AzSasw==\n-----END CERTIFICATE-----",
ClientKey: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCpSQKejofOA42j\nBSsfDVE5FdSxGEU+ktpqcp2CBZ8ZD9YLW4H6JTMU2JzPQLrwd5oF+FBdVDYkpunF\ns8lPGlvR7KMzXv130PSJ4ieSAEwJ7ocxKvqYpYmyFsPUHHOVJEYxlUK0nd8KvBw7\nOKbdk5tL/gEDoHKHJOZpiDmcMFqwlMfNrGGlsgZjcWvnZfEa3Q4D7hD3iNYJbLT9\nETZZF36V5I8sXrexnlzN4EXyCZuFV9M/+5V+JwYamvpHrTiCR9oDVrHJytjSyvyy\nsW7PKmLjs9C12opo6LBHoKEidCKVNyicgBfkvvHnlRDaANmELJpX5vNuW9lsEG+R\nxiyf47rtAgMBAAECggEAChFC/B362pgggLzqcxbOKUhwlTWdzJpcLe9yCYz/CLUF\n5DgFc1RqBMfbD4JIe8uJF+jMErjS3Xwls/G8u50UL9hUXlY8WbdOC7Ms6kRlQUPz\nu0tUiuZxWWt8Ku2kPA7js8guJuKqpI9KWIVGey/vkOXito4AsaPSph0JXA4OHqkp\nqdT6Q+ZIiSoCopaubmcToYr28g/g2K37IPDRaStCMEm2KFEL+D1VzCfucKgXrZdK\np9M9e5WoUR4J9p4yxNAwafXvdRxZq40SRJTB+3x172Epg2vhVJLAs0b8ArI9z57+\nOodWfCiebfQQ4EKIZLY0Djr5kYRv1yrwKjViktv4WwKBgQDcTB2KILsDSw6KP+c5\nzNm6WLP6SQ8mlXj5r0VWDDfKqo7w8UTzU+PMB1rDDmjLV0c7PKFadB6LsMlXUoKr\n+/iRoea9UNOwdy9JQ2DesjCFWbu2KigOO06X2LxckHBUNe0VbvCVXZ9hz3d8PkpA\n/Ib8m55Q2uY/iM3jtufF9XT5qwKBgQDEuHW2MXp7Kt9+3EAmy9yBggEkX+4MjhKM\ntPSLNWnxOz6ZiF0KZfwdZ1TmQxBJ0CPSoKuf6DrdFY6Rm2ZLkDKhRdw9LkN2hvne\nnIkteVSLzok0a1ryHYENXRRUOVN6s3E7X+b/uFiR7bemx/Tzrpc2yBJk3gJm+ffP\nVbFjug/1xwKBgQDats8VFg3V5SzYYT2GCzWXZv243dQm8HudGUBzf8nccp1b5Y4Z\nLw6YwCyCP8oXJ93WmAlyLpstASXEhmypp45PuDfHeXnSV2IhEL4aGztFCaPt5cjC\n6GrNIydPly+Oy8NIZk6BXOQiTcJJHebGwnCaVz5E9C9ooMAY9r0BswKh5QKBgFGt\nsRo7wvoe2/slYfF51Y1kOCstNX67ApKvk5W1UM6bZauDxfXKUHq467RLhhjPtf//\nPCNB3ibri22Dk16ueYcipYY1jkdJVbgLUJ2z8dm2oJtGM9WxUGMHEajCwJmCpfIc\nKKJmnUfB5u31ugvvotNZEOIWl/K/uRe6IdQhbf0DAoGAdw7GEmzs6Rdq4LTs9cfQ\nQaXtx8vAXAU5X0StEtJWIlDnRbEvV8NaiPEwXfcuvx+Y1nvPSXuaq+YhxgrlBviL\nuKt2V2ONIrUsfRuPePSHjUipMR92A/8hPhOxw2SCYemtzniYJbeulUhIAWbQplhT\nIeadrRMdaWmKA1OGeIZRRw4=\n-----END PRIVATE KEY-----",
}
err := settings.SetCertificate(func(cert *tls.Certificate) {
c = cert
})
require.NoError(t, err)
require.NotNil(t, c)
})
}
34 changes: 26 additions & 8 deletions common/component/redis/v8client.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ func (c v8Client) TTLResult(ctx context.Context, key string) (time.Duration, err
return c.client.TTL(writeCtx, key).Result()
}

func newV8FailoverClient(s *Settings) RedisClient {
func newV8FailoverClient(s *Settings) (RedisClient, error) {
if s == nil {
return nil
return nil, nil
}
opts := &v8.FailoverOptions{
DB: s.DB,
Expand All @@ -345,6 +345,12 @@ func newV8FailoverClient(s *Settings) RedisClient {
opts.TLSConfig = &tls.Config{
InsecureSkipVerify: s.EnableTLS,
}
err := s.SetCertificate(func(cert *tls.Certificate) {
opts.TLSConfig.Certificates = []tls.Certificate{*cert}
})
if err != nil {
return nil, err
}
}

if s.RedisType == ClusterType {
Expand All @@ -355,20 +361,20 @@ func newV8FailoverClient(s *Settings) RedisClient {
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}

return v8Client{
client: v8.NewFailoverClient(opts),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}

func newV8Client(s *Settings) RedisClient {
func newV8Client(s *Settings) (RedisClient, error) {
if s == nil {
return nil
return nil, nil
}
if s.RedisType == ClusterType {
options := &v8.ClusterOptions{
Expand All @@ -393,14 +399,20 @@ func newV8Client(s *Settings) RedisClient {
options.TLSConfig = &tls.Config{
InsecureSkipVerify: s.EnableTLS,
}
err := s.SetCertificate(func(cert *tls.Certificate) {
options.TLSConfig.Certificates = []tls.Certificate{*cert}
})
if err != nil {
return nil, err
}
}

return v8Client{
client: v8.NewClusterClient(options),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}

options := &v8.Options{
Expand All @@ -427,14 +439,20 @@ func newV8Client(s *Settings) RedisClient {
options.TLSConfig = &tls.Config{
InsecureSkipVerify: s.EnableTLS,
}
err := s.SetCertificate(func(cert *tls.Certificate) {
options.TLSConfig.Certificates = []tls.Certificate{*cert}
})
if err != nil {
return nil, err
}
}

return v8Client{
client: v8.NewClient(options),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}

func ClientFromV8Client(client v8.UniversalClient) RedisClient {
Expand Down
38 changes: 29 additions & 9 deletions common/component/redis/v9client.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ func (c v9Client) TTLResult(ctx context.Context, key string) (time.Duration, err
return c.client.TTL(writeCtx, key).Result()
}

func newV9FailoverClient(s *Settings) RedisClient {
func newV9FailoverClient(s *Settings) (RedisClient, error) {
if s == nil {
return nil
return nil, nil
}
opts := &v9.FailoverOptions{
DB: s.DB,
Expand All @@ -346,6 +346,12 @@ func newV9FailoverClient(s *Settings) RedisClient {
opts.TLSConfig = &tls.Config{
InsecureSkipVerify: s.EnableTLS,
}
err := s.SetCertificate(func(cert *tls.Certificate) {
opts.TLSConfig.Certificates = []tls.Certificate{*cert}
})
if err != nil {
return nil, err
}
}

if s.RedisType == ClusterType {
Expand All @@ -356,21 +362,22 @@ func newV9FailoverClient(s *Settings) RedisClient {
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}

return v9Client{
client: v9.NewFailoverClient(opts),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}

func newV9Client(s *Settings) RedisClient {
func newV9Client(s *Settings) (RedisClient, error) {
if s == nil {
return nil
return nil, nil
}

if s.RedisType == ClusterType {
options := &v9.ClusterOptions{
Addrs: strings.Split(s.Host, ","),
Expand All @@ -391,17 +398,24 @@ func newV9Client(s *Settings) RedisClient {
}
/* #nosec */
if s.EnableTLS {
/* #nosec */
options.TLSConfig = &tls.Config{
InsecureSkipVerify: s.EnableTLS,
}
err := s.SetCertificate(func(cert *tls.Certificate) {
options.TLSConfig.Certificates = []tls.Certificate{*cert}
})
if err != nil {
return nil, err
}
}

return v9Client{
client: v9.NewClusterClient(options),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}

options := &v9.Options{
Expand All @@ -423,17 +437,23 @@ func newV9Client(s *Settings) RedisClient {
ContextTimeoutEnabled: true,
}

/* #nosec */
if s.EnableTLS {
/* #nosec */
options.TLSConfig = &tls.Config{
InsecureSkipVerify: s.EnableTLS,
}
err := s.SetCertificate(func(cert *tls.Certificate) {
options.TLSConfig.Certificates = []tls.Certificate{*cert}
})
if err != nil {
return nil, err
}
}

return v9Client{
client: v9.NewClient(options),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}, nil
}
Loading
Loading