Skip to content

Commit

Permalink
[VAULT-3157] Move mergeStates utils from Agent to api module (hashi…
Browse files Browse the repository at this point in the history
…corp#12731)

* move merge and compare states to vault core

* move MergeState, CompareStates and ParseRequiredStates to api package

* fix merge state reference in API Proxy

* move mergeStates test to api package

* add changelog

* ghost commit to trigger CI

* rename CompareStates to CompareReplicationStates

* rename MergeStates and make compareStates and parseStates private methods

* improved error messaging in parseReplicationState

* export ParseReplicationState for enterprise files
  • Loading branch information
vinay-gopalan authored Oct 6, 2021
1 parent 547e4c9 commit 45b0179
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 190 deletions.
106 changes: 106 additions & 0 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package api

import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/http"
Expand All @@ -21,6 +25,8 @@ import (
rootcerts "github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
)
Expand Down Expand Up @@ -1161,6 +1167,106 @@ func RequireState(states ...string) RequestCallback {
}
}

// compareReplicationStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
// are invalid or from different clusters.
func compareReplicationStates(s1, s2 string) (int, error) {
w1, err := ParseReplicationState(s1, nil)
if err != nil {
return 0, err
}
w2, err := ParseReplicationState(s2, nil)
if err != nil {
return 0, err
}

if w1.ClusterID != w2.ClusterID {
return 0, fmt.Errorf("can't compare replication states with different ClusterIDs")
}

switch {
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
return 1, nil
// We've already handled the case where both are equal above, so really we're
// asking here if one or both are lesser.
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
return -1, nil
}

return 0, nil
}

// MergeReplicationStates returns a merged array of replication states by iterating
// through all states in `old`. An iterated state is merged to the result before `new`
// based on the result of compareReplicationStates
func MergeReplicationStates(old []string, new string) []string {
if len(old) == 0 || len(old) > 2 {
return []string{new}
}

var ret []string
for _, o := range old {
c, err := compareReplicationStates(o, new)
if err != nil {
return []string{new}
}
switch c {
case 1:
ret = append(ret, o)
case -1:
ret = append(ret, new)
case 0:
ret = append(ret, o, new)
}
}
return strutil.RemoveDuplicates(ret, false)
}

func ParseReplicationState(raw string, hmacKey []byte) (*logical.WALState, error) {
cooked, err := base64.StdEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
s := string(cooked)

lastIndex := strings.LastIndexByte(s, ':')
if lastIndex == -1 {
return nil, fmt.Errorf("invalid full state header format")
}
state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:]
stateHMAC, err := hex.DecodeString(stateHMACRaw)
if err != nil {
return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err)
}

if len(hmacKey) != 0 {
hm := hmac.New(sha256.New, hmacKey)
hm.Write([]byte(state))
if !hmac.Equal(hm.Sum(nil), stateHMAC) {
return nil, fmt.Errorf("invalid state header HMAC (mismatch)")
}
}

pieces := strings.Split(state, ":")
if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" {
return nil, fmt.Errorf("invalid state header format")
}
localIndex, err := strconv.ParseUint(pieces[2], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid local index in state header: %w", err)
}
replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid replicated index in state header: %w", err)
}

return &logical.WALState{
ClusterID: pieces[1],
LocalIndex: localIndex,
ReplicatedIndex: replicatedIndex,
}, nil
}

// ForwardInconsistent returns a request callback that will add a request
// header which says: if the state required isn't present on the node receiving
// this request, forward it to the active node. This should be used in
Expand Down
84 changes: 84 additions & 0 deletions api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net/http"
Expand All @@ -14,6 +15,7 @@ import (
"testing"
"time"

"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/consts"
)
Expand Down Expand Up @@ -562,3 +564,85 @@ func TestSetHeadersRaceSafe(t *testing.T) {
}
}
}

func TestMergeReplicationStates(t *testing.T) {
type testCase struct {
name string
old []string
new string
expected []string
}

testCases := []testCase{
{
name: "empty-old",
old: nil,
new: "v1:cid:1:0:",
expected: []string{"v1:cid:1:0:"},
},
{
name: "old-smaller",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "old-bigger",
old: []string{"v1:cid:2:0:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "mixed-single",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:0:1:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-single-alt",
old: []string{"v1:cid:0:1:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-double",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
},
{
name: "newer-both",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:1:",
expected: []string{"v1:cid:2:1:"},
},
}

b64enc := func(ss []string) []string {
var ret []string
for _, s := range ss {
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
}
return ret
}
b64dec := func(ss []string) []string {
var ret []string
for _, s := range ss {
d, err := base64.StdEncoding.DecodeString(s)
if err != nil {
t.Fatal(err)
}
ret = append(ret, string(d))
}
return ret
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
}
})
}
}
3 changes: 3 additions & 0 deletions changelog/12731.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
api: Move mergeStates and other required utils from agent to api module
```
56 changes: 1 addition & 55 deletions command/agent/cache/api_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import (

hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
)

type EnforceConsistency int
Expand Down Expand Up @@ -60,58 +58,6 @@ func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
}, nil
}

// compareStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
// are invalid or from different clusters.
func compareStates(s1, s2 string) (int, error) {
w1, err := vault.ParseRequiredState(s1, nil)
if err != nil {
return 0, err
}
w2, err := vault.ParseRequiredState(s2, nil)
if err != nil {
return 0, err
}

if w1.ClusterID != w2.ClusterID {
return 0, fmt.Errorf("don't know how to compare states with different ClusterIDs")
}

switch {
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
return 1, nil
// We've already handled the case where both are equal above, so really we're
// asking here if one or both are lesser.
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
return -1, nil
}

return 0, nil
}

func mergeStates(old []string, new string) []string {
if len(old) == 0 || len(old) > 2 {
return []string{new}
}

var ret []string
for _, o := range old {
c, err := compareStates(o, new)
if err != nil {
return []string{new}
}
switch c {
case 1:
ret = append(ret, o)
case -1:
ret = append(ret, new)
case 0:
ret = append(ret, o, new)
}
}
return strutil.RemoveDuplicates(ret, false)
}

func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
client, err := ap.client.Clone()
if err != nil {
Expand Down Expand Up @@ -184,7 +130,7 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse,
// Vault uses to do so. So instead we compare any of the 0-2 saved states
// we have to the new header, keeping the newest 1-2 of these, and sending
// them to Vault to evaluate.
ap.lastIndexStates = mergeStates(ap.lastIndexStates, newState)
ap.lastIndexStates = api.MergeReplicationStates(ap.lastIndexStates, newState)
ap.l.Unlock()
}

Expand Down
Loading

0 comments on commit 45b0179

Please sign in to comment.