diff --git a/builtin/logical/totp/backend.go b/builtin/logical/totp/backend.go index 4e3554bdbdf3..e26926aa73f4 100644 --- a/builtin/logical/totp/backend.go +++ b/builtin/logical/totp/backend.go @@ -2,9 +2,11 @@ package totp import ( "strings" + "time" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" + cache "github.com/patrickmn/go-cache" ) func Factory(conf *logical.BackendConfig) (logical.Backend, error) { @@ -25,11 +27,15 @@ func Backend(conf *logical.BackendConfig) *backend { Secrets: []*framework.Secret{}, } + b.usedCodes = cache.New(0, 30*time.Second) + return &b } type backend struct { *framework.Backend + + usedCodes *cache.Cache } const backendHelp = ` diff --git a/builtin/logical/totp/backend_test.go b/builtin/logical/totp/backend_test.go index 2a18a056f6e8..a3304c23209f 100644 --- a/builtin/logical/totp/backend_test.go +++ b/builtin/logical/totp/backend_test.go @@ -258,8 +258,10 @@ func TestBackend_keyCrudDefaultValues(t *testing.T) { Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), - testAccStepValidateCode(t, "test", code, true), - testAccStepValidateCode(t, "test", invalidCode, false), + testAccStepValidateCode(t, "test", code, true, false), + // Next step should fail because it should be in the used cache + testAccStepValidateCode(t, "test", code, false, true), + testAccStepValidateCode(t, "test", invalidCode, false, false), testAccStepDeleteKey(t, "test"), testAccStepReadKey(t, "test", nil), }, @@ -1091,13 +1093,14 @@ func testAccStepReadKey(t *testing.T, name string, expected map[string]interface } } -func testAccStepValidateCode(t *testing.T, name string, code string, valid bool) logicaltest.TestStep { +func testAccStepValidateCode(t *testing.T, name string, code string, valid, expectError bool) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "code/" + name, Data: map[string]interface{}{ "code": code, }, + ErrorOk: expectError, Check: func(resp *logical.Response) error { if resp == nil { return fmt.Errorf("bad: %#v", resp) diff --git a/builtin/logical/totp/path_code.go b/builtin/logical/totp/path_code.go index 0481db145f1a..ebc3d47fc757 100644 --- a/builtin/logical/totp/path_code.go +++ b/builtin/logical/totp/path_code.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" otplib "github.com/pquerna/otp" @@ -84,6 +85,13 @@ func (b *backend) pathValidateCode( return logical.ErrorResponse(fmt.Sprintf("unknown key: %s", name)), nil } + usedName := fmt.Sprintf("%s_%s", name, code) + + _, ok := b.usedCodes.Get(usedName) + if ok { + return logical.ErrorResponse("code already used; wait until the next time period"), nil + } + valid, err := totplib.ValidateCustom(code, key.Key, time.Now(), totplib.ValidateOpts{ Period: key.Period, Skew: key.Skew, @@ -94,6 +102,16 @@ func (b *backend) pathValidateCode( return logical.ErrorResponse("an error occured while validating the code"), err } + // Take the key skew, add two for behind and in front, and multiple that by + // the period to cover the full possibility of the validity of the key + err = b.usedCodes.Add(usedName, nil, time.Duration( + int64(time.Second)* + int64(key.Period)* + int64((2+key.Skew)))) + if err != nil { + return nil, errwrap.Wrapf("error adding code to used cache: {{err}}", err) + } + return &logical.Response{ Data: map[string]interface{}{ "valid": valid,