Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two PKI improvements: #5134

Merged
merged 1 commit into from
Aug 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions builtin/logical/pki/crl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package pki

import (
"crypto/x509"
"testing"

"github.com/hashicorp/vault/api"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)

func TestBackend_CRL_EnableDisable(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()

client := cluster.Cores[0].Client
var err error
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
})

resp, err := client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"ttl": "40h",
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
caSerial := resp.Data["serial_number"]

_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"allow_bare_domains": true,
"allow_subdomains": true,
"allowed_domains": "foobar.com",
"generate_lease": true,
})
if err != nil {
t.Fatal(err)
}

var serials = make(map[int]string)
for i := 0; i < 6; i++ {
resp, err := client.Logical().Write("pki/issue/test", map[string]interface{}{
"common_name": "test.foobar.com",
})
if err != nil {
t.Fatal(err)
}
serials[i] = resp.Data["serial_number"].(string)
}

test := func(num int) {
resp, err := client.Logical().Read("pki/cert/crl")
if err != nil {
t.Fatal(err)
}
crlPem := resp.Data["certificate"].(string)
certList, err := x509.ParseCRL([]byte(crlPem))
if err != nil {
t.Fatal(err)
}
lenList := len(certList.TBSCertList.RevokedCertificates)
if lenList != num {
t.Fatalf("expected %d, found %d", num, lenList)
}
}

revoke := func(num int) {
resp, err = client.Logical().Write("pki/revoke", map[string]interface{}{
"serial_number": serials[num],
})
if err != nil {
t.Fatal(err)
}

resp, err = client.Logical().Write("pki/revoke", map[string]interface{}{
"serial_number": caSerial,
})
if err == nil {
t.Fatal("expected error")
}
}

toggle := func(disabled bool) {
_, err = client.Logical().Write("pki/config/crl", map[string]interface{}{
"disable": disabled,
})
if err != nil {
t.Fatal(err)
}
}

test(0)
revoke(0)
revoke(1)
test(2)
toggle(true)
test(0)
revoke(2)
revoke(3)
test(0)
toggle(false)
test(4)
revoke(4)
revoke(5)
test(6)
toggle(true)
test(0)
toggle(false)
test(6)
}
71 changes: 52 additions & 19 deletions builtin/logical/pki/crl_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import (
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"strings"
"time"

"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
)
Expand All @@ -30,6 +33,21 @@ func revokeCert(ctx context.Context, b *backend, req *logical.Request, serial st
return nil, nil
}

signingBundle, caErr := fetchCAInfo(ctx, req)
switch caErr.(type) {
case errutil.UserError:
return logical.ErrorResponse(fmt.Sprintf("could not fetch the CA certificate: %s", caErr)), nil
case errutil.InternalError:
return nil, fmt.Errorf("error fetching CA certificate: %s", caErr)
}
if signingBundle == nil {
return nil, errors.New("CA info not found")
}
colonSerial := strings.Replace(strings.ToLower(serial), "-", ":", -1)
if colonSerial == certutil.GetHexFormatted(signingBundle.Certificate.SerialNumber.Bytes(), ":") {
return logical.ErrorResponse("adding CA to CRL is not allowed"), nil
}

alreadyRevoked := false
var revInfo revocationInfo

Expand Down Expand Up @@ -73,7 +91,9 @@ func revokeCert(ctx context.Context, b *backend, req *logical.Request, serial st
return nil, fmt.Errorf("got a nil certificate")
}

if cert.NotAfter.Before(time.Now()) {
// Add a little wiggle room because leases are stored with a second
// granularity
if cert.NotAfter.Before(time.Now().Add(2 * time.Second)) {
return nil, nil
}

Expand All @@ -100,7 +120,7 @@ func revokeCert(ctx context.Context, b *backend, req *logical.Request, serial st

}

crlErr := buildCRL(ctx, b, req)
crlErr := buildCRL(ctx, b, req, false)
switch crlErr.(type) {
case errutil.UserError:
return logical.ErrorResponse(fmt.Sprintf("Error during CRL building: %s", crlErr)), nil
Expand All @@ -121,14 +141,39 @@ func revokeCert(ctx context.Context, b *backend, req *logical.Request, serial st

// Builds a CRL by going through the list of revoked certificates and building
// a new CRL with the stored revocation times and serial numbers.
func buildCRL(ctx context.Context, b *backend, req *logical.Request) error {
revokedSerials, err := req.Storage.List(ctx, "revoked/")
func buildCRL(ctx context.Context, b *backend, req *logical.Request, forceNew bool) error {
crlInfo, err := b.CRL(ctx, req.Storage)
if err != nil {
return errutil.InternalError{Err: fmt.Sprintf("error fetching list of revoked certs: %s", err)}
return errutil.InternalError{Err: fmt.Sprintf("error fetching CRL config information: %s", err)}
}

revokedCerts := []pkix.RevokedCertificate{}
crlLifetime := b.crlLifetime
var revokedCerts []pkix.RevokedCertificate
var revInfo revocationInfo
var revokedSerials []string

if crlInfo != nil {
if crlInfo.Expiry != "" {
crlDur, err := time.ParseDuration(crlInfo.Expiry)
if err != nil {
return errutil.InternalError{Err: fmt.Sprintf("error parsing CRL duration of %s", crlInfo.Expiry)}
}
crlLifetime = crlDur
}

if crlInfo.Disable {
if !forceNew {
return nil
}
goto WRITE
}
}

revokedSerials, err = req.Storage.List(ctx, "revoked/")
if err != nil {
return errutil.InternalError{Err: fmt.Sprintf("error fetching list of revoked certs: %s", err)}
}

for _, serial := range revokedSerials {
revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial)
if err != nil {
Expand Down Expand Up @@ -167,6 +212,7 @@ func buildCRL(ctx context.Context, b *backend, req *logical.Request) error {
revokedCerts = append(revokedCerts, newRevCert)
}

WRITE:
signingBundle, caErr := fetchCAInfo(ctx, req)
switch caErr.(type) {
case errutil.UserError:
Expand All @@ -175,19 +221,6 @@ func buildCRL(ctx context.Context, b *backend, req *logical.Request) error {
return errutil.InternalError{Err: fmt.Sprintf("error fetching CA certificate: %s", caErr)}
}

crlLifetime := b.crlLifetime
crlInfo, err := b.CRL(ctx, req.Storage)
if err != nil {
return errutil.InternalError{Err: fmt.Sprintf("error fetching CRL config information: %s", err)}
}
if crlInfo != nil {
crlDur, err := time.ParseDuration(crlInfo.Expiry)
if err != nil {
return errutil.InternalError{Err: fmt.Sprintf("error parsing CRL duration of %s", crlInfo.Expiry)}
}
crlLifetime = crlDur
}

crlBytes, err := signingBundle.Certificate.CreateCRL(rand.Reader, signingBundle.PrivateKey, revokedCerts, time.Now(), time.Now().Add(crlLifetime))
if err != nil {
return errutil.InternalError{Err: fmt.Sprintf("error creating new CRL: %s", err)}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_config_ca.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (b *backend) pathCAWrite(ctx context.Context, req *logical.Request, data *f
return nil, err
}

err = buildCRL(ctx, b, req)
err = buildCRL(ctx, b, req, true)

return nil, err
}
Expand Down
47 changes: 39 additions & 8 deletions builtin/logical/pki/path_config_crl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import (
"fmt"
"time"

"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)

// CRLConfig holds basic CRL configuration information
type crlConfig struct {
Expiry string `json:"expiry" mapstructure:"expiry" structs:"expiry"`
Expiry string `json:"expiry" mapstructure:"expiry"`
Disable bool `json:"disable"`
}

func pathConfigCRL(b *backend) *framework.Path {
Expand All @@ -24,6 +27,10 @@ func pathConfigCRL(b *backend) *framework.Path {
valid; defaults to 72 hours`,
Default: "72h",
},
"disable": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `If set to true, disables generating the CRL entirely.`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand Down Expand Up @@ -64,21 +71,34 @@ func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, data *f

return &logical.Response{
Data: map[string]interface{}{
"expiry": config.Expiry,
"expiry": config.Expiry,
"disable": config.Disable,
},
}, nil
}

func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
expiry := d.Get("expiry").(string)

_, err := time.ParseDuration(expiry)
config, err := b.CRL(ctx, req.Storage)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Given expiry could not be decoded: %s", err)), nil
return nil, err
}
if config == nil {
config = &crlConfig{}
}

if expiryRaw, ok := d.GetOk("expiry"); ok {
expiry := expiryRaw.(string)
_, err := time.ParseDuration(expiry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("given expiry could not be decoded: %s", err)), nil
}
config.Expiry = expiry
}

config := &crlConfig{
Expiry: expiry,
var oldDisable bool
if disableRaw, ok := d.GetOk("disable"); ok {
oldDisable = config.Disable
config.Disable = disableRaw.(bool)
}

entry, err := logical.StorageEntryJSON("config/crl", config)
Expand All @@ -90,6 +110,17 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
return nil, err
}

if oldDisable != config.Disable {
// It wasn't disabled but now it is, rotate
crlErr := buildCRL(ctx, b, req, true)
switch crlErr.(type) {
case errutil.UserError:
return logical.ErrorResponse(fmt.Sprintf("Error during CRL building: %s", crlErr)), nil
case errutil.InternalError:
return nil, errwrap.Wrapf("error encountered during CRL building: {{err}}", crlErr)
}
}

return nil, nil
}

Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_intermediate.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (b *backend) pathSetSignedIntermediate(ctx context.Context, req *logical.Re
}

// Build a fresh CRL
err = buildCRL(ctx, b, req)
err = buildCRL(ctx, b, req, true)

return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (b *backend) pathRotateCRLRead(ctx context.Context, req *logical.Request, d
b.revokeStorageLock.RLock()
defer b.revokeStorageLock.RUnlock()

crlErr := buildCRL(ctx, b, req)
crlErr := buildCRL(ctx, b, req, false)
switch crlErr.(type) {
case errutil.UserError:
return logical.ErrorResponse(fmt.Sprintf("Error during CRL building: %s", crlErr)), nil
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_root.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
}

// Build a fresh CRL
err = buildCRL(ctx, b, req)
err = buildCRL(ctx, b, req, true)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_tidy.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr
}

if tidiedRevoked {
if err := buildCRL(ctx, b, req); err != nil {
if err := buildCRL(ctx, b, req, false); err != nil {
return err
}
}
Expand Down
Loading