Skip to content

Commit

Permalink
Pass a client credential getter through auth.TLSServer (#43874)
Browse files Browse the repository at this point in the history
  • Loading branch information
espadolini authored Jul 5, 2024
1 parent e260760 commit 350dbe3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
14 changes: 8 additions & 6 deletions lib/auth/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,19 +845,21 @@ func NewTestTLSServer(cfg TestTLSServerConfig) (*TestTLSServer, error) {
return nil, trace.Wrap(err)
}
tlsConfig.Time = cfg.AuthServer.Clock().Now
tlsCert := tlsConfig.Certificates[0]

srv.Listener, err = net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, trace.Wrap(err)
}

srv.TLSServer, err = NewTLSServer(context.Background(), TLSServerConfig{
Listener: srv.Listener,
AccessPoint: srv.AuthServer.AuthServer.Cache,
TLS: tlsConfig,
APIConfig: *srv.APIConfig,
LimiterConfig: *srv.Limiter,
AcceptedUsage: cfg.AcceptedUsage,
Listener: srv.Listener,
AccessPoint: srv.AuthServer.AuthServer.Cache,
TLS: tlsConfig,
GetClientCertificate: func() (*tls.Certificate, error) { return &tlsCert, nil },
APIConfig: *srv.APIConfig,
LimiterConfig: *srv.Limiter,
AcceptedUsage: cfg.AcceptedUsage,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
9 changes: 7 additions & 2 deletions lib/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ type AccessCacheWithEvents interface {
type TLSServerConfig struct {
// Listener is a listener to bind to
Listener net.Listener
// TLS is a base TLS configuration
// TLS is the server TLS configuration.
TLS *tls.Config
// GetClientCertificate returns auth client credentials.
GetClientCertificate func() (*tls.Certificate, error)
// API is API server configuration
APIConfig
// LimiterConfig is limiter config
Expand Down Expand Up @@ -119,6 +121,9 @@ func (c *TLSServerConfig) CheckAndSetDefaults() error {
if len(c.TLS.Certificates) == 0 {
return trace.BadParameter("missing parameter TLS.Certificates")
}
if c.GetClientCertificate == nil {
return trace.BadParameter("missing parameter GetClientCertificate")
}
if c.AccessPoint == nil {
return trace.BadParameter("missing parameter AccessPoint")
}
Expand Down Expand Up @@ -261,7 +266,7 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error)
}

if cfg.PluginRegistry != nil {
if err := cfg.PluginRegistry.RegisterAuthServices(ctx, server.grpcServer); err != nil {
if err := cfg.PluginRegistry.RegisterAuthServices(ctx, server.grpcServer, cfg.GetClientCertificate); err != nil {
return nil, trace.Wrap(err)
}
}
Expand Down
11 changes: 7 additions & 4 deletions lib/plugin/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ package plugin

import (
"context"
"crypto/tls"

"github.com/gravitational/trace"
)

type getCertFunc = func() (*tls.Certificate, error)

// Plugin describes interfaces of the teleport core plugin
type Plugin interface {
// GetName returns plugin name
Expand All @@ -33,7 +36,7 @@ type Plugin interface {
// RegisterAuthWebHandlers registers new methods with the Auth Web Handler
RegisterAuthWebHandlers(service interface{}) error
// RegisterAuthServices registers new services on the AuthServer
RegisterAuthServices(ctx context.Context, server interface{}) error
RegisterAuthServices(ctx context.Context, server any, getClientCert getCertFunc) error
}

// Registry is the plugin registry
Expand All @@ -47,7 +50,7 @@ type Registry interface {
// RegisterAuthWebHandlers registers Teleport Auth web handlers
RegisterAuthWebHandlers(handler interface{}) error
// RegisterAuthServices registers Teleport AuthServer services
RegisterAuthServices(ctx context.Context, server interface{}) error
RegisterAuthServices(ctx context.Context, server any, getClientCert getCertFunc) error
}

// NewRegistry creates an instance of the Registry
Expand Down Expand Up @@ -109,9 +112,9 @@ func (r *registry) RegisterAuthWebHandlers(handler interface{}) error {
return nil
}

func (r *registry) RegisterAuthServices(ctx context.Context, server interface{}) error {
func (r *registry) RegisterAuthServices(ctx context.Context, server any, getClientCert getCertFunc) error {
for _, p := range r.plugins {
if err := p.RegisterAuthServices(ctx, server); err != nil {
if err := p.RegisterAuthServices(ctx, server, getClientCert); err != nil {
return trace.Wrap(err, "plugin %v failed to register", p.GetName())
}
}
Expand Down
4 changes: 3 additions & 1 deletion lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2239,7 +2239,9 @@ func (process *TeleportProcess) initAuthService() error {
authMetrics := &auth.Metrics{GRPCServerLatency: cfg.Metrics.GRPCServerLatency}

tlsServer, err := auth.NewTLSServer(process.ExitContext(), auth.TLSServerConfig{
TLS: tlsConfig,
TLS: tlsConfig,
GetClientCertificate: connector.ClientGetCertificate,

APIConfig: *apiConf,
LimiterConfig: cfg.Auth.Limiter,
AccessPoint: authServer.Cache,
Expand Down

0 comments on commit 350dbe3

Please sign in to comment.