diff --git a/client/clientfactory.go b/client/clientfactory.go index bf0e1c86385..339e8e0b8ac 100644 --- a/client/clientfactory.go +++ b/client/clientfactory.go @@ -72,6 +72,7 @@ type ( metricsClient metrics.Client dynConfig *dynamicconfig.Collection numberOfHistoryShards int + allIsolationGroups func() []string logger log.Logger } ) @@ -83,6 +84,7 @@ func NewRPCClientFactory( metricsClient metrics.Client, dc *dynamicconfig.Collection, numberOfHistoryShards int, + allIsolationGroups func() []string, logger log.Logger, ) Factory { return &rpcClientFactory{ @@ -91,6 +93,7 @@ func NewRPCClientFactory( metricsClient: metricsClient, dynConfig: dc, numberOfHistoryShards: numberOfHistoryShards, + allIsolationGroups: allIsolationGroups, logger: logger, } } @@ -155,10 +158,12 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout( defaultLoadBalancer := matching.NewLoadBalancer(partitionConfigProvider) roundRobinLoadBalancer := matching.NewRoundRobinLoadBalancer(partitionConfigProvider) weightedLoadBalancer := matching.NewWeightedLoadBalancer(roundRobinLoadBalancer, partitionConfigProvider, cf.logger) + igLoadBalancer := matching.NewIsolationLoadBalancer(weightedLoadBalancer, partitionConfigProvider, cf.allIsolationGroups) loadBalancers := map[string]matching.LoadBalancer{ "random": defaultLoadBalancer, "round-robin": roundRobinLoadBalancer, "weighted": weightedLoadBalancer, + "isolation": igLoadBalancer, } client := matching.NewClient( rawClient, diff --git a/client/matching/client.go b/client/matching/client.go index 43f39c34442..8205f8edf1a 100644 --- a/client/matching/client.go +++ b/client/matching/client.go @@ -60,10 +60,8 @@ func (c *clientImpl) AddActivityTask( opts ...yarpc.CallOption, ) (*types.AddActivityTaskResponse, error) { partition := c.loadBalancer.PickWritePartition( - request.GetDomainUUID(), - *request.GetTaskList(), persistence.TaskListTypeActivity, - request.GetForwardedFrom(), + request, ) originalTaskListName := request.TaskList.GetName() request.TaskList.Name = partition @@ -91,10 +89,8 @@ func (c *clientImpl) AddDecisionTask( opts ...yarpc.CallOption, ) (*types.AddDecisionTaskResponse, error) { partition := c.loadBalancer.PickWritePartition( - request.GetDomainUUID(), - *request.GetTaskList(), persistence.TaskListTypeDecision, - request.GetForwardedFrom(), + request, ) originalTaskListName := request.TaskList.GetName() request.TaskList.Name = partition @@ -122,10 +118,9 @@ func (c *clientImpl) PollForActivityTask( opts ...yarpc.CallOption, ) (*types.MatchingPollForActivityTaskResponse, error) { partition := c.loadBalancer.PickReadPartition( - request.GetDomainUUID(), - *request.PollRequest.GetTaskList(), persistence.TaskListTypeActivity, - request.GetForwardedFrom(), + request, + request.GetIsolationGroup(), ) originalTaskListName := request.PollRequest.GetTaskList().GetName() request.PollRequest.TaskList.Name = partition @@ -145,10 +140,8 @@ func (c *clientImpl) PollForActivityTask( resp.PartitionConfig, ) c.loadBalancer.UpdateWeight( - request.GetDomainUUID(), - *request.PollRequest.GetTaskList(), persistence.TaskListTypeActivity, - request.GetForwardedFrom(), + request, partition, resp.LoadBalancerHints, ) @@ -163,10 +156,9 @@ func (c *clientImpl) PollForDecisionTask( opts ...yarpc.CallOption, ) (*types.MatchingPollForDecisionTaskResponse, error) { partition := c.loadBalancer.PickReadPartition( - request.GetDomainUUID(), - *request.PollRequest.GetTaskList(), persistence.TaskListTypeDecision, - request.GetForwardedFrom(), + request, + request.GetIsolationGroup(), ) originalTaskListName := request.PollRequest.GetTaskList().GetName() request.PollRequest.TaskList.Name = partition @@ -186,10 +178,8 @@ func (c *clientImpl) PollForDecisionTask( resp.PartitionConfig, ) c.loadBalancer.UpdateWeight( - request.GetDomainUUID(), - *request.PollRequest.GetTaskList(), persistence.TaskListTypeDecision, - request.GetForwardedFrom(), + request, partition, resp.LoadBalancerHints, ) @@ -204,10 +194,9 @@ func (c *clientImpl) QueryWorkflow( opts ...yarpc.CallOption, ) (*types.QueryWorkflowResponse, error) { partition := c.loadBalancer.PickReadPartition( - request.GetDomainUUID(), - *request.GetTaskList(), persistence.TaskListTypeDecision, - request.GetForwardedFrom(), + request, + "", ) request.TaskList.Name = partition peer, err := c.peerResolver.FromTaskList(request.TaskList.GetName()) diff --git a/client/matching/client_test.go b/client/matching/client_test.go index 61091cf299e..2312201884b 100644 --- a/client/matching/client_test.go +++ b/client/matching/client_test.go @@ -160,7 +160,7 @@ func TestClient_withResponse(t *testing.T) { return c.AddActivityTask(context.Background(), testAddActivityTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickWritePartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "").Return(_testPartition) + balancer.EXPECT().PickWritePartition(persistence.TaskListTypeActivity, testAddActivityTaskRequest()).Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().AddActivityTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.AddActivityTaskResponse{}, nil) mp.EXPECT().UpdatePartitionConfig(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, nil) @@ -173,7 +173,7 @@ func TestClient_withResponse(t *testing.T) { return c.AddActivityTask(context.Background(), testAddActivityTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickWritePartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "").Return(_testPartition) + balancer.EXPECT().PickWritePartition(persistence.TaskListTypeActivity, testAddActivityTaskRequest()).Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", assert.AnError) }, wantError: true, @@ -184,7 +184,7 @@ func TestClient_withResponse(t *testing.T) { return c.AddActivityTask(context.Background(), testAddActivityTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickWritePartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "").Return(_testPartition) + balancer.EXPECT().PickWritePartition(persistence.TaskListTypeActivity, testAddActivityTaskRequest()).Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().AddActivityTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(nil, assert.AnError) }, @@ -196,7 +196,7 @@ func TestClient_withResponse(t *testing.T) { return c.AddDecisionTask(context.Background(), testAddDecisionTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickWritePartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickWritePartition(persistence.TaskListTypeDecision, testAddDecisionTaskRequest()).Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().AddDecisionTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.AddDecisionTaskResponse{}, nil) mp.EXPECT().UpdatePartitionConfig(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, nil) @@ -209,7 +209,7 @@ func TestClient_withResponse(t *testing.T) { return c.AddDecisionTask(context.Background(), testAddDecisionTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickWritePartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickWritePartition(persistence.TaskListTypeDecision, testAddDecisionTaskRequest()).Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", assert.AnError) }, wantError: true, @@ -220,7 +220,7 @@ func TestClient_withResponse(t *testing.T) { return c.AddDecisionTask(context.Background(), testAddDecisionTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickWritePartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickWritePartition(persistence.TaskListTypeDecision, testAddDecisionTaskRequest()).Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().AddDecisionTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(nil, assert.AnError) }, @@ -232,11 +232,11 @@ func TestClient_withResponse(t *testing.T) { return c.PollForActivityTask(context.Background(), testMatchingPollForActivityTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeActivity, testMatchingPollForActivityTaskRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().PollForActivityTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.MatchingPollForActivityTaskResponse{}, nil) mp.EXPECT().UpdatePartitionConfig(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, nil) - balancer.EXPECT().UpdateWeight(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "", _testPartition, nil) + balancer.EXPECT().UpdateWeight(persistence.TaskListTypeActivity, testMatchingPollForActivityTaskRequest(), _testPartition, nil) }, want: &types.MatchingPollForActivityTaskResponse{}, }, @@ -246,7 +246,7 @@ func TestClient_withResponse(t *testing.T) { return c.PollForActivityTask(context.Background(), testMatchingPollForActivityTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeActivity, testMatchingPollForActivityTaskRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", assert.AnError) }, want: nil, @@ -258,7 +258,7 @@ func TestClient_withResponse(t *testing.T) { return c.PollForActivityTask(context.Background(), testMatchingPollForActivityTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeActivity, testMatchingPollForActivityTaskRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().PollForActivityTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(nil, assert.AnError) }, @@ -271,11 +271,11 @@ func TestClient_withResponse(t *testing.T) { return c.PollForDecisionTask(context.Background(), testMatchingPollForDecisionTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeDecision, testMatchingPollForDecisionTaskRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().PollForDecisionTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.MatchingPollForDecisionTaskResponse{}, nil) mp.EXPECT().UpdatePartitionConfig(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, nil) - balancer.EXPECT().UpdateWeight(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "", _testPartition, nil) + balancer.EXPECT().UpdateWeight(persistence.TaskListTypeDecision, testMatchingPollForDecisionTaskRequest(), _testPartition, nil) }, want: &types.MatchingPollForDecisionTaskResponse{}, }, @@ -285,7 +285,7 @@ func TestClient_withResponse(t *testing.T) { return c.PollForDecisionTask(context.Background(), testMatchingPollForDecisionTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeDecision, testMatchingPollForDecisionTaskRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", assert.AnError) }, want: nil, @@ -297,7 +297,7 @@ func TestClient_withResponse(t *testing.T) { return c.PollForDecisionTask(context.Background(), testMatchingPollForDecisionTaskRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeDecision, testMatchingPollForDecisionTaskRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().PollForDecisionTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(nil, assert.AnError) }, @@ -310,7 +310,7 @@ func TestClient_withResponse(t *testing.T) { return c.QueryWorkflow(context.Background(), testMatchingQueryWorkflowRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeDecision, testMatchingQueryWorkflowRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().QueryWorkflow(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.QueryWorkflowResponse{}, nil) }, @@ -322,7 +322,7 @@ func TestClient_withResponse(t *testing.T) { return c.QueryWorkflow(context.Background(), testMatchingQueryWorkflowRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeDecision, testMatchingQueryWorkflowRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", assert.AnError) }, want: nil, @@ -334,7 +334,7 @@ func TestClient_withResponse(t *testing.T) { return c.QueryWorkflow(context.Background(), testMatchingQueryWorkflowRequest()) }, mock: func(p *MockPeerResolver, balancer *MockLoadBalancer, c *MockClient, mp *MockPartitionConfigProvider) { - balancer.EXPECT().PickReadPartition(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "").Return(_testPartition) + balancer.EXPECT().PickReadPartition(persistence.TaskListTypeDecision, testMatchingQueryWorkflowRequest(), "").Return(_testPartition) p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil) c.EXPECT().QueryWorkflow(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(nil, assert.AnError) }, diff --git a/client/matching/isolation_loadbalancer.go b/client/matching/isolation_loadbalancer.go new file mode 100644 index 00000000000..5c79c295b5b --- /dev/null +++ b/client/matching/isolation_loadbalancer.go @@ -0,0 +1,147 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package matching + +import ( + "math/rand" + "slices" + + "golang.org/x/exp/maps" + + "github.com/uber/cadence/common/partition" + "github.com/uber/cadence/common/types" +) + +type isolationLoadBalancer struct { + provider PartitionConfigProvider + fallback LoadBalancer + allIsolationGroups func() []string +} + +func NewIsolationLoadBalancer(fallback LoadBalancer, provider PartitionConfigProvider, allIsolationGroups func() []string) LoadBalancer { + return &isolationLoadBalancer{ + provider: provider, + fallback: fallback, + allIsolationGroups: allIsolationGroups, + } +} + +func (i *isolationLoadBalancer) PickWritePartition(taskListType int, req WriteRequest) string { + taskList := *req.GetTaskList() + nPartitions := i.provider.GetNumberOfWritePartitions(req.GetDomainUUID(), taskList, taskListType) + taskListName := req.GetTaskList().Name + + if nPartitions <= 1 { + return taskListName + } + + taskGroup, ok := req.GetPartitionConfig()[partition.IsolationGroupKey] + if !ok { + return i.fallback.PickWritePartition(taskListType, req) + } + + partitions, ok := i.getPartitionsForGroup(taskGroup, nPartitions) + if !ok { + return i.fallback.PickWritePartition(taskListType, req) + } + + p := i.pickBetween(partitions) + + return getPartitionTaskListName(taskList.GetName(), p) +} + +func (i *isolationLoadBalancer) PickReadPartition(taskListType int, req ReadRequest, isolationGroup string) string { + taskList := *req.GetTaskList() + nRead := i.provider.GetNumberOfReadPartitions(req.GetDomainUUID(), taskList, taskListType) + taskListName := taskList.Name + + if nRead <= 1 { + return taskListName + } + + partitions, ok := i.getPartitionsForGroup(isolationGroup, nRead) + if !ok { + return i.fallback.PickReadPartition(taskListType, req, isolationGroup) + } + + // Scaling down, we need to consider both sets of partitions + if numWrite := i.provider.GetNumberOfWritePartitions(req.GetDomainUUID(), taskList, taskListType); numWrite != nRead { + writePartitions, ok := i.getPartitionsForGroup(isolationGroup, numWrite) + if ok { + for p := range writePartitions { + partitions[p] = struct{}{} + } + } + } + + p := i.pickBetween(partitions) + + return getPartitionTaskListName(taskList.GetName(), p) +} + +func (i *isolationLoadBalancer) UpdateWeight(taskListType int, req ReadRequest, partition string, info *types.LoadBalancerHints) { +} + +func (i *isolationLoadBalancer) getPartitionsForGroup(taskGroup string, partitionCount int) (map[int]any, bool) { + if taskGroup == "" { + return nil, false + } + isolationGroups := slices.Clone(i.allIsolationGroups()) + slices.Sort(isolationGroups) + index := slices.Index(isolationGroups, taskGroup) + if index == -1 { + return nil, false + } + partitions := make(map[int]any, 1) + // 3 groups [a, b, c] and 4 partitions gives us a mapping like this: + // 0, 3: a + // 1: b + // 2: c + // 4 groups [a, b, c, d] and 10 partitions gives us a mapping like this: + // 0, 4, 8: a + // 1, 5, 9: b + // 2, 6: c + // 3, 7: d + if len(isolationGroups) <= partitionCount { + for j := index; j < partitionCount; j += len(isolationGroups) { + partitions[j] = struct{}{} + } + // 4 groups [a,b,c,d] and 3 partitions gives us a mapping like this: + // 0: a, d + // 1: b + // 2: c + } else { + partitions[index%partitionCount] = struct{}{} + } + if len(partitions) == 0 { + return nil, false + } + return partitions, true +} + +func (i *isolationLoadBalancer) pickBetween(partitions map[int]any) int { + // Could alternatively use backlog weights to make a smarter choice + total := len(partitions) + picked := rand.Intn(total) + return maps.Keys(partitions)[picked] +} diff --git a/client/matching/isolation_loadbalancer_test.go b/client/matching/isolation_loadbalancer_test.go new file mode 100644 index 00000000000..a5a6cd0ef0e --- /dev/null +++ b/client/matching/isolation_loadbalancer_test.go @@ -0,0 +1,283 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package matching + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common/partition" + "github.com/uber/cadence/common/types" +) + +func TestIsolationPickWritePartition(t *testing.T) { + tl := "tl" + cases := []struct { + name string + group string + isolationGroups []string + numWrite int + shouldFallback bool + allowed []string + }{ + { + name: "single partition", + group: "a", + numWrite: 1, + isolationGroups: []string{"a"}, + allowed: []string{tl}, + }, + { + name: "multiple partitions - single option", + group: "b", + numWrite: 2, + isolationGroups: []string{"a", "b"}, + allowed: []string{getPartitionTaskListName(tl, 1)}, + }, + { + name: "multiple partitions - multiple options", + group: "a", + numWrite: 2, + isolationGroups: []string{"a"}, + allowed: []string{tl, getPartitionTaskListName(tl, 1)}, + }, + { + name: "fallback - no group", + numWrite: 2, + isolationGroups: []string{"a"}, + shouldFallback: true, + allowed: []string{"fallback"}, + }, + { + name: "fallback - no groups", + group: "a", + numWrite: 2, + isolationGroups: []string{""}, + shouldFallback: true, + allowed: []string{"fallback"}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lb, fallback := createWithMocks(t, tc.isolationGroups, tc.numWrite, tc.numWrite) + req := &types.AddDecisionTaskRequest{ + DomainUUID: "domainId", + TaskList: &types.TaskList{ + Name: tl, + Kind: types.TaskListKindSticky.Ptr(), + }, + } + if tc.group != "" { + req.PartitionConfig = map[string]string{ + partition.IsolationGroupKey: tc.group, + } + } + if tc.shouldFallback { + fallback.EXPECT().PickWritePartition(int(types.TaskListTypeDecision), req).Return("fallback").Times(1) + } + p := lb.PickWritePartition(0, req) + assert.Contains(t, tc.allowed, p) + }) + } +} + +func TestIsolationPickReadPartition(t *testing.T) { + tl := "tl" + cases := []struct { + name string + group string + isolationGroups []string + numRead int + numWrite int + shouldFallback bool + allowed []string + }{ + { + name: "single partition", + group: "a", + numRead: 1, + numWrite: 1, + isolationGroups: []string{"a"}, + allowed: []string{tl}, + }, + { + name: "multiple partitions - single option", + group: "b", + numRead: 2, + numWrite: 2, + isolationGroups: []string{"a", "b"}, + allowed: []string{getPartitionTaskListName(tl, 1)}, + }, + { + name: "multiple partitions - multiple options", + group: "a", + numRead: 2, + numWrite: 2, + isolationGroups: []string{"a"}, + allowed: []string{tl, getPartitionTaskListName(tl, 1)}, + }, + { + name: "scaling - multiple options", + group: "d", + numRead: 4, + numWrite: 3, + isolationGroups: []string{"a", "b", "c", "d"}, + // numRead = 4 means tasks for d could be in the last partition (idx=3) + // numWrite = 3 means new tasks for d are being written to the root (idx=0) + allowed: []string{tl, getPartitionTaskListName(tl, 3)}, + }, + { + name: "fallback - no group", + numRead: 2, + numWrite: 2, + isolationGroups: []string{"a"}, + shouldFallback: true, + allowed: []string{"fallback"}, + }, + { + name: "fallback - no groups", + group: "a", + numRead: 2, + numWrite: 2, + isolationGroups: []string{""}, + shouldFallback: true, + allowed: []string{"fallback"}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lb, fallback := createWithMocks(t, tc.isolationGroups, tc.numWrite, tc.numRead) + req := &types.MatchingQueryWorkflowRequest{ + DomainUUID: "domainId", + TaskList: &types.TaskList{ + Name: tl, + Kind: types.TaskListKindSticky.Ptr(), + }, + } + if tc.shouldFallback { + fallback.EXPECT().PickReadPartition(int(types.TaskListTypeDecision), req, tc.group).Return("fallback").Times(1) + } + p := lb.PickReadPartition(0, req, tc.group) + assert.Contains(t, tc.allowed, p) + }) + } +} + +func TestIsolationGetPartitionsForGroup(t *testing.T) { + cases := []struct { + name string + group string + isolationGroups []string + partitions int + expected []int + }{ + { + name: "single partition", + group: "a", + isolationGroups: []string{"a", "b", "c"}, + partitions: 1, + expected: []int{0}, + }, + { + name: "partitions less than groups", + group: "b", + isolationGroups: []string{"a", "b", "c"}, + partitions: 2, + expected: []int{1}, + }, + { + name: "partitions equals groups", + group: "c", + isolationGroups: []string{"a", "b", "c"}, + partitions: 3, + expected: []int{2}, + }, + { + name: "partitions greater than groups", + group: "c", + isolationGroups: []string{"a", "b", "c"}, + partitions: 4, + expected: []int{2}, + }, + { + name: "partitions greater than groups - multiple assigned", + group: "a", + isolationGroups: []string{"a", "b", "c"}, + partitions: 4, + expected: []int{0, 3}, + }, + { + name: "not ok - no isolation group", + group: "", + isolationGroups: []string{"a"}, + partitions: 4, + }, + { + name: "not ok - no isolation groups", + group: "a", + isolationGroups: []string{}, + partitions: 4, + }, + { + name: "not ok - unknown isolation group", + group: "d", + isolationGroups: []string{"a", "b", "c"}, + partitions: 4, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lb, _ := createWithMocks(t, tc.isolationGroups, tc.partitions, tc.partitions) + actual, ok := lb.getPartitionsForGroup(tc.group, tc.partitions) + if tc.expected == nil { + assert.Nil(t, actual) + assert.False(t, ok) + } else { + expectedSet := make(map[int]any, len(tc.expected)) + for _, expectedPartition := range tc.expected { + expectedSet[expectedPartition] = struct{}{} + } + assert.Equal(t, expectedSet, actual) + assert.True(t, ok) + } + }) + } +} + +func createWithMocks(t *testing.T, isolationGroups []string, writePartitions, readPartitions int) (*isolationLoadBalancer, *MockLoadBalancer) { + ctrl := gomock.NewController(t) + fallback := NewMockLoadBalancer(ctrl) + cfg := NewMockPartitionConfigProvider(ctrl) + cfg.EXPECT().GetNumberOfWritePartitions(gomock.Any(), gomock.Any(), gomock.Any()).Return(writePartitions).AnyTimes() + cfg.EXPECT().GetNumberOfReadPartitions(gomock.Any(), gomock.Any(), gomock.Any()).Return(readPartitions).AnyTimes() + allIsolationGroups := func() []string { + return isolationGroups + } + return &isolationLoadBalancer{ + provider: cfg, + fallback: fallback, + allIsolationGroups: allIsolationGroups, + }, fallback +} diff --git a/client/matching/loadbalancer.go b/client/matching/loadbalancer.go index 9a051815e30..cea7adac0bd 100644 --- a/client/matching/loadbalancer.go +++ b/client/matching/loadbalancer.go @@ -25,13 +25,23 @@ package matching import ( "fmt" "math/rand" - "strings" "github.com/uber/cadence/common" "github.com/uber/cadence/common/types" ) type ( + // WriteRequest is the interface for all types of AddTask* requests + WriteRequest interface { + ReadRequest + GetPartitionConfig() map[string]string + } + // ReadRequest is the interface for all types of Poll* requests + ReadRequest interface { + GetDomainUUID() string + GetTaskList() *types.TaskList + GetForwardedFrom() string + } // LoadBalancer is the interface for implementers of // component that distributes add/poll api calls across // available task list partitions when possible @@ -43,30 +53,25 @@ type ( // to a parent partition in which case, no load balancing should be // performed PickWritePartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + request WriteRequest, ) string // PickReadPartition returns the task list partition to send a poller to. // Input is name of the original task list as specified by caller. When // forwardedFrom is non-empty, no load balancing should be done. PickReadPartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + request ReadRequest, + isolationGroup string, ) string // UpdateWeight updates the weight of a task list partition. // Input is name of the original task list as specified by caller. When // the original task list is a partition, no update should be done. UpdateWeight( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + request ReadRequest, partition string, info *types.LoadBalancerHints, ) @@ -88,24 +93,18 @@ func NewLoadBalancer( } func (lb *defaultLoadBalancer) PickWritePartition( - domainID string, - taskList types.TaskList, - taskListType int, - forwardedFrom string, + taskListType int, req WriteRequest, ) string { - nPartitions := lb.provider.GetNumberOfWritePartitions(domainID, taskList, taskListType) - return lb.pickPartition(taskList, forwardedFrom, nPartitions) + nPartitions := lb.provider.GetNumberOfWritePartitions(req.GetDomainUUID(), *req.GetTaskList(), taskListType) + return lb.pickPartition(*req.GetTaskList(), req.GetForwardedFrom(), nPartitions) } func (lb *defaultLoadBalancer) PickReadPartition( - domainID string, - taskList types.TaskList, - taskListType int, - forwardedFrom string, + taskListType int, req ReadRequest, _ string, ) string { - n := lb.provider.GetNumberOfReadPartitions(domainID, taskList, taskListType) - return lb.pickPartition(taskList, forwardedFrom, n) + n := lb.provider.GetNumberOfReadPartitions(req.GetDomainUUID(), *req.GetTaskList(), taskListType) + return lb.pickPartition(*req.GetTaskList(), req.GetForwardedFrom(), n) } @@ -115,16 +114,7 @@ func (lb *defaultLoadBalancer) pickPartition( nPartitions int, ) string { - if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky { - return taskList.GetName() - } - - if strings.HasPrefix(taskList.GetName(), common.ReservedTaskListPrefix) { - // this should never happen when forwardedFrom is empty - return taskList.GetName() - } - - if nPartitions <= 0 { + if nPartitions <= 1 { return taskList.GetName() } @@ -133,10 +123,8 @@ func (lb *defaultLoadBalancer) pickPartition( } func (lb *defaultLoadBalancer) UpdateWeight( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req ReadRequest, partition string, info *types.LoadBalancerHints, ) { diff --git a/client/matching/loadbalancer_mock.go b/client/matching/loadbalancer_mock.go index 1edc26d30b1..0bed5bbaa93 100644 --- a/client/matching/loadbalancer_mock.go +++ b/client/matching/loadbalancer_mock.go @@ -34,6 +34,150 @@ import ( types "github.com/uber/cadence/common/types" ) +// MockWriteRequest is a mock of WriteRequest interface. +type MockWriteRequest struct { + ctrl *gomock.Controller + recorder *MockWriteRequestMockRecorder +} + +// MockWriteRequestMockRecorder is the mock recorder for MockWriteRequest. +type MockWriteRequestMockRecorder struct { + mock *MockWriteRequest +} + +// NewMockWriteRequest creates a new mock instance. +func NewMockWriteRequest(ctrl *gomock.Controller) *MockWriteRequest { + mock := &MockWriteRequest{ctrl: ctrl} + mock.recorder = &MockWriteRequestMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWriteRequest) EXPECT() *MockWriteRequestMockRecorder { + return m.recorder +} + +// GetDomainUUID mocks base method. +func (m *MockWriteRequest) GetDomainUUID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDomainUUID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDomainUUID indicates an expected call of GetDomainUUID. +func (mr *MockWriteRequestMockRecorder) GetDomainUUID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDomainUUID", reflect.TypeOf((*MockWriteRequest)(nil).GetDomainUUID)) +} + +// GetForwardedFrom mocks base method. +func (m *MockWriteRequest) GetForwardedFrom() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetForwardedFrom") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetForwardedFrom indicates an expected call of GetForwardedFrom. +func (mr *MockWriteRequestMockRecorder) GetForwardedFrom() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForwardedFrom", reflect.TypeOf((*MockWriteRequest)(nil).GetForwardedFrom)) +} + +// GetPartitionConfig mocks base method. +func (m *MockWriteRequest) GetPartitionConfig() map[string]string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPartitionConfig") + ret0, _ := ret[0].(map[string]string) + return ret0 +} + +// GetPartitionConfig indicates an expected call of GetPartitionConfig. +func (mr *MockWriteRequestMockRecorder) GetPartitionConfig() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPartitionConfig", reflect.TypeOf((*MockWriteRequest)(nil).GetPartitionConfig)) +} + +// GetTaskList mocks base method. +func (m *MockWriteRequest) GetTaskList() *types.TaskList { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskList") + ret0, _ := ret[0].(*types.TaskList) + return ret0 +} + +// GetTaskList indicates an expected call of GetTaskList. +func (mr *MockWriteRequestMockRecorder) GetTaskList() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskList", reflect.TypeOf((*MockWriteRequest)(nil).GetTaskList)) +} + +// MockReadRequest is a mock of ReadRequest interface. +type MockReadRequest struct { + ctrl *gomock.Controller + recorder *MockReadRequestMockRecorder +} + +// MockReadRequestMockRecorder is the mock recorder for MockReadRequest. +type MockReadRequestMockRecorder struct { + mock *MockReadRequest +} + +// NewMockReadRequest creates a new mock instance. +func NewMockReadRequest(ctrl *gomock.Controller) *MockReadRequest { + mock := &MockReadRequest{ctrl: ctrl} + mock.recorder = &MockReadRequestMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReadRequest) EXPECT() *MockReadRequestMockRecorder { + return m.recorder +} + +// GetDomainUUID mocks base method. +func (m *MockReadRequest) GetDomainUUID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDomainUUID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDomainUUID indicates an expected call of GetDomainUUID. +func (mr *MockReadRequestMockRecorder) GetDomainUUID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDomainUUID", reflect.TypeOf((*MockReadRequest)(nil).GetDomainUUID)) +} + +// GetForwardedFrom mocks base method. +func (m *MockReadRequest) GetForwardedFrom() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetForwardedFrom") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetForwardedFrom indicates an expected call of GetForwardedFrom. +func (mr *MockReadRequestMockRecorder) GetForwardedFrom() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForwardedFrom", reflect.TypeOf((*MockReadRequest)(nil).GetForwardedFrom)) +} + +// GetTaskList mocks base method. +func (m *MockReadRequest) GetTaskList() *types.TaskList { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskList") + ret0, _ := ret[0].(*types.TaskList) + return ret0 +} + +// GetTaskList indicates an expected call of GetTaskList. +func (mr *MockReadRequestMockRecorder) GetTaskList() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskList", reflect.TypeOf((*MockReadRequest)(nil).GetTaskList)) +} + // MockLoadBalancer is a mock of LoadBalancer interface. type MockLoadBalancer struct { ctrl *gomock.Controller @@ -58,41 +202,41 @@ func (m *MockLoadBalancer) EXPECT() *MockLoadBalancerMockRecorder { } // PickReadPartition mocks base method. -func (m *MockLoadBalancer) PickReadPartition(domainID string, taskList types.TaskList, taskListType int, forwardedFrom string) string { +func (m *MockLoadBalancer) PickReadPartition(taskListType int, request ReadRequest, isolationGroup string) string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PickReadPartition", domainID, taskList, taskListType, forwardedFrom) + ret := m.ctrl.Call(m, "PickReadPartition", taskListType, request, isolationGroup) ret0, _ := ret[0].(string) return ret0 } // PickReadPartition indicates an expected call of PickReadPartition. -func (mr *MockLoadBalancerMockRecorder) PickReadPartition(domainID, taskList, taskListType, forwardedFrom interface{}) *gomock.Call { +func (mr *MockLoadBalancerMockRecorder) PickReadPartition(taskListType, request, isolationGroup interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PickReadPartition", reflect.TypeOf((*MockLoadBalancer)(nil).PickReadPartition), domainID, taskList, taskListType, forwardedFrom) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PickReadPartition", reflect.TypeOf((*MockLoadBalancer)(nil).PickReadPartition), taskListType, request, isolationGroup) } // PickWritePartition mocks base method. -func (m *MockLoadBalancer) PickWritePartition(domainID string, taskList types.TaskList, taskListType int, forwardedFrom string) string { +func (m *MockLoadBalancer) PickWritePartition(taskListType int, request WriteRequest) string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PickWritePartition", domainID, taskList, taskListType, forwardedFrom) + ret := m.ctrl.Call(m, "PickWritePartition", taskListType, request) ret0, _ := ret[0].(string) return ret0 } // PickWritePartition indicates an expected call of PickWritePartition. -func (mr *MockLoadBalancerMockRecorder) PickWritePartition(domainID, taskList, taskListType, forwardedFrom interface{}) *gomock.Call { +func (mr *MockLoadBalancerMockRecorder) PickWritePartition(taskListType, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PickWritePartition", reflect.TypeOf((*MockLoadBalancer)(nil).PickWritePartition), domainID, taskList, taskListType, forwardedFrom) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PickWritePartition", reflect.TypeOf((*MockLoadBalancer)(nil).PickWritePartition), taskListType, request) } // UpdateWeight mocks base method. -func (m *MockLoadBalancer) UpdateWeight(domainID string, taskList types.TaskList, taskListType int, forwardedFrom, partition string, info *types.LoadBalancerHints) { +func (m *MockLoadBalancer) UpdateWeight(taskListType int, request ReadRequest, partition string, info *types.LoadBalancerHints) { m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateWeight", domainID, taskList, taskListType, forwardedFrom, partition, info) + m.ctrl.Call(m, "UpdateWeight", taskListType, request, partition, info) } // UpdateWeight indicates an expected call of UpdateWeight. -func (mr *MockLoadBalancerMockRecorder) UpdateWeight(domainID, taskList, taskListType, forwardedFrom, partition, info interface{}) *gomock.Call { +func (mr *MockLoadBalancerMockRecorder) UpdateWeight(taskListType, request, partition, info interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWeight", reflect.TypeOf((*MockLoadBalancer)(nil).UpdateWeight), domainID, taskList, taskListType, forwardedFrom, partition, info) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWeight", reflect.TypeOf((*MockLoadBalancer)(nil).UpdateWeight), taskListType, request, partition, info) } diff --git a/client/matching/loadbalancer_test.go b/client/matching/loadbalancer_test.go index 7b2c0eb47d5..8d9a33af194 100644 --- a/client/matching/loadbalancer_test.go +++ b/client/matching/loadbalancer_test.go @@ -28,7 +28,6 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "github.com/uber/cadence/common" "github.com/uber/cadence/common/types" ) @@ -66,14 +65,6 @@ func Test_defaultLoadBalancer_PickWritePartition(t *testing.T) { taskListKind: types.TaskListKindNormal, expectedPartitions: []string{"test-task-list", "/__cadence_sys/test-task-list/1", "/__cadence_sys/test-task-list/2"}, }, - { - name: "sticky task list", - forwardedFrom: "", - taskListType: 0, - nPartitions: 3, - taskListKind: types.TaskListKindSticky, - expectedPartitions: []string{"test-task-list"}, - }, } for _, tc := range testCases { @@ -87,9 +78,12 @@ func Test_defaultLoadBalancer_PickWritePartition(t *testing.T) { Times(1) // Pick write partition - kind := tc.taskListKind - taskList := types.TaskList{Name: "test-task-list", Kind: &kind} - partition := loadBalancer.PickWritePartition("test-domain-id", taskList, tc.taskListType, tc.forwardedFrom) + req := &types.AddDecisionTaskRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{Name: "test-task-list", Kind: &tc.taskListKind}, + ForwardedFrom: tc.forwardedFrom, + } + partition := loadBalancer.PickWritePartition(tc.taskListType, req) // Validate result assert.Contains(t, tc.expectedPartitions, partition) @@ -122,14 +116,6 @@ func Test_defaultLoadBalancer_PickReadPartition(t *testing.T) { taskListKind: types.TaskListKindNormal, expectedPartitions: []string{"test-task-list", "/__cadence_sys/test-task-list/1", "/__cadence_sys/test-task-list/2"}, }, - { - name: "sticky task list", - forwardedFrom: "", - taskListType: 0, - nPartitions: 3, - taskListKind: types.TaskListKindSticky, - expectedPartitions: []string{"test-task-list"}, - }, } for _, tc := range testCases { @@ -143,9 +129,12 @@ func Test_defaultLoadBalancer_PickReadPartition(t *testing.T) { Times(1) // Pick read partition - kind := tc.taskListKind - taskList := types.TaskList{Name: "test-task-list", Kind: &kind} - partition := loadBalancer.PickReadPartition("test-domain-id", taskList, tc.taskListType, tc.forwardedFrom) + req := &types.AddDecisionTaskRequest{ + DomainUUID: "test-domain-id", + TaskList: &types.TaskList{Name: "test-task-list", Kind: &tc.taskListKind}, + ForwardedFrom: tc.forwardedFrom, + } + partition := loadBalancer.PickReadPartition(tc.taskListType, req, "") // Validate result assert.Contains(t, tc.expectedPartitions, partition) @@ -161,7 +150,11 @@ func Test_defaultLoadBalancer_UpdateWeight(t *testing.T) { taskList := types.TaskList{Name: "test-task-list", Kind: types.TaskListKindNormal.Ptr()} // Call UpdateWeight, should do nothing - loadBalancer.UpdateWeight("test-domain-id", taskList, 0, "", "partition", nil) + req := &types.AddDecisionTaskRequest{ + DomainUUID: "test-domain-id", + TaskList: &taskList, + } + loadBalancer.UpdateWeight(0, req, "partition", nil) // No expectations, just ensure no-op }) @@ -178,42 +171,6 @@ func Test_defaultLoadBalancer_pickPartition(t *testing.T) { args args want string }{ - { - name: "Test: ForwardedFrom not empty", - args: args{ - taskList: types.TaskList{ - Name: "taskList1", - Kind: types.TaskListKindSticky.Ptr(), - }, - forwardedFrom: "forwardedFromVal", - nPartitions: 10, - }, - want: "taskList1", - }, - { - name: "Test: TaskList kind is Sticky", - args: args{ - taskList: types.TaskList{ - Name: "taskList2", - Kind: types.TaskListKindSticky.Ptr(), - }, - forwardedFrom: "", - nPartitions: 10, - }, - want: "taskList2", - }, - { - name: "Test: TaskList name starts with ReservedTaskListPrefix", - args: args{ - taskList: types.TaskList{ - Name: common.ReservedTaskListPrefix + "taskList3", - Kind: types.TaskListKindNormal.Ptr(), - }, - forwardedFrom: "", - nPartitions: 10, - }, - want: common.ReservedTaskListPrefix + "taskList3", - }, { name: "Test: nPartitions <= 0", args: args{ diff --git a/client/matching/multi_loadbalancer.go b/client/matching/multi_loadbalancer.go index 555f59eef7b..bf521476210 100644 --- a/client/matching/multi_loadbalancer.go +++ b/client/matching/multi_loadbalancer.go @@ -23,6 +23,9 @@ package matching import ( + "strings" + + "github.com/uber/cadence/common" "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" @@ -56,60 +59,68 @@ func NewMultiLoadBalancer( } func (lb *multiLoadBalancer) PickWritePartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req WriteRequest, ) string { - domainName, err := lb.domainIDToName(domainID) + if !lb.canRedirectToPartition(req) { + return req.GetTaskList().GetName() + } + domainName, err := lb.domainIDToName(req.GetDomainUUID()) if err != nil { - return lb.defaultLoadBalancer.PickWritePartition(domainID, taskList, taskListType, forwardedFrom) + return lb.defaultLoadBalancer.PickWritePartition(taskListType, req) } - strategy := lb.loadbalancerStrategy(domainName, taskList.GetName(), taskListType) + strategy := lb.loadbalancerStrategy(domainName, req.GetTaskList().GetName(), taskListType) loadBalancer, ok := lb.loadBalancers[strategy] if !ok { lb.logger.Warn("unsupported load balancer strategy", tag.Value(strategy)) - return lb.defaultLoadBalancer.PickWritePartition(domainID, taskList, taskListType, forwardedFrom) + return lb.defaultLoadBalancer.PickWritePartition(taskListType, req) } - return loadBalancer.PickWritePartition(domainID, taskList, taskListType, forwardedFrom) + return loadBalancer.PickWritePartition(taskListType, req) } func (lb *multiLoadBalancer) PickReadPartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req ReadRequest, + isolationGroup string, ) string { - domainName, err := lb.domainIDToName(domainID) + if !lb.canRedirectToPartition(req) { + return req.GetTaskList().GetName() + } + domainName, err := lb.domainIDToName(req.GetDomainUUID()) if err != nil { - return lb.defaultLoadBalancer.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + return lb.defaultLoadBalancer.PickReadPartition(taskListType, req, isolationGroup) } - strategy := lb.loadbalancerStrategy(domainName, taskList.GetName(), taskListType) + strategy := lb.loadbalancerStrategy(domainName, req.GetTaskList().GetName(), taskListType) loadBalancer, ok := lb.loadBalancers[strategy] if !ok { lb.logger.Warn("unsupported load balancer strategy", tag.Value(strategy)) - return lb.defaultLoadBalancer.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + return lb.defaultLoadBalancer.PickReadPartition(taskListType, req, isolationGroup) } - return loadBalancer.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + return loadBalancer.PickReadPartition(taskListType, req, isolationGroup) } func (lb *multiLoadBalancer) UpdateWeight( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req ReadRequest, partition string, info *types.LoadBalancerHints, ) { - domainName, err := lb.domainIDToName(domainID) + if !lb.canRedirectToPartition(req) { + return + } + domainName, err := lb.domainIDToName(req.GetDomainUUID()) if err != nil { return } - strategy := lb.loadbalancerStrategy(domainName, taskList.GetName(), taskListType) + strategy := lb.loadbalancerStrategy(domainName, req.GetTaskList().GetName(), taskListType) loadBalancer, ok := lb.loadBalancers[strategy] if !ok { lb.logger.Warn("unsupported load balancer strategy", tag.Value(strategy)) return } - loadBalancer.UpdateWeight(domainID, taskList, taskListType, forwardedFrom, partition, info) + loadBalancer.UpdateWeight(taskListType, req, partition, info) +} + +func (lb *multiLoadBalancer) canRedirectToPartition(req ReadRequest) bool { + return req.GetForwardedFrom() == "" && req.GetTaskList().GetKind() != types.TaskListKindSticky && !strings.HasPrefix(req.GetTaskList().GetName(), common.ReservedTaskListPrefix) } diff --git a/client/matching/multi_loadbalancer_test.go b/client/matching/multi_loadbalancer_test.go index 9a7f0939487..6b242e65f8c 100644 --- a/client/matching/multi_loadbalancer_test.go +++ b/client/matching/multi_loadbalancer_test.go @@ -105,18 +105,50 @@ func TestMultiLoadBalancer_PickWritePartition(t *testing.T) { loadbalancerStrategy: "invalid-enum", expectedPartition: "random-partition", }, + { + name: "cannot repartition - forwarded", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "somewhere", + loadbalancerStrategy: "random", + expectedPartition: "test-tasklist", + }, + { + name: "cannot repartition - sticky", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist", Kind: types.TaskListKindSticky.Ptr()}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "random", + expectedPartition: "test-tasklist", + }, + { + name: "cannot repartition - partition", + domainID: "valid-domain", + taskList: types.TaskList{Name: "/__cadence_sys/test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "random", + expectedPartition: "/__cadence_sys/test-tasklist", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + req := &types.AddDecisionTaskRequest{ + DomainUUID: tt.domainID, + TaskList: &tt.taskList, + ForwardedFrom: tt.forwardedFrom, + } // Mock behavior for random and round robin load balancers ctrl := gomock.NewController(t) // Mock the LoadBalancer interface randomMock := NewMockLoadBalancer(ctrl) roundRobinMock := NewMockLoadBalancer(ctrl) - randomMock.EXPECT().PickWritePartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("random-partition").AnyTimes() - roundRobinMock.EXPECT().PickWritePartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("roundrobin-partition").AnyTimes() + randomMock.EXPECT().PickWritePartition(tt.taskListType, req).Return("random-partition").AnyTimes() + roundRobinMock.EXPECT().PickWritePartition(tt.taskListType, req).Return("roundrobin-partition").AnyTimes() loadbalancerStrategyFn := func(domainName, taskListName string, taskListType int) string { return tt.loadbalancerStrategy @@ -134,7 +166,7 @@ func TestMultiLoadBalancer_PickWritePartition(t *testing.T) { } // Call PickWritePartition and assert result - partition := lb.PickWritePartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom) + partition := lb.PickWritePartition(tt.taskListType, req) assert.Equal(t, tt.expectedPartition, partition) }) } @@ -187,18 +219,50 @@ func TestMultiLoadBalancer_PickReadPartition(t *testing.T) { loadbalancerStrategy: "invalid-enum", expectedPartition: "random-partition", }, + { + name: "cannot repartition - forwarded", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist"}, + taskListType: 1, + forwardedFrom: "somewhere", + loadbalancerStrategy: "random", + expectedPartition: "test-tasklist", + }, + { + name: "cannot repartition - sticky", + domainID: "valid-domain", + taskList: types.TaskList{Name: "test-tasklist", Kind: types.TaskListKindSticky.Ptr()}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "random", + expectedPartition: "test-tasklist", + }, + { + name: "cannot repartition - partition", + domainID: "valid-domain", + taskList: types.TaskList{Name: "/__cadence_sys/test-tasklist"}, + taskListType: 1, + forwardedFrom: "", + loadbalancerStrategy: "random", + expectedPartition: "/__cadence_sys/test-tasklist", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + req := &types.AddDecisionTaskRequest{ + DomainUUID: tt.domainID, + TaskList: &tt.taskList, + ForwardedFrom: tt.forwardedFrom, + } // Mock behavior for random and round robin load balancers ctrl := gomock.NewController(t) // Mock the LoadBalancer interface randomMock := NewMockLoadBalancer(ctrl) roundRobinMock := NewMockLoadBalancer(ctrl) - randomMock.EXPECT().PickReadPartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("random-partition").AnyTimes() - roundRobinMock.EXPECT().PickReadPartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom).Return("roundrobin-partition").AnyTimes() + randomMock.EXPECT().PickReadPartition(tt.taskListType, req, "").Return("random-partition").AnyTimes() + roundRobinMock.EXPECT().PickReadPartition(tt.taskListType, req, "").Return("roundrobin-partition").AnyTimes() // Mock dynamic config for loadbalancer strategy loadbalancerStrategyFn := func(domainName, taskListName string, taskListType int) string { @@ -217,7 +281,7 @@ func TestMultiLoadBalancer_PickReadPartition(t *testing.T) { } // Call PickReadPartition and assert result - partition := lb.PickReadPartition(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom) + partition := lb.PickReadPartition(tt.taskListType, req, "") assert.Equal(t, tt.expectedPartition, partition) }) } @@ -290,6 +354,11 @@ func TestMultiLoadBalancer_UpdateWeight(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + req := &types.AddDecisionTaskRequest{ + DomainUUID: tt.domainID, + TaskList: &tt.taskList, + ForwardedFrom: tt.forwardedFrom, + } // Mock behavior for random and round-robin load balancers ctrl := gomock.NewController(t) @@ -298,9 +367,9 @@ func TestMultiLoadBalancer_UpdateWeight(t *testing.T) { roundRobinMock := NewMockLoadBalancer(ctrl) if tt.shouldUpdate { - roundRobinMock.EXPECT().UpdateWeight(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom, tt.partition, tt.loadBalancerHints).Times(1) + roundRobinMock.EXPECT().UpdateWeight(tt.taskListType, req, tt.partition, tt.loadBalancerHints).Times(1) } else { - roundRobinMock.EXPECT().UpdateWeight(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + roundRobinMock.EXPECT().UpdateWeight(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) } loadbalancerStrategyFn := func(domainName, taskListName string, taskListType int) string { @@ -319,7 +388,7 @@ func TestMultiLoadBalancer_UpdateWeight(t *testing.T) { } // Call UpdateWeight - lb.UpdateWeight(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom, tt.partition, tt.loadBalancerHints) + lb.UpdateWeight(tt.taskListType, req, tt.partition, tt.loadBalancerHints) }) } } diff --git a/client/matching/rr_loadbalancer.go b/client/matching/rr_loadbalancer.go index 597b96a2e9d..8c77315c9da 100644 --- a/client/matching/rr_loadbalancer.go +++ b/client/matching/rr_loadbalancer.go @@ -23,10 +23,8 @@ package matching import ( - "strings" "sync/atomic" - "github.com/uber/cadence/common" "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/types" ) @@ -43,7 +41,7 @@ type ( readCache cache.Cache writeCache cache.Cache - pickPartitionFn func(domainName string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache) string + pickPartitionFn func(domainName string, taskList types.TaskList, taskListType int, nPartitions int, partitionCache cache.Cache) string } ) @@ -71,41 +69,30 @@ func NewRoundRobinLoadBalancer( } func (lb *roundRobinLoadBalancer) PickWritePartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req WriteRequest, ) string { - nPartitions := lb.provider.GetNumberOfWritePartitions(domainID, taskList, taskListType) - return lb.pickPartitionFn(domainID, taskList, taskListType, forwardedFrom, nPartitions, lb.writeCache) + nPartitions := lb.provider.GetNumberOfWritePartitions(req.GetDomainUUID(), *req.GetTaskList(), taskListType) + return lb.pickPartitionFn(req.GetDomainUUID(), *req.GetTaskList(), taskListType, nPartitions, lb.writeCache) } func (lb *roundRobinLoadBalancer) PickReadPartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req ReadRequest, + _ string, ) string { - n := lb.provider.GetNumberOfReadPartitions(domainID, taskList, taskListType) - return lb.pickPartitionFn(domainID, taskList, taskListType, forwardedFrom, n, lb.readCache) + n := lb.provider.GetNumberOfReadPartitions(req.GetDomainUUID(), *req.GetTaskList(), taskListType) + return lb.pickPartitionFn(req.GetDomainUUID(), *req.GetTaskList(), taskListType, n, lb.readCache) } func pickPartition( domainID string, taskList types.TaskList, taskListType int, - forwardedFrom string, nPartitions int, partitionCache cache.Cache, ) string { taskListName := taskList.GetName() - if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky { - return taskListName - } - if strings.HasPrefix(taskListName, common.ReservedTaskListPrefix) { - // this should never happen when forwardedFrom is empty - return taskListName - } if nPartitions <= 1 { return taskListName } @@ -135,10 +122,8 @@ func pickPartition( } func (lb *roundRobinLoadBalancer) UpdateWeight( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req ReadRequest, partition string, info *types.LoadBalancerHints, ) { diff --git a/client/matching/rr_loadbalancer_test.go b/client/matching/rr_loadbalancer_test.go index f8ece650239..bb34830071a 100644 --- a/client/matching/rr_loadbalancer_test.go +++ b/client/matching/rr_loadbalancer_test.go @@ -29,7 +29,6 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "github.com/uber/cadence/common" "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/types" ) @@ -56,36 +55,6 @@ func TestPickPartition(t *testing.T) { setupCache func(mockCache *cache.MockCache) expectedResult string }{ - { - name: "ForwardedFrom is not empty", - domainID: "testDomain", - taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindNormal.Ptr()}, - taskListType: 1, - forwardedFrom: "otherDomain", - nPartitions: 3, - setupCache: nil, - expectedResult: "testTaskList", - }, - { - name: "Sticky task list", - domainID: "testDomain", - taskList: types.TaskList{Name: "testTaskList", Kind: types.TaskListKindSticky.Ptr()}, - taskListType: 1, - forwardedFrom: "", - nPartitions: 3, - setupCache: nil, - expectedResult: "testTaskList", - }, - { - name: "Reserved task list prefix", - domainID: "testDomain", - taskList: types.TaskList{Name: fmt.Sprintf("%vTest", common.ReservedTaskListPrefix), Kind: types.TaskListKindNormal.Ptr()}, - taskListType: 1, - forwardedFrom: "", - nPartitions: 3, - setupCache: nil, - expectedResult: fmt.Sprintf("%vTest", common.ReservedTaskListPrefix), - }, { name: "nPartitions <= 1", domainID: "testDomain", @@ -176,7 +145,6 @@ func TestPickPartition(t *testing.T) { tt.domainID, tt.taskList, tt.taskListType, - tt.forwardedFrom, tt.nPartitions, mockCache, ) @@ -187,7 +155,7 @@ func TestPickPartition(t *testing.T) { } } -func setUpMocksForRoundRobinLoadBalancer(t *testing.T, pickPartitionFn func(domainID string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache) string) (*roundRobinLoadBalancer, *MockPartitionConfigProvider, *cache.MockCache) { +func setUpMocksForRoundRobinLoadBalancer(t *testing.T, pickPartitionFn func(domainID string, taskList types.TaskList, taskListType int, nPartitions int, partitionCache cache.Cache) string) (*roundRobinLoadBalancer, *MockPartitionConfigProvider, *cache.MockCache) { ctrl := gomock.NewController(t) mockProvider := NewMockPartitionConfigProvider(ctrl) mockCache := cache.NewMockCache(ctrl) @@ -209,14 +177,6 @@ func TestRoundRobinPickWritePartition(t *testing.T) { taskListKind types.TaskListKind expectedPartition string }{ - { - name: "single write partition, forwarded", - forwardedFrom: "parent-task-list", - taskListType: 0, - nPartitions: 1, - taskListKind: types.TaskListKindNormal, - expectedPartition: "test-task-list", - }, { name: "multiple write partitions, no forward", forwardedFrom: "", @@ -230,14 +190,13 @@ func TestRoundRobinPickWritePartition(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Fake pickPartitionFn behavior - fakePickPartitionFn := func(domainID string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache) string { + fakePickPartitionFn := func(domainID string, taskList types.TaskList, taskListType int, nPartitions int, partitionCache cache.Cache) string { assert.Equal(t, "test-domain-id", domainID) assert.Equal(t, "test-task-list", taskList.Name) assert.Equal(t, tc.taskListKind, taskList.GetKind()) assert.Equal(t, tc.taskListType, taskListType) - assert.Equal(t, tc.forwardedFrom, forwardedFrom) assert.Equal(t, tc.nPartitions, nPartitions) - if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky { + if taskList.GetKind() == types.TaskListKindSticky { return taskList.GetName() } return "custom-partition" @@ -253,7 +212,12 @@ func TestRoundRobinPickWritePartition(t *testing.T) { kind := tc.taskListKind taskList := types.TaskList{Name: "test-task-list", Kind: &kind} - partition := loadBalancer.PickWritePartition("test-domain-id", taskList, tc.taskListType, tc.forwardedFrom) + req := &types.AddDecisionTaskRequest{ + DomainUUID: "test-domain-id", + TaskList: &taskList, + ForwardedFrom: tc.forwardedFrom, + } + partition := loadBalancer.PickWritePartition(tc.taskListType, req) // Validate result assert.Equal(t, tc.expectedPartition, partition) @@ -270,14 +234,6 @@ func TestRoundRobinPickReadPartition(t *testing.T) { taskListKind types.TaskListKind expectedPartition string }{ - { - name: "single read partition, forwarded", - forwardedFrom: "parent-task-list", - taskListType: 0, - nPartitions: 1, - taskListKind: types.TaskListKindNormal, - expectedPartition: "test-task-list", - }, { name: "multiple read partitions, no forward", forwardedFrom: "", @@ -291,14 +247,13 @@ func TestRoundRobinPickReadPartition(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Fake pickPartitionFn behavior - fakePickPartitionFn := func(domainID string, taskList types.TaskList, taskListType int, forwardedFrom string, nPartitions int, partitionCache cache.Cache) string { + fakePickPartitionFn := func(domainID string, taskList types.TaskList, taskListType int, nPartitions int, partitionCache cache.Cache) string { assert.Equal(t, "test-domain-id", domainID) assert.Equal(t, "test-task-list", taskList.Name) assert.Equal(t, tc.taskListKind, taskList.GetKind()) assert.Equal(t, tc.taskListType, taskListType) - assert.Equal(t, tc.forwardedFrom, forwardedFrom) assert.Equal(t, tc.nPartitions, nPartitions) - if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky { + if taskList.GetKind() == types.TaskListKindSticky { return taskList.GetName() } return "custom-partition" @@ -314,7 +269,12 @@ func TestRoundRobinPickReadPartition(t *testing.T) { kind := tc.taskListKind taskList := types.TaskList{Name: "test-task-list", Kind: &kind} - partition := loadBalancer.PickReadPartition("test-domain-id", taskList, tc.taskListType, tc.forwardedFrom) + req := &types.AddDecisionTaskRequest{ + DomainUUID: "test-domain-id", + TaskList: &taskList, + ForwardedFrom: tc.forwardedFrom, + } + partition := loadBalancer.PickReadPartition(tc.taskListType, req, "") // Validate result assert.Equal(t, tc.expectedPartition, partition) diff --git a/client/matching/weighted_loadbalancer.go b/client/matching/weighted_loadbalancer.go index 383143c1fcf..964f972ac47 100644 --- a/client/matching/weighted_loadbalancer.go +++ b/client/matching/weighted_loadbalancer.go @@ -28,10 +28,8 @@ import ( "path" "sort" "strconv" - "strings" "sync" - "github.com/uber/cadence/common" "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" @@ -133,61 +131,45 @@ func NewWeightedLoadBalancer( } func (lb *weightedLoadBalancer) PickWritePartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req WriteRequest, ) string { - return lb.fallbackLoadBalancer.PickWritePartition(domainID, taskList, taskListType, forwardedFrom) + return lb.fallbackLoadBalancer.PickWritePartition(taskListType, req) } func (lb *weightedLoadBalancer) PickReadPartition( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req ReadRequest, + isolationGroup string, ) string { - if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky { - return taskList.GetName() - } - if strings.HasPrefix(taskList.GetName(), common.ReservedTaskListPrefix) { - return taskList.GetName() - } taskListKey := key{ - domainID: domainID, - taskListName: taskList.GetName(), + domainID: req.GetDomainUUID(), + taskListName: req.GetTaskList().GetName(), taskListType: taskListType, } wI := lb.weightCache.Get(taskListKey) if wI == nil { - return lb.fallbackLoadBalancer.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + return lb.fallbackLoadBalancer.PickReadPartition(taskListType, req, isolationGroup) } w, ok := wI.(*weightSelector) if !ok { - return lb.fallbackLoadBalancer.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + return lb.fallbackLoadBalancer.PickReadPartition(taskListType, req, isolationGroup) } p, cumulativeWeights := w.pick() - lb.logger.Debug("pick read partition", tag.WorkflowDomainID(domainID), tag.WorkflowTaskListName(taskList.GetName()), tag.WorkflowTaskListType(taskListType), tag.Dynamic("cumulative-weights", cumulativeWeights), tag.Dynamic("task-list-partition", p)) + lb.logger.Debug("pick read partition", tag.WorkflowDomainID(req.GetDomainUUID()), tag.WorkflowTaskListName(req.GetTaskList().Name), tag.WorkflowTaskListType(taskListType), tag.Dynamic("cumulative-weights", cumulativeWeights), tag.Dynamic("task-list-partition", p)) if p < 0 { - return lb.fallbackLoadBalancer.PickReadPartition(domainID, taskList, taskListType, forwardedFrom) + return lb.fallbackLoadBalancer.PickReadPartition(taskListType, req, isolationGroup) } - return getPartitionTaskListName(taskList.GetName(), p) + return getPartitionTaskListName(req.GetTaskList().GetName(), p) } func (lb *weightedLoadBalancer) UpdateWeight( - domainID string, - taskList types.TaskList, taskListType int, - forwardedFrom string, + req ReadRequest, partition string, info *types.LoadBalancerHints, ) { - if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky { - return - } - if strings.HasPrefix(taskList.GetName(), common.ReservedTaskListPrefix) { - return - } + taskList := *req.GetTaskList() if info == nil { return } @@ -200,11 +182,11 @@ func (lb *weightedLoadBalancer) UpdateWeight( } } taskListKey := key{ - domainID: domainID, + domainID: req.GetDomainUUID(), taskListName: taskList.GetName(), taskListType: taskListType, } - n := lb.provider.GetNumberOfReadPartitions(domainID, taskList, taskListType) + n := lb.provider.GetNumberOfReadPartitions(req.GetDomainUUID(), taskList, taskListType) if n <= 1 { lb.weightCache.Delete(taskListKey) return @@ -223,7 +205,7 @@ func (lb *weightedLoadBalancer) UpdateWeight( return } weight := calcWeightFromLoadBalancerHints(info) - lb.logger.Debug("update task list partition weight", tag.WorkflowDomainID(domainID), tag.WorkflowTaskListName(taskList.GetName()), tag.WorkflowTaskListType(taskListType), tag.Dynamic("task-list-partition", p), tag.Dynamic("weight", weight), tag.Dynamic("load-balancer-hints", info)) + lb.logger.Debug("update task list partition weight", tag.WorkflowDomainID(req.GetDomainUUID()), tag.WorkflowTaskListName(taskList.GetName()), tag.WorkflowTaskListType(taskListType), tag.Dynamic("task-list-partition", p), tag.Dynamic("weight", weight), tag.Dynamic("load-balancer-hints", info)) w.update(n, p, weight) } diff --git a/client/matching/weighted_loadbalancer_test.go b/client/matching/weighted_loadbalancer_test.go index e3e8b01a016..6c910a56b41 100644 --- a/client/matching/weighted_loadbalancer_test.go +++ b/client/matching/weighted_loadbalancer_test.go @@ -134,8 +134,13 @@ func TestWeightedLoadBalancer_PickWritePartition(t *testing.T) { domainID: "domainA", taskList: types.TaskList{Name: "taskListA"}, setupMock: func(m *MockLoadBalancer) { + req := &types.AddDecisionTaskRequest{ + DomainUUID: "domainA", + TaskList: &types.TaskList{Name: "taskListA"}, + ForwardedFrom: "", + } m.EXPECT(). - PickWritePartition("domainA", types.TaskList{Name: "taskListA"}, 0, ""). + PickWritePartition(0, req). Return("partitionA") }, expectedResult: "partitionA", @@ -154,7 +159,13 @@ func TestWeightedLoadBalancer_PickWritePartition(t *testing.T) { fallbackLoadBalancer: mockFallbackLB, } - result := lb.PickWritePartition(tc.domainID, tc.taskList, tc.taskListType, tc.forwardedFrom) + req := &types.AddDecisionTaskRequest{ + DomainUUID: tc.domainID, + TaskList: &tc.taskList, + ForwardedFrom: tc.forwardedFrom, + } + + result := lb.PickWritePartition(tc.taskListType, req) assert.Equal(t, tc.expectedResult, result) }) } @@ -216,6 +227,11 @@ func TestWeightedLoadBalancer_PickReadPartition(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + req := &types.AddDecisionTaskRequest{ + DomainUUID: tc.domainID, + TaskList: &tc.taskList, + ForwardedFrom: tc.forwardedFrom, + } ctrl := gomock.NewController(t) // Create mocks. mockWeightCache := cache.NewMockCache(ctrl) @@ -233,7 +249,7 @@ func TestWeightedLoadBalancer_PickReadPartition(t *testing.T) { if tc.expectFallbackCall { mockFallbackLoadBalancer.EXPECT(). - PickReadPartition(tc.domainID, tc.taskList, tc.taskListType, tc.forwardedFrom). + PickReadPartition(tc.taskListType, req, ""). Return(tc.fallbackReturn) } @@ -247,7 +263,7 @@ func TestWeightedLoadBalancer_PickReadPartition(t *testing.T) { } // Call the method under test. - result := lb.PickReadPartition(tc.domainID, tc.taskList, tc.taskListType, tc.forwardedFrom) + result := lb.PickReadPartition(tc.taskListType, req, "") // Assert the result. assert.Equal(t, tc.expectedResult, result) @@ -347,6 +363,11 @@ func TestWeightedLoadBalancer_UpdateWeight(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + req := &types.AddDecisionTaskRequest{ + DomainUUID: tc.domainID, + TaskList: &tc.taskList, + ForwardedFrom: tc.forwardedFrom, + } ctrl := gomock.NewController(t) mockWeightCache := cache.NewMockCache(ctrl) mockPartitionConfigProvider := NewMockPartitionConfigProvider(ctrl) @@ -359,7 +380,7 @@ func TestWeightedLoadBalancer_UpdateWeight(t *testing.T) { tc.setupMock(mockWeightCache, mockPartitionConfigProvider) } - lb.UpdateWeight(tc.domainID, tc.taskList, tc.taskListType, tc.forwardedFrom, tc.partition, tc.loadBalancerHints) + lb.UpdateWeight(tc.taskListType, req, tc.partition, tc.loadBalancerHints) }) } } diff --git a/common/resource/resource_impl.go b/common/resource/resource_impl.go index 7622554175d..3951bcb4f82 100644 --- a/common/resource/resource_impl.go +++ b/common/resource/resource_impl.go @@ -177,6 +177,7 @@ func New( params.MetricsClient, dynamicCollection, numShards, + params.GetIsolationGroups, logger, ), params.RPCFactory.GetDispatcher(), diff --git a/host/service.go b/host/service.go index 368e5be8a6a..3ddee45ab54 100644 --- a/host/service.go +++ b/host/service.go @@ -80,6 +80,7 @@ type ( clientBean client.Bean timeSource clock.TimeSource numberOfHistoryShards int + allIsolationGroups func() []string logger log.Logger throttledLogger log.Logger @@ -113,6 +114,7 @@ func NewService(params *resource.Params) Service { timeSource: clock.NewRealTimeSource(), metricsScope: params.MetricScope, numberOfHistoryShards: params.PersistenceConfig.NumHistoryShards, + allIsolationGroups: params.GetIsolationGroups, clusterMetadata: params.ClusterMetadata, metricsClient: params.MetricsClient, messagingClient: params.MessagingClient, @@ -164,7 +166,7 @@ func (h *serviceImpl) Start() { h.hostInfo = hostInfo h.clientBean, err = client.NewClientBean( - client.NewRPCClientFactory(h.rpcFactory, h.membershipResolver, h.metricsClient, h.dynamicCollection, h.numberOfHistoryShards, h.logger), + client.NewRPCClientFactory(h.rpcFactory, h.membershipResolver, h.metricsClient, h.dynamicCollection, h.numberOfHistoryShards, h.allIsolationGroups, h.logger), h.rpcFactory.GetDispatcher(), h.clusterMetadata, )