Skip to content

Commit

Permalink
Extend RabbitMQ scaler to support count unacked messages
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Emelyanov <[email protected]>
  • Loading branch information
holyketzer committed Mar 25, 2020
1 parent 3692eca commit 6a2e827
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 29 deletions.
147 changes: 120 additions & 27 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package scalers

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"time"

"github.com/streadway/amqp"
v2beta1 "k8s.io/api/autoscaling/v2beta1"
Expand All @@ -17,6 +22,8 @@ import (
const (
rabbitQueueLengthMetricName = "queueLength"
rabbitMetricType = "External"
rabbitCountUnacked = "countUnacked"
defaultCountUnacked = false
)

type rabbitMQScaler struct {
Expand All @@ -26,9 +33,17 @@ type rabbitMQScaler struct {
}

type rabbitMQMetadata struct {
queueName string
host string
queueLength int
queueName string
host string // connection string for AMQP protocol
apiHost string // connection string for management API requests
queueLength int
countUnacked bool // if true uses HTTP API and requires apiHost, if false uses AMQP and requires host
}

type queueInfo struct {
Messages int `json:"messages"`
MessagesUnacknowledged int `json:"messages_unacknowledged"`
Name string `json:"name"`
}

var rabbitmqLog = logf.Log.WithName("rabbitmq_scaler")
Expand All @@ -40,33 +55,62 @@ func NewRabbitMQScaler(resolvedEnv, metadata, authParams map[string]string) (Sca
return nil, fmt.Errorf("error parsing rabbitmq metadata: %s", err)
}

conn, ch, err := getConnectionAndChannel(meta.host)
if err != nil {
return nil, fmt.Errorf("error establishing rabbitmq connection: %s", err)
}
if meta.countUnacked {
return &rabbitMQScaler{metadata: meta}, nil
} else {
conn, ch, err := getConnectionAndChannel(meta.host)
if err != nil {
return nil, fmt.Errorf("error establishing rabbitmq connection: %s", err)
}

return &rabbitMQScaler{
metadata: meta,
connection: conn,
channel: ch,
}, nil
return &rabbitMQScaler{
metadata: meta,
connection: conn,
channel: ch,
}, nil
}
}

func parseRabbitMQMetadata(resolvedEnv, metadata, authParams map[string]string) (*rabbitMQMetadata, error) {
meta := rabbitMQMetadata{}

if val, ok := authParams["host"]; ok {
meta.host = val
} else if val, ok := metadata["host"]; ok {
hostSetting := val
meta.countUnacked = defaultCountUnacked
if val, ok := metadata[rabbitCountUnacked]; ok {
countUnacked, err := strconv.ParseBool(val)
if err != nil {
return nil, fmt.Errorf("countUnacked parsing error %s", err.Error())
}
meta.countUnacked = countUnacked
}

if meta.countUnacked {
if val, ok := authParams["apiHost"]; ok {
meta.apiHost = val
} else if val, ok := metadata["apiHost"]; ok {
hostSetting := val

if val, ok := resolvedEnv[hostSetting]; ok {
if val, ok := resolvedEnv[hostSetting]; ok {
meta.apiHost = val
}
}

if meta.apiHost == "" {
return nil, fmt.Errorf("no apiHost setting given")
}
} else {
if val, ok := authParams["host"]; ok {
meta.host = val
} else if val, ok := metadata["host"]; ok {
hostSetting := val

if val, ok := resolvedEnv[hostSetting]; ok {
meta.host = val
}
}
}

if meta.host == "" {
return nil, fmt.Errorf("no host setting given")
if meta.host == "" {
return nil, fmt.Errorf("no host setting given")
}
}

if val, ok := metadata["queueName"]; ok {
Expand Down Expand Up @@ -105,10 +149,12 @@ func getConnectionAndChannel(host string) (*amqp.Connection, *amqp.Channel, erro

// Close disposes of RabbitMQ connections
func (s *rabbitMQScaler) Close() error {
err := s.connection.Close()
if err != nil {
rabbitmqLog.Error(err, "Error closing rabbitmq connection")
return err
if s.metadata.countUnacked == false {
err := s.connection.Close()
if err != nil {
rabbitmqLog.Error(err, "Error closing rabbitmq connection")
return err
}
}
return nil
}
Expand All @@ -124,12 +170,59 @@ func (s *rabbitMQScaler) IsActive(ctx context.Context) (bool, error) {
}

func (s *rabbitMQScaler) getQueueMessages() (int, error) {
items, err := s.channel.QueueInspect(s.metadata.queueName)
if s.metadata.countUnacked {
info, err := s.getQueueInfoViaHttp()
if err != nil {
return -1, err
} else {
return info.Messages + info.MessagesUnacknowledged, nil
}
} else {
items, err := s.channel.QueueInspect(s.metadata.queueName)
if err != nil {
return -1, err
} else {
return items.Messages, nil
}
}
}

func getJson(url string, target interface{}) error {
var client = &http.Client{Timeout: 5 * time.Second}
r, err := client.Get(url)
if err != nil {
return -1, err
return err
}
defer r.Body.Close()

return items.Messages, nil
if r.StatusCode == 200 {
return json.NewDecoder(r.Body).Decode(target)
} else {
body, _ := ioutil.ReadAll(r.Body)
return fmt.Errorf("error requesting rabbitMQ API status: %s, response: %s, from: %s", r.Status, body, url)
}
}

func (s *rabbitMQScaler) getQueueInfoViaHttp() (*queueInfo, error) {
parsedUrl, err := url.Parse(s.metadata.apiHost)

if err != nil {
return nil, err
}

vhost := parsedUrl.Path
parsedUrl.Path = ""

getQueueInfoManagementURI := fmt.Sprintf("%s/%s%s/%s", parsedUrl.String(), "api/queues", vhost, s.metadata.queueName)

info := queueInfo{}
err = getJson(getQueueInfoManagementURI, &info)

if err != nil {
return nil, err
} else {
return &info, nil
}
}

// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler
Expand Down
77 changes: 75 additions & 2 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package scalers

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

const (
host = "myHostSecret"
host = "myHostSecret"
apiHost = "myApiHostSecret"
)

type parseRabbitMQMetadataTestData struct {
Expand All @@ -15,7 +21,8 @@ type parseRabbitMQMetadataTestData struct {
}

var sampleRabbitMqResolvedEnv = map[string]string{
host: "none",
host: "amqp://user:[email protected]:5236/vhost",
apiHost: "https://user:[email protected]/vhost",
}

var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
Expand All @@ -31,6 +38,8 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
{map[string]string{"queueLength": "10", "host": host}, true, map[string]string{}},
// host defined in authParams
{map[string]string{"queueLength": "10"}, true, map[string]string{"host": host}},
// properly formed metadata with countUnacked
{map[string]string{"queueLength": "10", "queueName": "sample", "apiHost": apiHost, "countUnacked": "true"}, false, map[string]string{}},
}

func TestRabbitMQParseMetadata(t *testing.T) {
Expand All @@ -44,3 +53,67 @@ func TestRabbitMQParseMetadata(t *testing.T) {
}
}
}

type getQueueInfoTestData struct {
response string
responseStatus int
isActive bool
}

var testQueueInfoTestData = []getQueueInfoTestData{
{`{"messages": 4, "messages_unacknowledged": 1, "name": "evaluate_trials"}`, http.StatusOK, true},
{`{"messages": 0, "messages_unacknowledged": 1, "name": "evaluate_trials"}`, http.StatusOK, true},
{`{"messages": 1, "messages_unacknowledged": 0, "name": "evaluate_trials"}`, http.StatusOK, true},
{`{"messages": 0, "messages_unacknowledged": 0, "name": "evaluate_trials"}`, http.StatusOK, false},
{`Password is incorrect`, http.StatusUnauthorized, false},
}

func TestGetQueueInfo(t *testing.T) {
for _, testData := range testQueueInfoTestData {
var apiStub = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expeced_path := "/api/queues/myhost/evaluate_trials"
if r.RequestURI != expeced_path {
t.Error("Expect request path to =", expeced_path, "but it is", r.RequestURI)
}

w.WriteHeader(testData.responseStatus)
w.Write([]byte(testData.response))
}))

resolvedEnv := map[string]string{apiHost: fmt.Sprintf("%s/%s", apiStub.URL, "myhost")}

metadata := map[string]string{
"queueLength": "10",
"queueName": "evaluate_trials",
"apiHost": apiHost,
"countUnacked": "true",
}

s, err := NewRabbitMQScaler(resolvedEnv, metadata, map[string]string{})

if err != nil {
t.Error("Expect success", err)
}

ctx := context.TODO()
active, err := s.IsActive(ctx)

if testData.responseStatus == http.StatusOK {
if err != nil {
t.Error("Expect success", err)
}

if active != testData.isActive {
if testData.isActive {
t.Error("Expect to be active")
} else {
t.Error("Expect to not be active")
}
}
} else {
if !strings.Contains(err.Error(), testData.response) {
t.Error("Expect error to be like '", testData.response, "' but it's '", err, "'")
}
}
}
}

0 comments on commit 6a2e827

Please sign in to comment.