Skip to content

Commit

Permalink
Add support for workload identity
Browse files Browse the repository at this point in the history
Signed-off-by: anton.lysina <[email protected]>
  • Loading branch information
toniiiik committed Jan 15, 2024
1 parent 3e12e20 commit d3b1793
Show file tree
Hide file tree
Showing 4 changed files with 438 additions and 45 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Here is an overview of all new **experimental** features:
### Improvements

- **General**: TODO ([#XXX](https://github.com/kedacore/keda/issues/XXX))
- **Azure pipeline Scaler**: Add support for workload identity authentication ([#5013](https://github.com/kedacore/keda/issues/5013))

### Fixes

Expand Down Expand Up @@ -165,7 +166,6 @@ New deprecation(s):
- **General**: Support profiling for KEDA components ([#4789](https://github.com/kedacore/keda/issues/4789))
- **CPU scaler**: Wait for metrics window during CPU scaler tests ([#5294](https://github.com/kedacore/keda/pull/5294))
- **Hashicorp Vault**: Improve test coverage in `pkg/scaling/resolver/hashicorpvault_handler` ([#5195](https://github.com/kedacore/keda/issues/5195))
- **Kafka Scaler**: Add more test cases for large value of LagThreshold ([#5354](https://github.com/kedacore/keda/issues/5354))
- **Openstack Scaler**: Use Gophercloud SDK ([#3439](https://github.com/kedacore/keda/issues/3439))

## v2.12.1
Expand Down
148 changes: 109 additions & 39 deletions pkg/scalers/azure_pipelines_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ import (
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/go-logr/logr"
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
"github.com/kedacore/keda/v2/pkg/scalers/azure"
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

Expand All @@ -27,6 +31,11 @@ type JobRequests struct {
Value []JobRequest `json:"value"`
}

const (
// Azure storage resource is "https://storage.azure.com/" in all cloud environments
devopsResource = "499b84ac-1321-427f-aa17-267ca6975798/.default"
)

type JobRequest struct {
RequestID int `json:"requestId"`
QueueTime time.Time `json:"queueTime"`
Expand Down Expand Up @@ -119,16 +128,17 @@ type azurePipelinesPoolIDResponse struct {
}

type azurePipelinesScaler struct {
metricType v2.MetricTargetType
metadata *azurePipelinesMetadata
httpClient *http.Client
logger logr.Logger
metricType v2.MetricTargetType
metadata *azurePipelinesMetadata
httpClient *http.Client
podIdentity kedav1alpha1.AuthPodIdentity
logger logr.Logger
}

type azurePipelinesMetadata struct {
organizationURL string
organizationName string
personalAccessToken string
authContext authContext
parent string
demands string
poolID int
Expand All @@ -139,36 +149,68 @@ type azurePipelinesMetadata struct {
requireAllDemands bool
}

type authContext struct {
cred *azidentity.ChainedTokenCredential
pat string
}

// NewAzurePipelinesScaler creates a new AzurePipelinesScaler
func NewAzurePipelinesScaler(ctx context.Context, config *ScalerConfig) (Scaler, error) {
httpClient := kedautil.CreateHTTPClient(config.GlobalHTTPTimeout, false)

logger := InitializeLogger(config, "azure_pipelines_scaler")
metricType, err := GetMetricTargetType(config)
if err != nil {
return nil, fmt.Errorf("error getting scaler metric type: %w", err)
}

meta, err := parseAzurePipelinesMetadata(ctx, config, httpClient)
meta, podIdentity, err := parseAzurePipelinesMetadata(ctx, logger, config, httpClient)
if err != nil {
return nil, fmt.Errorf("error parsing azure Pipelines metadata: %w", err)
}

return &azurePipelinesScaler{
metricType: metricType,
metadata: meta,
httpClient: httpClient,
logger: InitializeLogger(config, "azure_pipelines_scaler"),
metricType: metricType,
metadata: meta,
httpClient: httpClient,
podIdentity: podIdentity,
logger: logger,
}, nil
}

func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, httpClient *http.Client) (*azurePipelinesMetadata, error) {
func getAuthMethod(logger logr.Logger, config *ScalerConfig) (string, *azidentity.ChainedTokenCredential, kedav1alpha1.AuthPodIdentity, error) {
pat := ""
if val, ok := config.AuthParams["personalAccessToken"]; ok && val != "" {
// Found the personalAccessToken in a parameter from TriggerAuthentication
pat = config.AuthParams["personalAccessToken"]
} else if val, ok := config.TriggerMetadata["personalAccessTokenFromEnv"]; ok && val != "" {
pat = config.ResolvedEnv[config.TriggerMetadata["personalAccessTokenFromEnv"]]
} else {
switch config.PodIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
return "", nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no personalAccessToken given or PodIdentity provider configured")
// return "", kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderNone}, nil
case kedav1alpha1.PodIdentityProviderAzure, kedav1alpha1.PodIdentityProviderAzureWorkload:
cred, err := azure.NewChainedCredential(logger, config.PodIdentity.GetIdentityID(), config.PodIdentity.Provider)
if err != nil {
return "", nil, kedav1alpha1.AuthPodIdentity{}, err
}
return "", cred, kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}, nil
default:
return "", nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("pod identity %s not supported for azure pipelines", config.PodIdentity.Provider)
}
}
return pat, nil, kedav1alpha1.AuthPodIdentity{}, nil
}

func parseAzurePipelinesMetadata(ctx context.Context, logger logr.Logger, config *ScalerConfig, httpClient *http.Client) (*azurePipelinesMetadata, kedav1alpha1.AuthPodIdentity, error) {
meta := azurePipelinesMetadata{}
meta.targetPipelinesQueueLength = defaultTargetPipelinesQueueLength

if val, ok := config.TriggerMetadata["targetPipelinesQueueLength"]; ok {
queueLength, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing azure pipelines metadata targetPipelinesQueueLength: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing azure pipelines metadata targetPipelinesQueueLength: %w", err)
}

meta.targetPipelinesQueueLength = queueLength
Expand All @@ -178,7 +220,7 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["activationTargetPipelinesQueueLength"]; ok {
activationQueueLength, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing azure pipelines metadata activationTargetPipelinesQueueLength: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing azure pipelines metadata activationTargetPipelinesQueueLength: %w", err)
}

meta.activationTargetPipelinesQueueLength = activationQueueLength
Expand All @@ -190,22 +232,23 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
} else if val, ok := config.TriggerMetadata["organizationURLFromEnv"]; ok && val != "" {
meta.organizationURL = config.ResolvedEnv[val]
} else {
return nil, fmt.Errorf("no organizationURL given")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no organizationURL given")
}

if val := meta.organizationURL[strings.LastIndex(meta.organizationURL, "/")+1:]; val != "" {
meta.organizationName = meta.organizationURL[strings.LastIndex(meta.organizationURL, "/")+1:]
} else {
return nil, fmt.Errorf("failed to extract organization name from organizationURL")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("failed to extract organization name from organizationURL")
}

if val, ok := config.AuthParams["personalAccessToken"]; ok && val != "" {
// Found the personalAccessToken in a parameter from TriggerAuthentication
meta.personalAccessToken = config.AuthParams["personalAccessToken"]
} else if val, ok := config.TriggerMetadata["personalAccessTokenFromEnv"]; ok && val != "" {
meta.personalAccessToken = config.ResolvedEnv[config.TriggerMetadata["personalAccessTokenFromEnv"]]
} else {
return nil, fmt.Errorf("no personalAccessToken given")
pat, cred, podIdentity, err := getAuthMethod(logger, config)
if err != nil {
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
// // Trim any trailing new lines from the Azure Pipelines PAT
meta.authContext = authContext{
pat: strings.TrimSuffix(pat, "\n"),
cred: cred,
}

if val, ok := config.TriggerMetadata["parent"]; ok && val != "" {
Expand All @@ -224,7 +267,7 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["jobsToFetch"]; ok && val != "" {
jobsToFetch, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing jobsToFetch: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing jobsToFetch: %w", err)
}
meta.jobsToFetch = jobsToFetch
}
Expand All @@ -233,41 +276,40 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["requireAllDemands"]; ok && val != "" {
requireAllDemands, err := strconv.ParseBool(val)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.requireAllDemands = requireAllDemands
}

if val, ok := config.TriggerMetadata["poolName"]; ok && val != "" {
var err error
poolID, err := getPoolIDFromName(ctx, val, &meta, httpClient)
poolID, err := getPoolIDFromName(ctx, logger, val, &meta, podIdentity, httpClient)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.poolID = poolID
} else {
if val, ok := config.TriggerMetadata["poolID"]; ok && val != "" {
var err error
poolID, err := validatePoolID(ctx, val, &meta, httpClient)
poolID, err := validatePoolID(ctx, logger, val, &meta, podIdentity, httpClient)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.poolID = poolID
} else {
return nil, fmt.Errorf("no poolName or poolID given")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no poolName or poolID given")
}
}

// Trim any trailing new lines from the Azure Pipelines PAT
meta.personalAccessToken = strings.TrimSuffix(meta.personalAccessToken, "\n")
meta.triggerIndex = config.TriggerIndex

return &meta, nil
return &meta, podIdentity, nil
}

func getPoolIDFromName(ctx context.Context, poolName string, metadata *azurePipelinesMetadata, httpClient *http.Client) (int, error) {
func getPoolIDFromName(ctx context.Context, logger logr.Logger, poolName string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) (int, error) {
urlString := fmt.Sprintf("%s/_apis/distributedtask/pools?poolName=%s", metadata.organizationURL, url.QueryEscape(poolName))
body, err := getAzurePipelineRequest(ctx, urlString, metadata, httpClient)
body, err := getAzurePipelineRequest(ctx, logger, urlString, metadata, podIdentity, httpClient)

if err != nil {
return -1, err
}
Expand All @@ -290,9 +332,10 @@ func getPoolIDFromName(ctx context.Context, poolName string, metadata *azurePipe
return result.Value[0].ID, nil
}

func validatePoolID(ctx context.Context, poolID string, metadata *azurePipelinesMetadata, httpClient *http.Client) (int, error) {
func validatePoolID(ctx context.Context, logger logr.Logger, poolID string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) (int, error) {
urlString := fmt.Sprintf("%s/_apis/distributedtask/pools?poolID=%s", metadata.organizationURL, poolID)
body, err := getAzurePipelineRequest(ctx, urlString, metadata, httpClient)
body, err := getAzurePipelineRequest(ctx, logger, urlString, metadata, podIdentity, httpClient)

if err != nil {
return -1, fmt.Errorf("agent pool with id `%s` not found: %w", poolID, err)
}
Expand All @@ -306,13 +349,40 @@ func validatePoolID(ctx context.Context, poolID string, metadata *azurePipelines
return result.ID, nil
}

func getAzurePipelineRequest(ctx context.Context, urlString string, metadata *azurePipelinesMetadata, httpClient *http.Client) ([]byte, error) {
func getToken(ctx context.Context, metadata *azurePipelinesMetadata, scope string) (string, error) {
token, err := metadata.authContext.cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{
scope,
},
})

if err != nil {
return "", err
}
return token.Token, nil
}

func getAzurePipelineRequest(ctx context.Context, logger logr.Logger, urlString string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", urlString, nil)
if err != nil {
return []byte{}, err
}

req.SetBasicAuth("", metadata.personalAccessToken)
switch podIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
//PAT
logger.V(1).Info("making request to ADO REST API using PAT")
req.SetBasicAuth("", metadata.authContext.pat)
case kedav1alpha1.PodIdentityProviderAzureWorkload:
//ADO Resource token
logger.V(1).Info("making request to ADO REST API using managed identity")
aadToken, err := getToken(ctx, metadata, devopsResource)
if err != nil {
return []byte{}, fmt.Errorf("cannot create workload identity credentials: %w", err)
}
logger.V(1).Info("token acquired setting auth header as 'bearer XXXXXX'")
req.Header.Set("Authorization", "Bearer "+aadToken)
}

r, err := httpClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -340,7 +410,7 @@ func (s *azurePipelinesScaler) GetAzurePipelinesQueueLength(ctx context.Context)
} else {
urlString = fmt.Sprintf("%s/_apis/distributedtask/pools/%d/jobrequests?$top=%d", s.metadata.organizationURL, s.metadata.poolID, s.metadata.jobsToFetch)
}
body, err := getAzurePipelineRequest(ctx, urlString, s.metadata, s.httpClient)
body, err := getAzurePipelineRequest(ctx, s.logger, urlString, s.metadata, s.podIdentity, s.httpClient)
if err != nil {
return -1, err
}
Expand Down
16 changes: 11 additions & 5 deletions pkg/scalers/azure_pipelines_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"net/http/httptest"
"strings"
"testing"

"github.com/go-logr/logr"
)

const loadCount = 1000 // the size of the pretend pool completed of job requests
Expand Down Expand Up @@ -68,7 +70,9 @@ func TestParseAzurePipelinesMetadata(t *testing.T) {
testData.authParams["organizationURL"] = apiStub.URL
}

_, err := parseAzurePipelinesMetadata(context.TODO(), &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testData.resolvedEnv, AuthParams: testData.authParams}, http.DefaultClient)
logger := logr.Discard()

_, _, err := parseAzurePipelinesMetadata(context.TODO(), logger, &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testData.resolvedEnv, AuthParams: testData.authParams}, http.DefaultClient)
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
Expand Down Expand Up @@ -121,8 +125,8 @@ func TestValidateAzurePipelinesPool(t *testing.T) {
"organizationURL": apiStub.URL,
"personalAccessToken": "PAT",
}

_, err := parseAzurePipelinesMetadata(context.TODO(), &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: nil, AuthParams: authParams}, http.DefaultClient)
logger := logr.Discard()
_, _, err := parseAzurePipelinesMetadata(context.TODO(), logger, &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: nil, AuthParams: authParams}, http.DefaultClient)
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
Expand Down Expand Up @@ -160,7 +164,9 @@ func TestAzurePipelinesGetMetricSpecForScaling(t *testing.T) {
"targetPipelinesQueueLength": "1",
}

meta, err := parseAzurePipelinesMetadata(context.TODO(), &ScalerConfig{TriggerMetadata: metadata, ResolvedEnv: nil, AuthParams: authParams, TriggerIndex: testData.triggerIndex}, http.DefaultClient)
logger := logr.Discard()

meta, _, err := parseAzurePipelinesMetadata(context.TODO(), logger, &ScalerConfig{TriggerMetadata: metadata, ResolvedEnv: nil, AuthParams: authParams, TriggerIndex: testData.triggerIndex}, http.DefaultClient)
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
Expand All @@ -183,7 +189,7 @@ func getMatchedAgentMetaData(url string) *azurePipelinesMetadata {
meta.organizationName = "testOrg"
meta.organizationURL = url
meta.parent = "dotnet60-keda-template"
meta.personalAccessToken = "testPAT"
meta.authContext.pat = "testPAT"
meta.poolID = 1
meta.targetPipelinesQueueLength = 1

Expand Down
Loading

0 comments on commit d3b1793

Please sign in to comment.