diff --git a/pkg/cacheutil/memcached_client.go b/pkg/cacheutil/memcached_client.go index e700ffee7b..5737fac949 100644 --- a/pkg/cacheutil/memcached_client.go +++ b/pkg/cacheutil/memcached_client.go @@ -102,6 +102,10 @@ type updatableServerSelector interface { // resolve. No attempt is made to connect to the server. If any // error occurs, no changes are made to the internal server list. SetServers(servers ...string) error + + // PickServerForKeys is like PickServer but returns a map of server address + // and corresponding keys. + PickServerForKeys(keys []string) (map[string][]string, error) } // MemcachedClientConfig is the config accepted by RemoteCacheClient. @@ -571,20 +575,13 @@ func (c *memcachedClient) getMultiSingle(ctx context.Context, keys []string) (it // *except* that keys sharded to the same server will be together. The order of keys // returned may change from call to call. func (c *memcachedClient) sortKeysByServer(keys []string) []string { - bucketed := make(map[string][]string) - - for _, key := range keys { - addr, err := c.selector.PickServer(key) - // If we couldn't determine the correct server, return keys in existing order - if err != nil { - return keys - } - - addrString := addr.String() - bucketed[addrString] = append(bucketed[addrString], key) + bucketed, err := c.selector.PickServerForKeys(keys) + // No need to pick server and sort keys if no more than 1 server. + if err != nil || len(bucketed) <= 1 { + return keys } - var out []string + out := make([]string, 0, len(keys)) for srv := range bucketed { out = append(out, bucketed[srv]...) } diff --git a/pkg/cacheutil/memcached_client_test.go b/pkg/cacheutil/memcached_client_test.go index 4951075720..957fd060b0 100644 --- a/pkg/cacheutil/memcached_client_test.go +++ b/pkg/cacheutil/memcached_client_test.go @@ -454,13 +454,9 @@ func TestMemcachedClient_sortKeysByServer(t *testing.T) { config.Addresses = []string{"127.0.0.1:11211", "127.0.0.2:11211"} backendMock := newMemcachedClientBackendMock() selector := &mockServerSelector{ - serversByKey: map[string]mockAddr{ - "key1": "127.0.0.1:11211", - "key2": "127.0.0.2:11211", - "key3": "127.0.0.1:11211", - "key4": "127.0.0.2:11211", - "key5": "127.0.0.1:11211", - "key6": "127.0.0.2:11211", + resp: map[string][]string{ + "127.0.0.1:11211": {"key1", "key2", "key4"}, + "127.0.0.2:11211": {"key5", "key3", "key6"}, }, } @@ -478,41 +474,44 @@ func TestMemcachedClient_sortKeysByServer(t *testing.T) { } sorted := client.sortKeysByServer(keys) - testutil.ContainsStringSlice(t, sorted, []string{"key1", "key3", "key5"}) - testutil.ContainsStringSlice(t, sorted, []string{"key2", "key4", "key6"}) -} - -type mockAddr string + testutil.ContainsStringSlice(t, sorted, []string{"key1", "key2", "key4"}) + testutil.ContainsStringSlice(t, sorted, []string{"key5", "key3", "key6"}) -func (m mockAddr) Network() string { - return "mock" -} + // 1 server no need to sort. + client.selector = &mockServerSelector{ + resp: map[string][]string{ + "127.0.0.1:11211": {}, + }, + } + sorted = client.sortKeysByServer(keys) + testutil.ContainsStringSlice(t, sorted, []string{"key1", "key2", "key3", "key4", "key5", "key6"}) -func (m mockAddr) String() string { - return string(m) + // 0 server no need to sort. + client.selector = &mockServerSelector{ + resp: map[string][]string{}, + err: memcache.ErrCacheMiss, + } + sorted = client.sortKeysByServer(keys) + testutil.ContainsStringSlice(t, sorted, []string{"key1", "key2", "key3", "key4", "key5", "key6"}) } type mockServerSelector struct { - serversByKey map[string]mockAddr + resp map[string][]string + err error } +// PickServer is not used here. func (m *mockServerSelector) PickServer(key string) (net.Addr, error) { - if srv, ok := m.serversByKey[key]; ok { - return srv, nil - } - panic(fmt.Sprintf("unmapped key: %s", key)) } +// Each is not used here. func (m *mockServerSelector) Each(f func(net.Addr) error) error { - for k := range m.serversByKey { - addr := m.serversByKey[k] - if err := f(addr); err != nil { - return err - } - } + panic("not implemented") +} - return nil +func (m *mockServerSelector) PickServerForKeys(keys []string) (map[string][]string, error) { + return m.resp, m.err } func (m *mockServerSelector) SetServers(...string) error { diff --git a/pkg/cacheutil/memcached_server_selector.go b/pkg/cacheutil/memcached_server_selector.go index 5426d6af33..bc1f36706f 100644 --- a/pkg/cacheutil/memcached_server_selector.go +++ b/pkg/cacheutil/memcached_server_selector.go @@ -5,6 +5,7 @@ package cacheutil import ( "net" + "strings" "sync" "github.com/bradfitz/gomemcache/memcache" @@ -12,15 +13,6 @@ import ( "github.com/facette/natsort" ) -var ( - addrsPool = sync.Pool{ - New: func() interface{} { - addrs := make([]net.Addr, 0, 64) - return &addrs - }, - } -) - // MemcachedJumpHashSelector implements the memcache.ServerSelector // interface, utilizing a jump hash to distribute keys to servers. // @@ -30,9 +22,8 @@ var ( // with consistent DNS names where the naturally sorted order // is predictable (ie. Kubernetes statefulsets). type MemcachedJumpHashSelector struct { - // To avoid copy and pasting all memcache server list logic, - // we embed it and implement our features on top of it. - servers memcache.ServerList + mu sync.RWMutex + addrs []net.Addr } // SetServers changes a MemcachedJumpHashSelector's set of servers at @@ -53,52 +44,111 @@ func (s *MemcachedJumpHashSelector) SetServers(servers ...string) error { copy(sortedServers, servers) natsort.Sort(sortedServers) - return s.servers.SetServers(sortedServers...) + naddr := make([]net.Addr, len(servers)) + var err error + for i, server := range sortedServers { + naddr[i], err = parseStaticAddr(server) + if err != nil { + return err + } + } + + s.mu.Lock() + defer s.mu.Unlock() + s.addrs = naddr + return nil } // PickServer returns the server address that a given item // should be shared onto. func (s *MemcachedJumpHashSelector) PickServer(key string) (net.Addr, error) { - // Unfortunately we can't read the list of server addresses from - // the original implementation, so we use Each() to fetch all of them. - addrs := *(addrsPool.Get().(*[]net.Addr)) - err := s.servers.Each(func(addr net.Addr) error { - addrs = append(addrs, addr) - return nil - }) - if err != nil { - return nil, err + s.mu.RLock() + defer s.mu.RUnlock() + if len(s.addrs) == 0 { + return nil, memcache.ErrNoServers + } else if len(s.addrs) == 1 { + return s.addrs[0], nil } + return pickServerWithJumpHash(s.addrs, key), nil +} + +// Each iterates over each server and calls the given function. +// If f returns a non-nil error, iteration will stop and that +// error will be returned. +func (s *MemcachedJumpHashSelector) Each(f func(net.Addr) error) error { + s.mu.RLock() + defer s.mu.RUnlock() + for _, def := range s.addrs { + if err := f(def); err != nil { + return err + } + } + return nil +} + +// PickServerForKeys is like PickServer but returns a map of server address +// and corresponding keys. +func (s *MemcachedJumpHashSelector) PickServerForKeys(keys []string) (map[string][]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() // No need of a jump hash in case of 0 or 1 servers. - if len(addrs) == 0 { - addrs = (addrs)[:0] - addrsPool.Put(&addrs) + if len(s.addrs) <= 0 { return nil, memcache.ErrNoServers } - if len(addrs) == 1 { - picked := addrs[0] - addrs = (addrs)[:0] - addrsPool.Put(&addrs) + m := make(map[string][]string, len(keys)) + if len(s.addrs) == 1 { + m[s.addrs[0].String()] = keys + return m, nil + } - return picked, nil + for _, key := range keys { + // Pick a server using the jump hash. + picked := pickServerWithJumpHash(s.addrs, key).String() + m[picked] = append(m[picked], key) } + return m, nil +} + +// pickServerWithJumpHash returns the server address that a given item should be shared onto. +func pickServerWithJumpHash(addrs []net.Addr, key string) net.Addr { // Pick a server using the jump hash. cs := xxhash.Sum64String(key) idx := jumpHash(cs, len(addrs)) picked := (addrs)[idx] + return picked +} - addrs = (addrs)[:0] - addrsPool.Put(&addrs) +// Copied from https://github.com/bradfitz/gomemcache/blob/master/memcache/selector.go#L68. +func parseStaticAddr(server string) (net.Addr, error) { + if strings.Contains(server, "/") { + addr, err := net.ResolveUnixAddr("unix", server) + if err != nil { + return nil, err + } + return newStaticAddr(addr), nil + } + tcpaddr, err := net.ResolveTCPAddr("tcp", server) + if err != nil { + return nil, err + } + return newStaticAddr(tcpaddr), nil +} - return picked, nil +// Copied from https://github.com/bradfitz/gomemcache/blob/master/memcache/selector.go#L45 +// staticAddr caches the Network() and String() values from any net.Addr. +type staticAddr struct { + ntw, str string } -// Each iterates over each server and calls the given function. -// If f returns a non-nil error, iteration will stop and that -// error will be returned. -func (s *MemcachedJumpHashSelector) Each(f func(net.Addr) error) error { - return s.servers.Each(f) +func newStaticAddr(a net.Addr) net.Addr { + return &staticAddr{ + ntw: a.Network(), + str: a.String(), + } } + +func (s *staticAddr) Network() string { return s.ntw } +func (s *staticAddr) String() string { return s.str }