Skip to content

Commit

Permalink
update key algorithm for agentless dials (#47241)
Browse files Browse the repository at this point in the history
* update key algorithm for agentless

* fix tests
  • Loading branch information
nklaassen authored Oct 8, 2024
1 parent 06e2b5b commit 235b348
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 96 deletions.
26 changes: 18 additions & 8 deletions lib/agentless/agentless.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ import (

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/services"
)

Expand All @@ -49,9 +50,14 @@ type CertGenerator interface {
GenerateOpenSSHCert(ctx context.Context, req *proto.OpenSSHCertRequest) (*proto.OpenSSHCert, error)
}

// LocalAccessPoint should be a cache of the local cluster auth preference.
type LocalAccessPoint interface {
GetAuthPreference(context.Context) (types.AuthPreference, error)
}

// SignerCreator returns an [ssh.Signer] that can be used to authenticate
// with an agentless node.
type SignerCreator func(ctx context.Context, certGen CertGenerator) (ssh.Signer, error)
type SignerCreator func(ctx context.Context, localAccessPoint LocalAccessPoint, certGen CertGenerator) (ssh.Signer, error)

// SignerFromSSHCertificate returns a function that attempts to
// create a [ssh.Signer] for the Identity in the provided [ssh.Certificate]
Expand All @@ -60,7 +66,7 @@ type SignerCreator func(ctx context.Context, certGen CertGenerator) (ssh.Signer,
// passed into the returned function must be connected to the same cluster
// as the target node.
func SignerFromSSHCertificate(cert *ssh.Certificate, authClient AuthProvider, clusterName, teleportUser string) SignerCreator {
return func(ctx context.Context, certGen CertGenerator) (ssh.Signer, error) {
return func(ctx context.Context, localAccessPoint LocalAccessPoint, certGen CertGenerator) (ssh.Signer, error) {
u, err := authClient.GetUser(ctx, teleportUser, false)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -97,7 +103,7 @@ func SignerFromSSHCertificate(cert *ssh.Certificate, authClient AuthProvider, cl
roles: roles,
ttl: ttl,
}
signer, err := createAuthSigner(ctx, params, certGen)
signer, err := createAuthSigner(ctx, params, localAccessPoint, certGen)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -113,7 +119,7 @@ func SignerFromSSHCertificate(cert *ssh.Certificate, authClient AuthProvider, cl
// passed into the returned function must be connected to the same cluster
// as the target node.
func SignerFromAuthzContext(authzCtx *authz.Context, authClient AuthProvider, clusterName string) SignerCreator {
return func(ctx context.Context, certGen CertGenerator) (ssh.Signer, error) {
return func(ctx context.Context, localAccessPoint LocalAccessPoint, certGen CertGenerator) (ssh.Signer, error) {
u, ok := authzCtx.User.(*types.UserV2)
if !ok {
return nil, trace.BadParameter("unsupported user type %T", u)
Expand All @@ -139,7 +145,7 @@ func SignerFromAuthzContext(authzCtx *authz.Context, authClient AuthProvider, cl
roles: roles,
ttl: time.Until(identity.Expires),
}
signer, err := createAuthSigner(ctx, params, certGen)
signer, err := createAuthSigner(ctx, params, localAccessPoint, certGen)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -174,9 +180,13 @@ type certParams struct {

// createAuthSigner creates a [ssh.Signer] that is signed with
// OpenSSH CA and can be used to authenticate to agentless nodes.
func createAuthSigner(ctx context.Context, params certParams, certGen CertGenerator) (ssh.Signer, error) {
func createAuthSigner(ctx context.Context, params certParams, localAccessPoint LocalAccessPoint, certGen CertGenerator) (ssh.Signer, error) {
// generate a new key pair
priv, err := native.GeneratePrivateKey()
key, err := cryptosuites.GenerateKey(ctx, cryptosuites.GetCurrentSuiteFromAuthPreference(localAccessPoint), cryptosuites.UserSSH)
if err != nil {
return nil, trace.Wrap(err)
}
priv, err := keys.NewSoftwarePrivateKey(key)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
46 changes: 24 additions & 22 deletions lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ type SiteGetter interface {
GetSite(clusterName string) (reversetunnelclient.RemoteSite, error)
}

// RemoteClusterGetter provides access to remote cluster resources
type RemoteClusterGetter interface {
// LocalAccessPoint provides access to remote cluster resources
type LocalAccessPoint interface {
// GetRemoteCluster returns a remote cluster by name
GetRemoteCluster(ctx context.Context, clusterName string) (types.RemoteCluster, error)
// GetAuthPreference returns the local cluster auth preference.
GetAuthPreference(context.Context) (types.AuthPreference, error)
}

// RouterConfig contains all the dependencies required
Expand All @@ -122,8 +124,8 @@ type RouterConfig struct {
ClusterName string
// Log is the logger to use
Log *logrus.Entry
// AccessPoint is the proxy cache
RemoteClusterGetter RemoteClusterGetter
// LocalAccessPoint is the proxy cache
LocalAccessPoint LocalAccessPoint
// SiteGetter allows looking up sites
SiteGetter SiteGetter
// TracerProvider allows tracers to be created
Expand All @@ -143,8 +145,8 @@ func (c *RouterConfig) CheckAndSetDefaults() error {
return trace.BadParameter("ClusterName must be provided")
}

if c.RemoteClusterGetter == nil {
return trace.BadParameter("RemoteClusterGetter must be provided")
if c.LocalAccessPoint == nil {
return trace.BadParameter("LocalAccessPoint must be provided")
}

if c.SiteGetter == nil {
Expand All @@ -165,13 +167,13 @@ func (c *RouterConfig) CheckAndSetDefaults() error {
// Router is used by the proxy to establish connections to both
// nodes and other clusters.
type Router struct {
clusterName string
log *logrus.Entry
clusterGetter RemoteClusterGetter
localSite reversetunnelclient.RemoteSite
siteGetter SiteGetter
tracer oteltrace.Tracer
serverResolver serverResolverFn
clusterName string
log *logrus.Entry
localAccessPoint LocalAccessPoint
localSite reversetunnelclient.RemoteSite
siteGetter SiteGetter
tracer oteltrace.Tracer
serverResolver serverResolverFn
}

// NewRouter creates and returns a Router that is populated
Expand All @@ -187,13 +189,13 @@ func NewRouter(cfg RouterConfig) (*Router, error) {
}

return &Router{
clusterName: cfg.ClusterName,
log: cfg.Log,
clusterGetter: cfg.RemoteClusterGetter,
localSite: localSite,
siteGetter: cfg.SiteGetter,
tracer: cfg.TracerProvider.Tracer("Router"),
serverResolver: cfg.serverResolver,
clusterName: cfg.ClusterName,
log: cfg.Log,
localAccessPoint: cfg.LocalAccessPoint,
localSite: localSite,
siteGetter: cfg.SiteGetter,
tracer: cfg.TracerProvider.Tracer("Router"),
serverResolver: cfg.serverResolver,
}, nil
}

Expand Down Expand Up @@ -277,7 +279,7 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.
if err != nil {
return nil, trace.Wrap(err)
}
sshSigner, err = signer(ctx, client)
sshSigner, err = signer(ctx, r.localAccessPoint, client)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -366,7 +368,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check
return nil, utils.OpaqueAccessDenied(err)
}

rc, err := r.clusterGetter.GetRemoteCluster(ctx, clusterName)
rc, err := r.localAccessPoint.GetRemoteCluster(ctx, clusterName)
if err != nil {
return nil, utils.OpaqueAccessDenied(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/proxy/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ func TestRouter_DialHost(t *testing.T) {
agentGetter := func() (teleagent.Agent, error) {
return nil, nil
}
createSigner := func(_ context.Context, _ agentless.CertGenerator) (ssh.Signer, error) {
createSigner := func(_ context.Context, _ agentless.LocalAccessPoint, _ agentless.CertGenerator) (ssh.Signer, error) {
key, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.Ed25519)
if err != nil {
return nil, err
Expand Down
10 changes: 5 additions & 5 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4420,11 +4420,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
var proxyRouter *proxy.Router
if !process.Config.Proxy.DisableReverseTunnel {
router, err := proxy.NewRouter(proxy.RouterConfig{
ClusterName: clusterName,
Log: process.log.WithField(teleport.ComponentKey, "router"),
RemoteClusterGetter: accessPoint,
SiteGetter: tsrv,
TracerProvider: process.TracingProvider,
ClusterName: clusterName,
Log: process.log.WithField(teleport.ComponentKey, "router"),
LocalAccessPoint: accessPoint,
SiteGetter: tsrv,
TracerProvider: process.TracingProvider,
})
if err != nil {
return trace.Wrap(err)
Expand Down
70 changes: 25 additions & 45 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ import (
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/x11"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/cert"
)

// teleportTestUser is additional user used for tests
Expand Down Expand Up @@ -1259,25 +1258,6 @@ func TestAllowedLabels(t *testing.T) {
}
}

// TestKeyAlgorithms makes sure Teleport does not accept invalid user
// certificates. The main check is the certificate algorithms.
func TestKeyAlgorithms(t *testing.T) {
t.Parallel()
f := newFixtureWithoutDiskBasedLogging(t)

_, ellipticSigner, err := cert.CreateEllipticCertificate("foo", ssh.UserCert)
require.NoError(t, err)

sshConfig := &ssh.ClientConfig{
User: f.user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(ellipticSigner)},
HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()),
}

_, err = tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig)
require.Error(t, err)
}

func TestInvalidSessionID(t *testing.T) {
t.Parallel()
f := newFixtureWithoutDiskBasedLogging(t)
Expand Down Expand Up @@ -1475,11 +1455,11 @@ func TestProxyRoundRobin(t *testing.T) {
defer reverseTunnelServer.Close()

router, err := libproxy.NewRouter(libproxy.RouterConfig{
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
RemoteClusterGetter: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
LocalAccessPoint: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -1615,11 +1595,11 @@ func TestProxyDirectAccess(t *testing.T) {
nodeClient, _ := newNodeClient(t, f.testSrv)

router, err := libproxy.NewRouter(libproxy.RouterConfig{
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
RemoteClusterGetter: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
LocalAccessPoint: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -2332,11 +2312,11 @@ func TestParseSubsystemRequest(t *testing.T) {
require.NoError(t, err)

router, err := libproxy.NewRouter(libproxy.RouterConfig{
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
RemoteClusterGetter: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
LocalAccessPoint: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -2593,11 +2573,11 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) {
nodeClient, _ := newNodeClient(t, f.testSrv)

router, err := libproxy.NewRouter(libproxy.RouterConfig{
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
RemoteClusterGetter: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
LocalAccessPoint: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -3014,11 +2994,11 @@ func TestHostUserCreationProxy(t *testing.T) {
defer reverseTunnelServer.Close()

router, err := libproxy.NewRouter(libproxy.RouterConfig{
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
RemoteClusterGetter: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
ClusterName: f.testSrv.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
LocalAccessPoint: proxyClient,
SiteGetter: reverseTunnelServer,
TracerProvider: tracing.NoopProvider(),
})
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/transport/transportv1/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func newServer(t *testing.T, cfg ServerConfig) testPack {
}

func fakeSigner(authzCtx *authz.Context, clusterName string) agentless.SignerCreator {
return func(_ context.Context, _ agentless.CertGenerator) (ssh.Signer, error) {
return func(_ context.Context, _ agentless.LocalAccessPoint, _ agentless.CertGenerator) (ssh.Signer, error) {
return nil, nil
}
}
Expand Down
20 changes: 10 additions & 10 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,11 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite {
s.proxyTunnel = revTunServer

router, err := proxy.NewRouter(proxy.RouterConfig{
ClusterName: s.server.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
RemoteClusterGetter: s.proxyClient,
SiteGetter: revTunServer,
TracerProvider: tracing.NoopProvider(),
ClusterName: s.server.ClusterName(),
Log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
LocalAccessPoint: s.proxyClient,
SiteGetter: revTunServer,
TracerProvider: tracing.NoopProvider(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -8207,11 +8207,11 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula

clustername := authServer.ClusterName()
router, err := proxy.NewRouter(proxy.RouterConfig{
ClusterName: clustername,
Log: log.WithField(teleport.ComponentKey, "router"),
RemoteClusterGetter: client,
SiteGetter: revTunServer,
TracerProvider: tracing.NoopProvider(),
ClusterName: clustername,
Log: log.WithField(teleport.ComponentKey, "router"),
LocalAccessPoint: client,
SiteGetter: revTunServer,
TracerProvider: tracing.NoopProvider(),
})
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 235b348

Please sign in to comment.