Skip to content

Commit

Permalink
Optimize sort keys by server in memcache client (#8026)
Browse files Browse the repository at this point in the history
* optimize sort keys by server in memcache client

Signed-off-by: Ben Ye <[email protected]>

* address comments

Signed-off-by: Ben Ye <[email protected]>

* remove unused mockAddr

Signed-off-by: Ben Ye <[email protected]>

---------

Signed-off-by: Ben Ye <[email protected]>
  • Loading branch information
yeya24 authored Jan 6, 2025
1 parent bed76cf commit 2ff07b2
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 79 deletions.
21 changes: 9 additions & 12 deletions pkg/cacheutil/memcached_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ type updatableServerSelector interface {
// resolve. No attempt is made to connect to the server. If any
// error occurs, no changes are made to the internal server list.
SetServers(servers ...string) error

// PickServerForKeys is like PickServer but returns a map of server address
// and corresponding keys.
PickServerForKeys(keys []string) (map[string][]string, error)
}

// MemcachedClientConfig is the config accepted by RemoteCacheClient.
Expand Down Expand Up @@ -571,20 +575,13 @@ func (c *memcachedClient) getMultiSingle(ctx context.Context, keys []string) (it
// *except* that keys sharded to the same server will be together. The order of keys
// returned may change from call to call.
func (c *memcachedClient) sortKeysByServer(keys []string) []string {
bucketed := make(map[string][]string)

for _, key := range keys {
addr, err := c.selector.PickServer(key)
// If we couldn't determine the correct server, return keys in existing order
if err != nil {
return keys
}

addrString := addr.String()
bucketed[addrString] = append(bucketed[addrString], key)
bucketed, err := c.selector.PickServerForKeys(keys)
// No need to pick server and sort keys if no more than 1 server.
if err != nil || len(bucketed) <= 1 {
return keys
}

var out []string
out := make([]string, 0, len(keys))
for srv := range bucketed {
out = append(out, bucketed[srv]...)
}
Expand Down
57 changes: 28 additions & 29 deletions pkg/cacheutil/memcached_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,9 @@ func TestMemcachedClient_sortKeysByServer(t *testing.T) {
config.Addresses = []string{"127.0.0.1:11211", "127.0.0.2:11211"}
backendMock := newMemcachedClientBackendMock()
selector := &mockServerSelector{
serversByKey: map[string]mockAddr{
"key1": "127.0.0.1:11211",
"key2": "127.0.0.2:11211",
"key3": "127.0.0.1:11211",
"key4": "127.0.0.2:11211",
"key5": "127.0.0.1:11211",
"key6": "127.0.0.2:11211",
resp: map[string][]string{
"127.0.0.1:11211": {"key1", "key2", "key4"},
"127.0.0.2:11211": {"key5", "key3", "key6"},
},
}

Expand All @@ -478,41 +474,44 @@ func TestMemcachedClient_sortKeysByServer(t *testing.T) {
}

sorted := client.sortKeysByServer(keys)
testutil.ContainsStringSlice(t, sorted, []string{"key1", "key3", "key5"})
testutil.ContainsStringSlice(t, sorted, []string{"key2", "key4", "key6"})
}

type mockAddr string
testutil.ContainsStringSlice(t, sorted, []string{"key1", "key2", "key4"})
testutil.ContainsStringSlice(t, sorted, []string{"key5", "key3", "key6"})

func (m mockAddr) Network() string {
return "mock"
}
// 1 server no need to sort.
client.selector = &mockServerSelector{
resp: map[string][]string{
"127.0.0.1:11211": {},
},
}
sorted = client.sortKeysByServer(keys)
testutil.ContainsStringSlice(t, sorted, []string{"key1", "key2", "key3", "key4", "key5", "key6"})

func (m mockAddr) String() string {
return string(m)
// 0 server no need to sort.
client.selector = &mockServerSelector{
resp: map[string][]string{},
err: memcache.ErrCacheMiss,
}
sorted = client.sortKeysByServer(keys)
testutil.ContainsStringSlice(t, sorted, []string{"key1", "key2", "key3", "key4", "key5", "key6"})
}

type mockServerSelector struct {
serversByKey map[string]mockAddr
resp map[string][]string
err error
}

// PickServer is not used here.
func (m *mockServerSelector) PickServer(key string) (net.Addr, error) {
if srv, ok := m.serversByKey[key]; ok {
return srv, nil
}

panic(fmt.Sprintf("unmapped key: %s", key))
}

// Each is not used here.
func (m *mockServerSelector) Each(f func(net.Addr) error) error {
for k := range m.serversByKey {
addr := m.serversByKey[k]
if err := f(addr); err != nil {
return err
}
}
panic("not implemented")
}

return nil
func (m *mockServerSelector) PickServerForKeys(keys []string) (map[string][]string, error) {
return m.resp, m.err
}

func (m *mockServerSelector) SetServers(...string) error {
Expand Down
126 changes: 88 additions & 38 deletions pkg/cacheutil/memcached_server_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,14 @@ package cacheutil

import (
"net"
"strings"
"sync"

"github.com/bradfitz/gomemcache/memcache"
"github.com/cespare/xxhash/v2"
"github.com/facette/natsort"
)

var (
addrsPool = sync.Pool{
New: func() interface{} {
addrs := make([]net.Addr, 0, 64)
return &addrs
},
}
)

// MemcachedJumpHashSelector implements the memcache.ServerSelector
// interface, utilizing a jump hash to distribute keys to servers.
//
Expand All @@ -30,9 +22,8 @@ var (
// with consistent DNS names where the naturally sorted order
// is predictable (ie. Kubernetes statefulsets).
type MemcachedJumpHashSelector struct {
// To avoid copy and pasting all memcache server list logic,
// we embed it and implement our features on top of it.
servers memcache.ServerList
mu sync.RWMutex
addrs []net.Addr
}

// SetServers changes a MemcachedJumpHashSelector's set of servers at
Expand All @@ -53,52 +44,111 @@ func (s *MemcachedJumpHashSelector) SetServers(servers ...string) error {
copy(sortedServers, servers)
natsort.Sort(sortedServers)

return s.servers.SetServers(sortedServers...)
naddr := make([]net.Addr, len(servers))
var err error
for i, server := range sortedServers {
naddr[i], err = parseStaticAddr(server)
if err != nil {
return err
}
}

s.mu.Lock()
defer s.mu.Unlock()
s.addrs = naddr
return nil
}

// PickServer returns the server address that a given item
// should be shared onto.
func (s *MemcachedJumpHashSelector) PickServer(key string) (net.Addr, error) {
// Unfortunately we can't read the list of server addresses from
// the original implementation, so we use Each() to fetch all of them.
addrs := *(addrsPool.Get().(*[]net.Addr))
err := s.servers.Each(func(addr net.Addr) error {
addrs = append(addrs, addr)
return nil
})
if err != nil {
return nil, err
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.addrs) == 0 {
return nil, memcache.ErrNoServers
} else if len(s.addrs) == 1 {
return s.addrs[0], nil
}
return pickServerWithJumpHash(s.addrs, key), nil
}

// Each iterates over each server and calls the given function.
// If f returns a non-nil error, iteration will stop and that
// error will be returned.
func (s *MemcachedJumpHashSelector) Each(f func(net.Addr) error) error {
s.mu.RLock()
defer s.mu.RUnlock()
for _, def := range s.addrs {
if err := f(def); err != nil {
return err
}
}
return nil
}

// PickServerForKeys is like PickServer but returns a map of server address
// and corresponding keys.
func (s *MemcachedJumpHashSelector) PickServerForKeys(keys []string) (map[string][]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()

// No need of a jump hash in case of 0 or 1 servers.
if len(addrs) == 0 {
addrs = (addrs)[:0]
addrsPool.Put(&addrs)
if len(s.addrs) <= 0 {
return nil, memcache.ErrNoServers
}
if len(addrs) == 1 {
picked := addrs[0]

addrs = (addrs)[:0]
addrsPool.Put(&addrs)
m := make(map[string][]string, len(keys))
if len(s.addrs) == 1 {
m[s.addrs[0].String()] = keys
return m, nil
}

return picked, nil
for _, key := range keys {
// Pick a server using the jump hash.
picked := pickServerWithJumpHash(s.addrs, key).String()
m[picked] = append(m[picked], key)
}

return m, nil
}

// pickServerWithJumpHash returns the server address that a given item should be shared onto.
func pickServerWithJumpHash(addrs []net.Addr, key string) net.Addr {
// Pick a server using the jump hash.
cs := xxhash.Sum64String(key)
idx := jumpHash(cs, len(addrs))
picked := (addrs)[idx]
return picked
}

addrs = (addrs)[:0]
addrsPool.Put(&addrs)
// Copied from https://github.com/bradfitz/gomemcache/blob/master/memcache/selector.go#L68.
func parseStaticAddr(server string) (net.Addr, error) {
if strings.Contains(server, "/") {
addr, err := net.ResolveUnixAddr("unix", server)
if err != nil {
return nil, err
}
return newStaticAddr(addr), nil
}
tcpaddr, err := net.ResolveTCPAddr("tcp", server)
if err != nil {
return nil, err
}
return newStaticAddr(tcpaddr), nil
}

return picked, nil
// Copied from https://github.com/bradfitz/gomemcache/blob/master/memcache/selector.go#L45
// staticAddr caches the Network() and String() values from any net.Addr.
type staticAddr struct {
ntw, str string
}

// Each iterates over each server and calls the given function.
// If f returns a non-nil error, iteration will stop and that
// error will be returned.
func (s *MemcachedJumpHashSelector) Each(f func(net.Addr) error) error {
return s.servers.Each(f)
func newStaticAddr(a net.Addr) net.Addr {
return &staticAddr{
ntw: a.Network(),
str: a.String(),
}
}

func (s *staticAddr) Network() string { return s.ntw }
func (s *staticAddr) String() string { return s.str }

0 comments on commit 2ff07b2

Please sign in to comment.