diff --git a/go.mod b/go.mod index 7d06608b36e..cd3a84b96bf 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( github.com/go-logr/logr v0.1.0 github.com/go-openapi/spec v0.19.3 github.com/go-redis/redis v6.15.5+incompatible + github.com/go-sql-driver/mysql v1.4.1 github.com/golang/mock v1.3.1 github.com/golang/protobuf v1.3.2 github.com/imdario/mergo v0.3.8 diff --git a/go.sum b/go.sum index b5a0528f869..414f6214c63 100644 --- a/go.sum +++ b/go.sum @@ -228,6 +228,7 @@ github.com/go-openapi/validate v0.18.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+ github.com/go-redis/redis v6.15.5+incompatible h1:pLky8I0rgiblWfa8C1EV7fPEUv0aH6vKRaYHc/YRHVk= github.com/go-redis/redis v6.15.5+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobuffalo/envy v1.6.5/go.mod h1:N+GkhhZ/93bGZc6ZKhJLP6+m+tCNPKwgSpH9kaifseQ= diff --git a/pkg/handler/scale_handler.go b/pkg/handler/scale_handler.go index cbcc0702f85..b0c403552f7 100644 --- a/pkg/handler/scale_handler.go +++ b/pkg/handler/scale_handler.go @@ -298,6 +298,8 @@ func (h *ScaleHandler) getScaler(name, namespace, triggerType string, resolvedEn return scalers.NewAzureBlobScaler(resolvedEnv, triggerMetadata, authParams, podIdentity) case "postgres": return scalers.NewPostgresScaler(resolvedEnv, triggerMetadata, authParams) + case "mysql": + return scalers.NewMySQLScaler(resolvedEnv, triggerMetadata, authParams) default: return nil, fmt.Errorf("no scaler found for type: %s", triggerType) } diff --git a/pkg/scalers/mysql_scaler.go b/pkg/scalers/mysql_scaler.go new file mode 100644 index 00000000000..dc43a088c1a --- /dev/null +++ b/pkg/scalers/mysql_scaler.go @@ -0,0 +1,212 @@ +package scalers + +import ( + "context" + "database/sql" + "fmt" + "github.com/go-sql-driver/mysql" + "k8s.io/api/autoscaling/v2beta1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/metrics/pkg/apis/external_metrics" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "strconv" +) + +const ( + mySQLMetricName = "MySQLQueryValue" + defaultMySQLPassword = "" +) + +type mySQLScaler struct { + metadata *mySQLMetadata + connection *sql.DB +} + +type mySQLMetadata struct { + connectionString string // Database connection string + username string + password string + host string + port string + dbName string + query string + queryValue int +} + +var mySQLLog = logf.Log.WithName("mysql_scaler") + +// NewMySQLScaler creates a new MySQL scaler +func NewMySQLScaler(resolvedEnv, metadata, authParams map[string]string) (Scaler, error) { + meta, err := parseMySQLMetadata(resolvedEnv, metadata, authParams) + if err != nil { + return nil, fmt.Errorf("error parsing MySQL metadata: %s", err) + } + + conn, err := newMySQLConnection(meta) + if err != nil { + return nil, fmt.Errorf("error establishing MySQL connectionString: %s", err) + } + return &mySQLScaler{ + metadata: meta, + connection: conn, + }, nil +} + +func parseMySQLMetadata(resolvedEnv, metadata, authParams map[string]string) (*mySQLMetadata, error) { + meta := mySQLMetadata{} + + if val, ok := metadata["query"]; ok { + meta.query = val + } else { + return nil, fmt.Errorf("no query given") + } + + if val, ok := metadata["queryValue"]; ok { + queryValue, err := strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("queryValue parsing error %s", err.Error()) + } + meta.queryValue = queryValue + } else { + return nil, fmt.Errorf("no queryValue given") + } + + if val, ok := authParams["connectionString"]; ok { + meta.connectionString = val + } else if val, ok := metadata["connectionString"]; ok { + hostSetting := val + if val, ok := resolvedEnv[hostSetting]; ok { + meta.connectionString = val + } + } else { + meta.connectionString = "" + if val, ok := metadata["host"]; ok { + meta.host = val + } else { + return nil, fmt.Errorf("no host given") + } + if val, ok := metadata["port"]; ok { + meta.port = val + } else { + return nil, fmt.Errorf("no port given") + } + + if val, ok := metadata["username"]; ok { + meta.username = val + } else { + return nil, fmt.Errorf("no username given") + } + if val, ok := metadata["dbName"]; ok { + meta.dbName = val + } else { + return nil, fmt.Errorf("no dbName given") + } + meta.password = defaultMySQLPassword + if val, ok := authParams["password"]; ok { + meta.password = val + } else if val, ok := metadata["password"]; ok && val != "" { + if pass, ok := resolvedEnv[val]; ok { + meta.password = pass + } + } + } + + return &meta, nil +} + +// metadataToConnectionStr builds new MySQL connection string +func metadataToConnectionStr(meta *mySQLMetadata) string { + var connStr string + + if meta.connectionString != "" { + connStr = meta.connectionString + } else { + // Build connection str + config := mysql.NewConfig() + config.Addr = fmt.Sprintf("%s:%s", meta.host, meta.port) + config.DBName = meta.dbName + config.Passwd = meta.password + config.User = meta.username + config.Net = "tcp" + connStr = config.FormatDSN() + } + return connStr +} + +// newMySQLConnection creates MySQL db connection +func newMySQLConnection(meta *mySQLMetadata) (*sql.DB, error) { + connStr := metadataToConnectionStr(meta) + db, err := sql.Open("mysql", connStr) + if err != nil { + mySQLLog.Error(err, fmt.Sprintf("Found error when opening connection: %s", err)) + return nil, err + } + err = db.Ping() + if err != nil { + mySQLLog.Error(err, fmt.Sprintf("Found error when pinging databse: %s", err)) + return nil, err + } + return db, nil +} + +// Close disposes of MySQL connections +func (s *mySQLScaler) Close() error { + err := s.connection.Close() + if err != nil { + mySQLLog.Error(err, "Error closing MySQL connection") + return err + } + return nil +} + +// IsActive returns true if there are pending messages to be processed +func (s *mySQLScaler) IsActive(ctx context.Context) (bool, error) { + messages, err := s.getQueryResult() + if err != nil { + mySQLLog.Error(err, fmt.Sprintf("Error inspecting MySQL: %s", err)) + return false, err + } + return messages > 0, nil +} + +// getQueryResult returns result of the scaler query +func (s *mySQLScaler) getQueryResult() (int, error) { + var value int + err := s.connection.QueryRow(s.metadata.query).Scan(&value) + if err != nil { + mySQLLog.Error(err, fmt.Sprintf("Could not query MySQL database: %s", err)) + return 0, err + } + return value, nil +} + +// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler +func (s *mySQLScaler) GetMetricSpecForScaling() []v2beta1.MetricSpec { + targetQueryValue := resource.NewQuantity(int64(s.metadata.queryValue), resource.DecimalSI) + externalMetric := &v2beta1.ExternalMetricSource{ + MetricName: mySQLMetricName, + TargetAverageValue: targetQueryValue, + } + metricSpec := v2beta1.MetricSpec{ + External: externalMetric, Type: externalMetricType, + } + return []v2beta1.MetricSpec{metricSpec} +} + +// GetMetrics returns value for a supported metric and an error if there is a problem getting the metric +func (s *mySQLScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { + num, err := s.getQueryResult() + if err != nil { + return []external_metrics.ExternalMetricValue{}, fmt.Errorf("error inspecting MySQL: %s", err) + } + + metric := external_metrics.ExternalMetricValue{ + MetricName: mySQLMetricName, + Value: *resource.NewQuantity(int64(num), resource.DecimalSI), + Timestamp: metav1.Now(), + } + + return append([]external_metrics.ExternalMetricValue{}, metric), nil +} diff --git a/pkg/scalers/mysql_scaler_test.go b/pkg/scalers/mysql_scaler_test.go new file mode 100644 index 00000000000..2efcb0f0f70 --- /dev/null +++ b/pkg/scalers/mysql_scaler_test.go @@ -0,0 +1,58 @@ +package scalers + +import ( + "testing" +) + +var testMySQLResolvedEnv = map[string]string{ + "MYSQL_PASSWORD": "pass", + "MYSQL_CONN_STR": "test_conn_str", +} + +type parseMySQLMetadataTestData struct { + metdadata map[string] string + raisesError bool +} + +var testMySQLMetdata = []parseMySQLMetadataTestData{ + // No metadata + {metdadata: map[string]string{}, raisesError:true}, + // connectionString + {metdadata: map[string]string{"query": "query", "queryValue": "12", "connectionString": "test_value"}, raisesError:false}, + // Params instead of conn str + {metdadata: map[string]string{"query": "query", "queryValue": "12", "host": "test_host", "port": "test_port", "username": "test_username", "password": "test_password", "dbName": "test_dbname"}, raisesError:false}, +} + +func TestParseMySQLMetadata(t *testing.T) { + for _, testData := range testMySQLMetdata { + _, err := parseMySQLMetadata(testMySQLResolvedEnv, testData.metdadata, map[string]string{}) + if err != nil && !testData.raisesError { + t.Error("Expected success but got error", err) + } + if err == nil && testData.raisesError { + t.Error("Expected error but got success") + } + } +} + +func TestMetadataToConnectionStrUseConnStr(t *testing.T) { + // Use existing ConnStr + testMeta := map[string]string{"query": "query", "queryValue": "12", "connectionString": "MYSQL_CONN_STR"} + meta, _ := parseMySQLMetadata(testMySQLResolvedEnv, testMeta, map[string]string{}) + connStr := metadataToConnectionStr(meta) + if connStr != testMySQLResolvedEnv["MYSQL_CONN_STR"] { + t.Error("Expected success") + } +} + +func TestMetadataToConnectionStrBuildNew(t *testing.T) { + // Build new ConnStr + expected := "test_username:pass@tcp(test_host:test_port)/test_dbname" + testMeta := map[string]string{"query": "query", "queryValue": "12", "host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"} + meta, _ := parseMySQLMetadata(testMySQLResolvedEnv, testMeta, map[string]string{}) + connStr := metadataToConnectionStr(meta) + if connStr != expected { + t.Errorf("%s != %s", expected, connStr) + } +} +