Skip to content

Commit

Permalink
Juju 7464/reuse dial logic (#1529)
Browse files Browse the repository at this point in the history
* extract get addresses from controllers

* reuse dial logic in ssh server
  • Loading branch information
SimoneDutto authored Jan 21, 2025
1 parent 8de5428 commit c0b0b11
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 72 deletions.
22 changes: 20 additions & 2 deletions internal/jimm/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/canonical/jimm/v3/internal/dbmodel"
"github.com/canonical/jimm/v3/internal/errors"
"github.com/canonical/jimm/v3/internal/openfga"
"github.com/canonical/jimm/v3/internal/rpc"
)

// IdentityManager provides a means to fetch an identity from the identity service.
Expand Down Expand Up @@ -58,12 +59,29 @@ type sshManager struct {
func (s *sshManager) PublicKeyHandler(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) {
zapctx.Info(ctx, "PublicKeyHandler")
if ok, err := s.sshKeyManager.VerifyPublicKey(ctx, claimUser, key); !ok || err != nil {
return nil, fmt.Errorf("cannot verify key for user %s: %s", claimUser, err.Error())
return nil, errors.E(err, "cannot verify key for user")
}
user, err := s.identityManager.FetchIdentity(ctx, claimUser)
if err != nil {
zapctx.Info(ctx, fmt.Sprintf("cannot find user %s", claimUser))
return nil, fmt.Errorf("cannot find user %s: %s", claimUser, err.Error())
return nil, errors.E(err, "cannot find user")
}
return user, nil
}

// ResolveAddressesFromModelUUID is the method to resolve the address of the controller to contact given the model UUID.
func (s *sshManager) ResolveAddressesFromModelUUID(ctx context.Context, modelUUID string) ([]string, error) {
zapctx.Info(ctx, "ResolveAddressesFromModelUUID")

model, err := s.modelManager.GetModel(ctx, modelUUID)
if err != nil {
return nil, errors.E(err, "cannot find model")
}

addrs, _ := rpc.GetAddressesAndTLSConfig(ctx, &model.Controller)
if len(addrs) == 0 {
return nil, errors.E(err, "cannot find addresses for model's controller")
}

return addrs, nil
}
128 changes: 128 additions & 0 deletions internal/rpc/dial _test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package rpc_test

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net/http"
"testing"
Expand Down Expand Up @@ -65,3 +67,129 @@ func TestDialIPv6(t *testing.T) {
_, err = rpc.Dial(ctx, &controller, names.ModelTag{}, "", http.Header{})
c.Assert(err, qt.Equals, nil)
}

func TestGetAddressesAndTLSConfig(t *testing.T) {
c := qt.New(t)
ctx := context.Background()

tests := []struct {
name string
controller dbmodel.Controller
expectedAddrs []string
expectedTLSCfg *tls.Config
}{
{
name: "With CACertificate and PublicAddress",
controller: dbmodel.Controller{
CACertificate: "test-ca-cert",
PublicAddress: "public.address.com",
TLSHostname: "tls.hostname.com",
},
expectedAddrs: []string{"public.address.com"},
expectedTLSCfg: &tls.Config{
RootCAs: x509.NewCertPool(),
ServerName: "tls.hostname.com",
MinVersion: tls.VersionTLS12,
},
},
{
name: "With IPv4 Address",
controller: dbmodel.Controller{
Addresses: [][]jujuparams.HostPort{
{
{
Address: jujuparams.Address{
Value: "192.168.1.1",
Type: "ipv4",
},
Port: 8080,
},
{
Address: jujuparams.Address{
Value: "192.168.1.1",
Type: "ipv4",
Scope: "non-exisiting-scope",
},
Port: 8080,
},
},
},
},
expectedAddrs: []string{"192.168.1.1:8080"},
expectedTLSCfg: nil,
},
{
name: "With IPv6 Address",
controller: dbmodel.Controller{
Addresses: [][]jujuparams.HostPort{
{
{
Address: jujuparams.Address{
Value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
Type: "ipv6",
},
Port: 8080,
},
{
Address: jujuparams.Address{
Value: "2001:0db8:85a3:0000:0000:8a2e:0370:7335",
Type: "ipv6",
Scope: string(network.ScopePublic),
},
Port: 8080,
},
},
},
},
expectedAddrs: []string{"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080", "[2001:0db8:85a3:0000:0000:8a2e:0370:7335]:8080"},
expectedTLSCfg: nil,
},
{
name: "With Both IPv4 and IPv6 Addresses",
controller: dbmodel.Controller{
Addresses: [][]jujuparams.HostPort{
{
{
Address: jujuparams.Address{
Value: "192.168.1.1",
Type: "ipv4",
},
Port: 8080,
},
{
Address: jujuparams.Address{
Value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
Type: "ipv6",
},
Port: 8080,
},
},
},
},
expectedAddrs: []string{"192.168.1.1:8080", "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080"},
expectedTLSCfg: nil,
},
{
name: "No Addresses",
controller: dbmodel.Controller{
Addresses: [][]jujuparams.HostPort{},
},
expectedAddrs: nil,
expectedTLSCfg: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addrs, tlsCfg := rpc.GetAddressesAndTLSConfig(ctx, &tt.controller)
c.Assert(addrs, qt.DeepEquals, tt.expectedAddrs)
if tt.expectedTLSCfg != nil {
c.Assert(tlsCfg, qt.Not(qt.IsNil))
c.Assert(tlsCfg.ServerName, qt.Equals, tt.expectedTLSCfg.ServerName)
c.Assert(tlsCfg.MinVersion, qt.Equals, tt.expectedTLSCfg.MinVersion)
} else {
c.Assert(tlsCfg, qt.IsNil)
}
})
}
}
45 changes: 25 additions & 20 deletions internal/rpc/dial.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Canonical.
// Copyright 2025 Canonical.

package rpc

Expand All @@ -15,7 +15,6 @@ import (
"github.com/gorilla/websocket"
"github.com/juju/juju/core/network"
"github.com/juju/names/v5"
"github.com/juju/zaputil"
"github.com/juju/zaputil/zapctx"
"go.uber.org/zap"

Expand Down Expand Up @@ -54,10 +53,8 @@ func (d Dialer) DialWebsocket(ctx context.Context, url string, headers http.Head
return conn, nil
}

// Dial connects to the controller/model and returns a raw websocket
// that can be used as is.
// It accepts the endpoints to dial, normally /api or /commands.
func Dial(ctx context.Context, ctl *dbmodel.Controller, modelTag names.ModelTag, finalPath string, headers http.Header) (*websocket.Conn, error) {
// GetAddressesAndTLSConfig returns the addresses and TLS configuration for the given controller.
func GetAddressesAndTLSConfig(ctx context.Context, ctl *dbmodel.Controller) ([]string, *tls.Config) {
var tlsConfig *tls.Config
if ctl.CACertificate != "" {
cp := x509.NewCertPool()
Expand All @@ -71,21 +68,14 @@ func Dial(ctx context.Context, ctl *dbmodel.Controller, modelTag names.ModelTag,
MinVersion: tls.VersionTLS12,
}
}
dialer := Dialer{
TLSConfig: tlsConfig,
}

if ctl.PublicAddress != "" {
// If there is a public-address configured it is almost
// certainly the one we want to use, try it first.
conn, err := dialer.DialWebsocket(ctx, websocketURL(ctl.PublicAddress, modelTag, finalPath), headers)
if err != nil {
zapctx.Error(ctx, "failed to dial public address", zaputil.Error(err))
} else {
return conn, nil
}
// certainly the one we want to use.
return []string{ctl.PublicAddress}, tlsConfig
}
var urls []string

var addrs []string
for _, hps := range ctl.Addresses {
for _, hp := range hps {
if maybeReachable(hp.Scope) {
Expand All @@ -95,12 +85,27 @@ func Dial(ctx context.Context, ctl *dbmodel.Controller, modelTag names.ModelTag,
} else {
ip = fmt.Sprintf("%s:%d", hp.Value, hp.Port)
}
urls = append(urls, websocketURL(ip, modelTag, finalPath))
addrs = append(addrs, ip)
}
}
}
zapctx.Debug(ctx, "Dialling all URLs", zap.Any("urls", urls))
conn, err := dialAll(ctx, &dialer, urls, headers)
return addrs, tlsConfig
}

// Dial connects to the controller/model and returns a raw websocket
// that can be used as is.
// It accepts the endpoints to dial, normally /api or /commands.
func Dial(ctx context.Context, ctl *dbmodel.Controller, modelTag names.ModelTag, finalPath string, headers http.Header) (*websocket.Conn, error) {
addrs, tlsConfig := GetAddressesAndTLSConfig(ctx, ctl)
dialer := Dialer{
TLSConfig: tlsConfig,
}
var websocketUrls []string
for _, addr := range addrs {
websocketUrls = append(websocketUrls, websocketURL(addr, modelTag, finalPath))
}
zapctx.Debug(ctx, "Dialling all URLs", zap.Any("urls", websocketUrls))
conn, err := dialAll(ctx, &dialer, websocketUrls, headers)
if err != nil {
return nil, err
}
Expand Down
41 changes: 41 additions & 0 deletions internal/ssh/dial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2025 Canonical.

package ssh

import (
goerr "errors"
"fmt"
"net"
"time"

gossh "golang.org/x/crypto/ssh"

"github.com/canonical/jimm/v3/internal/errors"
)

// dialControllerSSHServer dials the controller ssh server, trying the addresses sequentially and returning a go ssh client.
func dialControllerSSHServer(addrs []string, destPort uint32) (*gossh.Client, error) {
var client *gossh.Client
var err error
var errs []error
for _, addr := range addrs {
dest := net.JoinHostPort(addr, fmt.Sprint(destPort))
client, err = gossh.Dial("tcp", dest, &gossh.ClientConfig{
//nolint:gosec // this will be removed once we handle hostkeys
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{
gossh.PasswordCallback(func() (secret string, err error) {
return "jwt", nil
}),
},
Timeout: 5 * time.Second,
})
if err != nil {
errs = append(errs, err)
}
}
if client == nil {
return nil, errors.E(goerr.Join(errs...), "cannot dial controller")
}
return client, nil
}
44 changes: 16 additions & 28 deletions internal/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@ const defaultMaxConcurrentConnections = 100

type publicKeySSHUserKey struct{}

// SSHAuthorizer is the interface to authorize users via public key.
type SSHAuthorizer interface {
// SSHManager is the interface to enable the ssh server to operate. Performing public key verification and
// resolving addresses from model uuids.
type SSHManager interface {
// PublicKeyHandler is the method to verify the public key of the user. It returns a user if successful.
PublicKeyHandler(ctx context.Context, claimUser string, key []byte) (*openfga.User, error)
}

// TODO(simonedutto): this is going to change to reuse as much as our dial logic as we possibly can.
// SSHResolver is the interface to resolve controller's addresses.
type SSHResolver interface {
// AddrFromModelUUID is the method to resolve the address of the controller to contact given the model UUID.
AddrFromModelUUID(ctx context.Context, user *openfga.User, modelTag names.ModelTag) (string, error)
// ResolveAddressesFromModelUUID is the method to resolve the address of the controller to contact given the model UUID.
ResolveAddressesFromModelUUID(ctx context.Context, modelUUID string) ([]string, error)
}

// forwardMessage is the struct holding the information about the jump message received by the ssh client.
Expand Down Expand Up @@ -63,21 +60,21 @@ type Server struct {
}

// NewJumpServer creates the jump server struct.
func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (Server, error) {
func NewJumpServer(ctx context.Context, config Config, sshManager SSHManager) (Server, error) {
zapctx.Info(ctx, "NewJumpServer")

if sshResolver == nil {
return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.")
if sshManager == nil {
return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil ssh manager.")
}
config = setConfigDefaults(config)
server := Server{
Server: &ssh.Server{
Addr: fmt.Sprintf(":%s", config.Port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(sshResolver),
"direct-tcpip": directTCPIPHandler(sshManager),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
user, err := sshManager.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
if err != nil {
zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err))
return false
Expand Down Expand Up @@ -122,7 +119,7 @@ func (srv Server) ListenAndServe() error {
return srv.Serve(ln)
}

func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
func directTCPIPHandler(sshManager SSHManager) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := forwardMessage{}
k := newChan.ExtraData()
Expand All @@ -139,29 +136,20 @@ func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gos
return
}
modelTag := names.NewModelTag(d.DestAddr)
user, err := fetchAndAuthorizeUser(ctx, modelTag)
// user is now ignored, but it will be needed for the jwt auth next-up.
_, err := fetchAndAuthorizeUser(ctx, modelTag)
if err != nil {
rejectConnectionAndLogError(ctx, newChan, err.Error(), err)
return
}
addr, err := sshResolver.AddrFromModelUUID(ctx, user, modelTag)
addrs, err := sshManager.ResolveAddressesFromModelUUID(ctx, modelTag.Id())
if err != nil {
rejectConnectionAndLogError(ctx, newChan, "failed to resolve address from model uuid", err)
return
}
dest := net.JoinHostPort(addr, fmt.Sprint(d.DestPort))
// this is temporary. The way we dial to the controller will heavily change.
client, err := gossh.Dial("tcp", dest, &gossh.ClientConfig{
//nolint:gosec // this will be removed once we handle hostkeys
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{
gossh.PasswordCallback(func() (secret string, err error) {
return "jwt", nil
}),
},
})
client, err := dialControllerSSHServer(addrs, d.DestPort)
if err != nil {
rejectConnectionAndLogError(ctx, newChan, fmt.Sprintf("failed to connect to %s: %v", dest, err), err)
rejectConnectionAndLogError(ctx, newChan, fmt.Sprintf("failed to dial controller ssh: %v", err), err)
return
}

Expand Down
Loading

0 comments on commit c0b0b11

Please sign in to comment.