diff --git a/client/clientfactory.go b/client/clientfactory.go index 82401974f2f..564f6eb3d34 100644 --- a/client/clientfactory.go +++ b/client/clientfactory.go @@ -152,10 +152,12 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout( peerResolver := matching.NewPeerResolver(cf.resolver, namedPort) + defaultLoadBalancer := matching.NewLoadBalancer(domainIDToName, cf.dynConfig) + roundRobinLoadBalancer := matching.NewRoundRobinLoadBalancer(domainIDToName, cf.dynConfig) client := matching.NewClient( rawClient, peerResolver, - matching.NewLoadBalancer(domainIDToName, cf.dynConfig), + matching.NewMultiLoadBalancer(defaultLoadBalancer, roundRobinLoadBalancer, domainIDToName, cf.dynConfig), ) client = timeoutwrapper.NewMatchingClient(client, longPollTimeout, timeout) if errorRate := cf.dynConfig.GetFloat64Property(dynamicconfig.MatchingErrorInjectionRate)(); errorRate != 0 { diff --git a/client/matching/multi_loadbalancer.go b/client/matching/multi_loadbalancer.go new file mode 100644 index 00000000000..d71b132ca98 --- /dev/null +++ b/client/matching/multi_loadbalancer.go @@ -0,0 +1,83 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package matching + +import ( + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/types" +) + +type ( + multiLoadBalancer struct { + random LoadBalancer + roundRobin LoadBalancer + domainIDToName func(string) (string, error) + loadbalancerStrategy dynamicconfig.StringPropertyFnWithTaskListInfoFilters + } +) + +func NewMultiLoadBalancer( + random LoadBalancer, + roundRobin LoadBalancer, + domainIDToName func(string) (string, error), + dc *dynamicconfig.Collection, +) LoadBalancer { + return &multiLoadBalancer{ + random: random, + roundRobin: roundRobin, + domainIDToName: domainIDToName, + loadbalancerStrategy: dc.GetStringPropertyFilteredByTaskListInfo(dynamicconfig.TasklistLoadBalancerStrategy), + } +} + +func (lb *multiLoadBalancer) PickWritePartition( + domainID string, + taskList types.TaskList, + taskListType int, + forwardedFrom string, +) string { + domainName, err := lb.domainIDToName(domainID) + if err != nil { + return lb.random.PickWritePartition(domainID, taskList, taskListType, forwardedFrom) + } + if lb.loadbalancerStrategy(domainName, taskList.GetName(), taskListType) == "round-robin" { + return lb.roundRobin.PickWritePartition(domainID, taskList, taskListType, forwardedFrom) + } + return lb.random.PickWritePartition(domainID, taskList, taskListType, forwardedFrom) +} + +func (lb *multiLoadBalancer) PickReadPartition( + domainID string, + taskList types.TaskList, + taskListType int, + forwardedFrom string, +) string { + domainName, err := lb.domainIDToName(domainID) + if err != nil { + return lb.random.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + } + if lb.loadbalancerStrategy(domainName, taskList.GetName(), taskListType) == "round-robin" { + return lb.roundRobin.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + } + return lb.random.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) +} diff --git a/client/matching/multi_loadbalancer_test.go b/client/matching/multi_loadbalancer_test.go new file mode 100644 index 00000000000..7a1d3175a4a --- /dev/null +++ b/client/matching/multi_loadbalancer_test.go @@ -0,0 +1,213 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package matching + +import ( + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/types" +) + +func TestNewMultiLoadBalancer(t *testing.T) { + ctrl := gomock.NewController(t) + randomMock := NewMockLoadBalancer(ctrl) + roundRobinMock := NewMockLoadBalancer(ctrl) + domainIDToName := func(domainID string) (string, error) { + return "testDomainName", nil + } + dc := dynamicconfig.NewCollection(dynamicconfig.NewNopClient(), testlogger.New(t)) + lb := NewMultiLoadBalancer(randomMock, roundRobinMock, domainIDToName, dc) + assert.NotNil(t, lb) + multiLB, ok := lb.(*multiLoadBalancer) + assert.NotNil(t, multiLB) + assert.True(t, ok) + assert.NotNil(t, multiLB.random) + assert.NotNil(t, multiLB.roundRobin) + assert.NotNil(t, multiLB.domainIDToName) + assert.NotNil(t, multiLB.loadbalancerStrategy) +} + +func TestMultiLoadBalancer_PickWritePartition(t *testing.T) { + + // Mock the domainIDToName function + domainIDToName := func(domainID string) (string, error) { + if domainID == "valid-domain" { + return "valid-domain-name", nil + } + return "", errors.New("domain not found") + } + + // Test cases + tests := []struct { + name string + domainID string + taskList types.TaskList + taskListType int + forwardedFrom string + loadbalancerStrategy string + expectedPartition string + }{ + { + name: "random partition when domainIDToName fails", + domainID: "invalid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "random", + expectedPartition: "random-partition", + }, + { + name: "round-robin partition enabled", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "round-robin", + expectedPartition: "roundrobin-partition", + }, + { + name: "random partition when round-robin disabled", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "invalid-enum", + expectedPartition: "random-partition", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock behavior for random and round robin load balancers + ctrl := gomock.NewController(t) + + // Mock the LoadBalancer interface + randomMock := NewMockLoadBalancer(ctrl) + roundRobinMock := NewMockLoadBalancer(ctrl) + randomMock.EXPECT().PickWritePartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("random-partition").AnyTimes() + roundRobinMock.EXPECT().PickWritePartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("roundrobin-partition").AnyTimes() + + loadbalancerStrategyFn := func(domainName, taskListName string, taskListType int) string { + return tt.loadbalancerStrategy + } + + // Create multiLoadBalancer + lb := &multiLoadBalancer{ + random: randomMock, + roundRobin: roundRobinMock, + domainIDToName: domainIDToName, + loadbalancerStrategy: loadbalancerStrategyFn, + } + + // Call PickWritePartition and assert result + partition := lb.PickWritePartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom) + assert.Equal(t, tt.expectedPartition, partition) + }) + } +} + +func TestMultiLoadBalancer_PickReadPartition(t *testing.T) { + + // Mock the domainIDToName function + domainIDToName := func(domainID string) (string, error) { + if domainID == "valid-domain" { + return "valid-domain-name", nil + } + return "", errors.New("domain not found") + } + + // Test cases + tests := []struct { + name string + domainID string + taskList types.TaskList + taskListType int + forwardedFrom string + loadbalancerStrategy string + expectedPartition string + }{ + { + name: "random partition when domainIDToName fails", + domainID: "invalid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "random", + expectedPartition: "random-partition", + }, + { + name: "round-robin partition enabled", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "round-robin", + expectedPartition: "roundrobin-partition", + }, + { + name: "random partition when round-robin disabled", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "invalid-enum", + expectedPartition: "random-partition", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock behavior for random and round robin load balancers + ctrl := gomock.NewController(t) + + // Mock the LoadBalancer interface + randomMock := NewMockLoadBalancer(ctrl) + roundRobinMock := NewMockLoadBalancer(ctrl) + randomMock.EXPECT().PickReadPartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("random-partition").AnyTimes() + roundRobinMock.EXPECT().PickReadPartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("roundrobin-partition").AnyTimes() + + // Mock dynamic config for loadbalancer strategy + loadbalancerStrategyFn := func(domainName, taskListName string, taskListType int) string { + return tt.loadbalancerStrategy + } + + // Create multiLoadBalancer + lb := &multiLoadBalancer{ + random: randomMock, + roundRobin: roundRobinMock, + domainIDToName: domainIDToName, + loadbalancerStrategy: loadbalancerStrategyFn, + } + + // Call PickReadPartition and assert result + partition := lb.PickReadPartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom) + assert.Equal(t, tt.expectedPartition, partition) + }) + } +} diff --git a/client/matching/rb_loadbalancer.go b/client/matching/rb_loadbalancer.go new file mode 100644 index 00000000000..044d0a00a0f --- /dev/null +++ b/client/matching/rb_loadbalancer.go @@ -0,0 +1,158 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package matching + +import ( + "fmt" + "strings" + "sync/atomic" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/types" +) + +type ( + key struct { + domainName string + taskListName string + taskListType int + } + + roundRobinLoadBalancer struct { + nReadPartitions dynamicconfig.IntPropertyFnWithTaskListInfoFilters + nWritePartitions dynamicconfig.IntPropertyFnWithTaskListInfoFilters + domainIDToName func(string) (string, error) + readCache cache.Cache + writeCache cache.Cache + + pickPartitionFn func(domainName string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache) string + } +) + +func NewRoundRobinLoadBalancer( + domainIDToName func(string) (string, error), + dc *dynamicconfig.Collection, +) LoadBalancer { + return &roundRobinLoadBalancer{ + domainIDToName: domainIDToName, + nReadPartitions: dc.GetIntPropertyFilteredByTaskListInfo(dynamicconfig.MatchingNumTasklistReadPartitions), + nWritePartitions: dc.GetIntPropertyFilteredByTaskListInfo(dynamicconfig.MatchingNumTasklistWritePartitions), + readCache: cache.New(&cache.Options{ + TTL: 0, + InitialCapacity: 100, + Pin: false, + MaxCount: 3000, + ActivelyEvict: false, + }), + writeCache: cache.New(&cache.Options{ + TTL: 0, + InitialCapacity: 100, + Pin: false, + MaxCount: 3000, + ActivelyEvict: false, + }), + pickPartitionFn: pickPartition, + } +} + +func (lb *roundRobinLoadBalancer) PickWritePartition( + domainID string, + taskList types.TaskList, + taskListType int, + forwardedFrom string, +) string { + domainName, err := lb.domainIDToName(domainID) + if err != nil { + return taskList.GetName() + } + nPartitions := lb.nWritePartitions(domainName, taskList.GetName(), taskListType) + + // checks to make sure number of writes never exceeds number of reads + if nRead := lb.nReadPartitions(domainName, taskList.GetName(), taskListType); nPartitions > nRead { + nPartitions = nRead + } + return lb.pickPartitionFn(domainName, taskList, taskListType, forwardedFrom, nPartitions, lb.writeCache) +} + +func (lb *roundRobinLoadBalancer) PickReadPartition( + domainID string, + taskList types.TaskList, + taskListType int, + forwardedFrom string, +) string { + domainName, err := lb.domainIDToName(domainID) + if err != nil { + return taskList.GetName() + } + n := lb.nReadPartitions(domainName, taskList.GetName(), taskListType) + return lb.pickPartitionFn(domainName, taskList, taskListType, forwardedFrom, n, lb.readCache) +} + +func pickPartition( + domainName string, + taskList types.TaskList, + taskListType int, + forwardedFrom string, + nPartitions int, + partitionCache cache.Cache, +) string { + taskListName := taskList.GetName() + if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky { + return taskListName + } + if strings.HasPrefix(taskListName, common.ReservedTaskListPrefix) { + // this should never happen when forwardedFrom is empty + return taskListName + } + if nPartitions <= 1 { + return taskListName + } + + taskListKey := key{ + domainName: domainName, + taskListName: taskListName, + taskListType: taskListType, + } + + valI := partitionCache.Get(taskListKey) + if valI == nil { + val := int64(-1) + var err error + valI, err = partitionCache.PutIfNotExist(taskListKey, &val) + if err != nil { + return taskListName + } + } + valAddr, ok := valI.(*int64) + if !ok { + return taskListName + } + + p := atomic.AddInt64(valAddr, 1) % int64(nPartitions) + if p == 0 { + return taskListName + } + return fmt.Sprintf("%v%v/%v", common.ReservedTaskListPrefix, taskListName, p) +} diff --git a/client/matching/rb_loadbalancer_test.go b/client/matching/rb_loadbalancer_test.go new file mode 100644 index 00000000000..6d563faa093 --- /dev/null +++ b/client/matching/rb_loadbalancer_test.go @@ -0,0 +1,382 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package matching + +import ( + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/types" +) + +func TestNewRoundRobinLoadBalancer(t *testing.T) { + domainIDToName := func(domainID string) (string, error) { + return "testDomainName", nil + } + dc := dynamicconfig.NewCollection(dynamicconfig.NewNopClient(), testlogger.New(t)) + + lb := NewRoundRobinLoadBalancer(domainIDToName, dc) + assert.NotNil(t, lb) + rb, ok := lb.(*roundRobinLoadBalancer) + assert.NotNil(t, rb) + assert.True(t, ok) + + assert.NotNil(t, rb.domainIDToName) + assert.NotNil(t, rb.nReadPartitions) + assert.NotNil(t, rb.nWritePartitions) + assert.NotNil(t, rb.readCache) + assert.NotNil(t, rb.writeCache) +} + +func TestPickPartition(t *testing.T) { + tests := []struct { + name string + domainName string + taskList types.TaskList + taskListType int + forwardedFrom string + nPartitions int + setupCache func(mockCache *cache.MockCache) + expectedResult string + }{ + { + name: "ForwardedFrom is not empty", + domainName: "testDomain", + taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindNormal.Ptr()}, + taskListType: 1, + forwardedFrom: "otherDomain", + nPartitions: 3, + setupCache: nil, + expectedResult: "testTaskList", + }, + { + name: "Sticky task list", + domainName: "testDomain", + taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindSticky.Ptr()}, + taskListType: 1, + forwardedFrom: "", + nPartitions: 3, + setupCache: nil, + expectedResult: "testTaskList", + }, + { + name: "Reserved task list prefix", + domainName: "testDomain", + taskList: types.TaskList{Name: fmt.Sprintf("%vTest", common.ReservedTaskListPrefix), Kind: types.TaskListKindNormal.Ptr()}, + taskListType: 1, + forwardedFrom: "", + nPartitions: 3, + setupCache: nil, + expectedResult: fmt.Sprintf("%vTest", common.ReservedTaskListPrefix), + }, + { + name: "nPartitions <= 1", + domainName: "testDomain", + taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindNormal.Ptr()}, + taskListType: 1, + forwardedFrom: "", + nPartitions: 1, + setupCache: nil, + expectedResult: "testTaskList", + }, + { + name: "Cache miss and partitioned task list", + domainName: "testDomain", + taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindNormal.Ptr()}, + taskListType: 1, + forwardedFrom: "", + nPartitions: 3, + setupCache: func(mockCache *cache.MockCache) { + mockCache.EXPECT().Get(key{ + domainName: "testDomain", + taskListName: "testTaskList", + taskListType: 1, + }).Return(nil) + mockCache.EXPECT().PutIfNotExist(key{ + domainName: "testDomain", + taskListName: "testTaskList", + taskListType: 1, + }, gomock.Any()).DoAndReturn(func(key key, val interface{}) (interface{}, error) { + if *val.(*int64) != -1 { + panic("Expected value to be -1") + } + return val, nil + }) + }, + expectedResult: "testTaskList", + }, + { + name: "Cache error and partitioned task list", + domainName: "testDomain", + taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindNormal.Ptr()}, + taskListType: 1, + forwardedFrom: "", + nPartitions: 3, + setupCache: func(mockCache *cache.MockCache) { + mockCache.EXPECT().Get(key{ + domainName: "testDomain", + taskListName: "testTaskList", + taskListType: 1, + }).Return(nil) + mockCache.EXPECT().PutIfNotExist(key{ + domainName: "testDomain", + taskListName: "testTaskList", + taskListType: 1, + }, gomock.Any()).Return(nil, fmt.Errorf("cache error")) + }, + expectedResult: "testTaskList", + }, + { + name: "Cache hit and partitioned task list", + domainName: "testDomain", + taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindNormal.Ptr()}, + taskListType: 1, + forwardedFrom: "", + nPartitions: 3, + setupCache: func(mockCache *cache.MockCache) { + mockCache.EXPECT().Get(key{ + domainName: "testDomain", + taskListName: "testTaskList", + taskListType: 1, + }).Return(new(int64)) + }, + expectedResult: "/__cadence_sys/testTaskList/1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockCache := cache.NewMockCache(ctrl) + + // If the test requires setting up cache behavior, call setupCache + if tt.setupCache != nil { + tt.setupCache(mockCache) + } + + // Call the pickPartition function + result := pickPartition( + tt.domainName, + tt.taskList, + tt.taskListType, + tt.forwardedFrom, + tt.nPartitions, + mockCache, + ) + + // Assert that the result matches the expected result + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestPickWritePartition(t *testing.T) { + tests := []struct { + name string + domainID string + taskList types.TaskList + taskListType int + forwardedFrom string + domainIDToName func(string, *testing.T) (string, error) + nReadPartitions func(string, string, int, *testing.T) int + nWritePartitions func(string, string, int, *testing.T) int + pickPartitionFn func(string, types.TaskList, int, string, int, cache.Cache, *testing.T) string + expectedPartition string + expectError bool + }{ + { + name: "successful partition pick", + domainID: "testDomainID", + taskList: types.TaskList{Name: "testTaskList"}, + taskListType: 1, + forwardedFrom: "", + domainIDToName: func(domainID string, t *testing.T) (string, error) { + assert.Equal(t, "testDomainID", domainID) // Assert parameter with t + return "testDomainName", nil + }, + nReadPartitions: func(domainName, taskListName string, taskListType int, t *testing.T) int { + assert.Equal(t, "testDomainName", domainName) // Assert parameters with t + assert.Equal(t, "testTaskList", taskListName) + assert.Equal(t, 1, taskListType) + return 3 + }, + nWritePartitions: func(domainName, taskListName string, taskListType int, t *testing.T) int { + assert.Equal(t, "testDomainName", domainName) // Assert parameters with t + assert.Equal(t, "testTaskList", taskListName) + assert.Equal(t, 1, taskListType) + return 4 + }, + pickPartitionFn: func(domainName string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache, t *testing.T) string { + assert.Equal(t, "testDomainName", domainName) // Assert parameters with t + assert.Equal(t, "testTaskList", taskList.GetName()) + assert.Equal(t, 1, taskListType) + assert.Equal(t, "", forwardedFrom) + assert.Equal(t, 3, nPartitions) + return "partition1" + }, + expectedPartition: "partition1", + expectError: false, + }, + { + name: "domainIDToName returns error", + domainID: "badDomainID", + taskList: types.TaskList{Name: "testTaskList"}, + taskListType: 1, + forwardedFrom: "", + domainIDToName: func(domainID string, t *testing.T) (string, error) { + assert.Equal(t, "badDomainID", domainID) // Assert parameter with t + return "", fmt.Errorf("domain not found") + }, + expectedPartition: "testTaskList", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb := &roundRobinLoadBalancer{ + domainIDToName: func(domainID string) (string, error) { + return tt.domainIDToName(domainID, t) + }, + nReadPartitions: func(domainName, taskListName string, taskListType int) int { + return tt.nReadPartitions(domainName, taskListName, taskListType, t) + }, + nWritePartitions: func(domainName, taskListName string, taskListType int) int { + return tt.nWritePartitions(domainName, taskListName, taskListType, t) + }, + pickPartitionFn: func(domainName string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache) string { + return tt.pickPartitionFn(domainName, taskList, taskListType, forwardedFrom, nPartitions, partitionCache, t) + }, + writeCache: cache.New(&cache.Options{ + TTL: 0, + InitialCapacity: 100, + Pin: false, + MaxCount: 3000, + ActivelyEvict: false, + }), + } + + partition := lb.PickWritePartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom) + assert.Equal(t, tt.expectedPartition, partition) + }) + } +} + +func TestPickReadPartition(t *testing.T) { + tests := []struct { + name string + domainID string + taskList types.TaskList + taskListType int + forwardedFrom string + domainIDToName func(string, *testing.T) (string, error) + nReadPartitions func(string, string, int, *testing.T) int + nWritePartitions func(string, string, int, *testing.T) int + pickPartitionFn func(string, types.TaskList, int, string, int, cache.Cache, *testing.T) string + expectedPartition string + expectError bool + }{ + { + name: "successful partition pick", + domainID: "testDomainID", + taskList: types.TaskList{Name: "testTaskList"}, + taskListType: 1, + forwardedFrom: "", + domainIDToName: func(domainID string, t *testing.T) (string, error) { + assert.Equal(t, "testDomainID", domainID) // Assert parameter with t + return "testDomainName", nil + }, + nReadPartitions: func(domainName, taskListName string, taskListType int, t *testing.T) int { + assert.Equal(t, "testDomainName", domainName) // Assert parameters with t + assert.Equal(t, "testTaskList", taskListName) + assert.Equal(t, 1, taskListType) + return 3 + }, + nWritePartitions: func(domainName, taskListName string, taskListType int, t *testing.T) int { + assert.Equal(t, "testDomainName", domainName) // Assert parameters with t + assert.Equal(t, "testTaskList", taskListName) + assert.Equal(t, 1, taskListType) + return 4 + }, + pickPartitionFn: func(domainName string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache, t *testing.T) string { + assert.Equal(t, "testDomainName", domainName) // Assert parameters with t + assert.Equal(t, "testTaskList", taskList.GetName()) + assert.Equal(t, 1, taskListType) + assert.Equal(t, "", forwardedFrom) + assert.Equal(t, 3, nPartitions) + return "partition1" + }, + expectedPartition: "partition1", + expectError: false, + }, + { + name: "domainIDToName returns error", + domainID: "badDomainID", + taskList: types.TaskList{Name: "testTaskList"}, + taskListType: 1, + forwardedFrom: "", + domainIDToName: func(domainID string, t *testing.T) (string, error) { + assert.Equal(t, "badDomainID", domainID) // Assert parameter with t + return "", fmt.Errorf("domain not found") + }, + expectedPartition: "testTaskList", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb := &roundRobinLoadBalancer{ + domainIDToName: func(domainID string) (string, error) { + return tt.domainIDToName(domainID, t) + }, + nReadPartitions: func(domainName, taskListName string, taskListType int) int { + return tt.nReadPartitions(domainName, taskListName, taskListType, t) + }, + nWritePartitions: func(domainName, taskListName string, taskListType int) int { + return tt.nWritePartitions(domainName, taskListName, taskListType, t) + }, + pickPartitionFn: func(domainName string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache) string { + return tt.pickPartitionFn(domainName, taskList, taskListType, forwardedFrom, nPartitions, partitionCache, t) + }, + writeCache: cache.New(&cache.Options{ + TTL: 0, + InitialCapacity: 100, + Pin: false, + MaxCount: 3000, + ActivelyEvict: false, + }), + } + + partition := lb.PickReadPartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom) + assert.Equal(t, tt.expectedPartition, partition) + }) + } +} diff --git a/common/dynamicconfig/config.go b/common/dynamicconfig/config.go index de1462d0eb4..6a43188fc30 100644 --- a/common/dynamicconfig/config.go +++ b/common/dynamicconfig/config.go @@ -125,6 +125,9 @@ type MapPropertyFn func(opts ...FilterOption) map[string]interface{} // StringPropertyFnWithDomainFilter is a wrapper to get string property from dynamic config type StringPropertyFnWithDomainFilter func(domain string) string +// StringPropertyFnWithTaskListInfoFilters is a wrapper to get string property from dynamic config with domainID as filter +type StringPropertyFnWithTaskListInfoFilters func(domain string, taskList string, taskType int) string + // BoolPropertyFnWithDomainFilter is a wrapper to get bool property from dynamic config with domain as filter type BoolPropertyFnWithDomainFilter func(domain string) bool @@ -447,6 +450,25 @@ func (c *Collection) GetStringPropertyFilteredByDomain(key StringKey) StringProp } } +func (c *Collection) GetStringPropertyFilteredByTaskListInfo(key StringKey) StringPropertyFnWithTaskListInfoFilters { + return func(domain string, taskList string, taskType int) string { + filters := c.toFilterMap( + DomainFilter(domain), + TaskListFilter(taskList), + TaskTypeFilter(taskType), + ) + val, err := c.client.GetStringValue( + key, + filters, + ) + if err != nil { + c.logError(key, filters, err) + return key.DefaultString() + } + return val + } +} + // GetBoolPropertyFilteredByDomain gets property with domain filter and asserts that it's a bool func (c *Collection) GetBoolPropertyFilteredByDomain(key BoolKey) BoolPropertyFnWithDomainFilter { return func(domain string) bool { diff --git a/common/dynamicconfig/config_test.go b/common/dynamicconfig/config_test.go index d02b2b68d6a..4faba763ab6 100644 --- a/common/dynamicconfig/config_test.go +++ b/common/dynamicconfig/config_test.go @@ -109,6 +109,17 @@ func (s *configSuite) TestGetStringPropertyFnWithDomainFilter() { s.Equal("efg", value(domain)) } +func (s *configSuite) TestGetStringPropertyFnByTaskListInfo() { + key := TasklistLoadBalancerStrategy + domain := "testDomain" + taskList := "testTaskList" + taskType := 0 + value := s.cln.GetStringPropertyFilteredByTaskListInfo(key) + s.Equal(key.DefaultString(), value(domain, taskList, taskType)) + s.client.SetValue(key, "round-robin") + s.Equal("round-robin", value(domain, taskList, taskType)) +} + func (s *configSuite) TestGetStringPropertyFilteredByRatelimitKey() { key := FrontendGlobalRatelimiterMode ratelimitKey := "user:testDomain" diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index f17f10ab876..39cecaea1e1 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -2393,6 +2393,8 @@ const ( // Allowed filters: RatelimitKey (on global key, e.g. prefixed by collection name) FrontendGlobalRatelimiterMode + TasklistLoadBalancerStrategy + // LastStringKey must be the last one in this const group LastStringKey ) @@ -4740,6 +4742,12 @@ var StringKeys = map[StringKey]DynamicString{ DefaultValue: "disabled", Filters: []Filter{RatelimitKey}, }, + TasklistLoadBalancerStrategy: { + KeyName: "system.tasklistLoadBalancerStrategy", + Description: "TasklistLoadBalancerStrategy is the key for tasklist load balancer strategy", + DefaultValue: "random", // other options: "round-robin" + Filters: []Filter{DomainName, TaskListName, TaskType}, + }, } var DurationKeys = map[DurationKey]DynamicDuration{