From 24ae7848f7022cc929743bfb5cb6b78fd4617a70 Mon Sep 17 00:00:00 2001 From: Dimitri Masson <30894448+d-masson@users.noreply.github.com> Date: Mon, 19 Oct 2020 11:39:24 +0200 Subject: [PATCH] reverse_proxy: Fix random_choose lb_policy --- .../reverseproxy/selectionpolicies.go | 16 +++--- .../reverseproxy/selectionpolicies_test.go | 51 +++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 343140fc1eda..1d6aa5514cbb 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -139,14 +139,17 @@ func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request) *Upstr if k > len(pool) { k = len(pool) } - choices := make([]*Upstream, k) + // Slice of length 0 and capacity k (if we use an array and have some unavailable + // upstream, we could have some empty element which may lead + // to an index out of range runtime error) + choices := make([]*Upstream, 0, k) for i, upstream := range pool { if !upstream.Available() { continue } j := weakrand.Intn(i + 1) if j < k { - choices[j] = upstream + choices = append(choices, upstream) } } return leastRequests(choices) @@ -397,13 +400,12 @@ func leastRequests(upstreams []*Upstream) *Upstream { return nil } var best []*Upstream - var bestReqs int + var bestReqs int = -1 for _, upstream := range upstreams { reqs := upstream.NumRequests() - if reqs == 0 { - return upstream - } - if reqs <= bestReqs { + // If bestReqs was just initialized to -1 + // we need to append upstream also + if reqs <= bestReqs || bestReqs == -1 { bestReqs = reqs best = append(best, upstream) } diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index e9939d6d14bf..3e8d4f51877e 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -271,3 +271,54 @@ func TestURIHashPolicy(t *testing.T) { t.Error("Expected uri policy policy host to be nil.") } } + +func TestLeastRequests(t *testing.T) { + pool := testPool() + pool[0].Dial = "localhost:8080" + pool[1].Dial = "localhost:8081" + pool[2].Dial = "localhost:8082" + pool[0].SetHealthy(true) + pool[1].SetHealthy(true) + pool[2].SetHealthy(true) + pool[0].CountRequest(10) + pool[1].CountRequest(20) + pool[2].CountRequest(30) + + result := leastRequests(pool) + + if result == nil { + t.Error("Least request should not return nil") + } + + if result != pool[0] { + t.Error("Least request should return pool[0]") + } +} + +func TestRandomChoicePolicy(t *testing.T) { + pool := testPool() + pool[0].Dial = "localhost:8080" + pool[1].Dial = "localhost:8081" + pool[2].Dial = "localhost:8082" + pool[0].SetHealthy(false) + pool[1].SetHealthy(true) + pool[2].SetHealthy(true) + pool[0].CountRequest(10) + pool[1].CountRequest(20) + pool[2].CountRequest(30) + + request := httptest.NewRequest(http.MethodGet, "/test", nil) + randomChoicePolicy := new(RandomChoiceSelection) + randomChoicePolicy.Choose = 2 + + h := randomChoicePolicy.Select(pool, request) + + if h == nil { + t.Error("RandomChoicePolicy should not return nil") + } + + if h != pool[1] { + t.Error("RandomChoicePolicy should choose pool[1]") + } + +}