Skip to content

Commit

Permalink
Test for signature algorithm on startup rather than per request (#376)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaare Hoff Skovgaard <[email protected]>
  • Loading branch information
andsens and kastermester authored Dec 28, 2023
1 parent 47759c9 commit 5313cad
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
15 changes: 11 additions & 4 deletions auth_server/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type ServerConfig struct {

publicKey libtrust.PublicKey
privateKey libtrust.PrivateKey
sigAlg string
}

type LetsEncryptConfig struct {
Expand All @@ -86,6 +87,7 @@ type TokenConfig struct {

publicKey libtrust.PublicKey
privateKey libtrust.PrivateKey
sigAlg string
}

// TLSCipherSuitesValues maps CipherSuite names as strings to the actual values
Expand Down Expand Up @@ -335,7 +337,7 @@ func validate(c *Config) error {
return nil
}

func loadCertAndKey(certFile string, keyFile string) (pk libtrust.PublicKey, prk libtrust.PrivateKey, err error) {
func loadCertAndKey(certFile string, keyFile string) (pk libtrust.PublicKey, prk libtrust.PrivateKey, sigAlg string, err error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
Expand All @@ -349,6 +351,11 @@ func loadCertAndKey(certFile string, keyFile string) (pk libtrust.PublicKey, prk
return
}
prk, err = libtrust.FromCryptoPrivateKey(cert.PrivateKey)
_, sigAlg, errStr := prk.Sign(strings.NewReader("dummy"), 0)
if errStr != nil {
err = fmt.Errorf("failed to sign: %s", errStr)
return
}
return
}

Expand All @@ -370,7 +377,7 @@ func LoadConfig(fileName string) (*Config, error) {
if c.Server.CertFile == "" || c.Server.KeyFile == "" {
return nil, fmt.Errorf("failed to load server cert and key: both were not provided")
}
c.Server.publicKey, c.Server.privateKey, err = loadCertAndKey(c.Server.CertFile, c.Server.KeyFile)
c.Server.publicKey, c.Server.privateKey, c.Server.sigAlg, err = loadCertAndKey(c.Server.CertFile, c.Server.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load server cert and key: %s", err)
}
Expand All @@ -382,15 +389,15 @@ func LoadConfig(fileName string) (*Config, error) {
if c.Token.CertFile == "" || c.Token.KeyFile == "" {
return nil, fmt.Errorf("failed to load token cert and key: both were not provided")
}
c.Token.publicKey, c.Token.privateKey, err = loadCertAndKey(c.Token.CertFile, c.Token.KeyFile)
c.Token.publicKey, c.Token.privateKey, c.Token.sigAlg, err = loadCertAndKey(c.Token.CertFile, c.Token.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load token cert and key: %s", err)
}
tokenConfigured = true
}

if serverConfigured && !tokenConfigured {
c.Token.publicKey, c.Token.privateKey = c.Server.publicKey, c.Server.privateKey
c.Token.publicKey, c.Token.privateKey, c.Token.sigAlg = c.Server.publicKey, c.Server.privateKey, c.Server.sigAlg
tokenConfigured = true
}

Expand Down
9 changes: 2 additions & 7 deletions auth_server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,9 @@ func (as *AuthServer) CreateToken(ar *authRequest, ares []authzResult) (string,
now := time.Now().Unix()
tc := &as.config.Token

// Sign something dummy to find out which algorithm is used.
_, sigAlg, err := tc.privateKey.Sign(strings.NewReader("dummy"), 0)
if err != nil {
return "", fmt.Errorf("failed to sign: %s", err)
}
header := token.Header{
Type: "JWT",
SigningAlg: sigAlg,
SigningAlg: tc.sigAlg,
KeyID: tc.publicKey.KeyID(),
}
headerJSON, err := json.Marshal(header)
Expand Down Expand Up @@ -423,7 +418,7 @@ func (as *AuthServer) CreateToken(ar *authRequest, ares []authzResult) (string,
payload := fmt.Sprintf("%s%s%s", joseBase64UrlEncode(headerJSON), token.TokenSeparator, joseBase64UrlEncode(claimsJSON))

sig, sigAlg2, err := tc.privateKey.Sign(strings.NewReader(payload), 0)
if err != nil || sigAlg2 != sigAlg {
if err != nil || sigAlg2 != tc.sigAlg {
return "", fmt.Errorf("failed to sign token: %s", err)
}
glog.Infof("New token for %s %+v: %s", *ar, ar.Labels, claimsJSON)
Expand Down

0 comments on commit 5313cad

Please sign in to comment.