Skip to content

Commit

Permalink
fix auth header for bearer authtype, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
toniiiik committed Aug 10, 2023
1 parent 65b9215 commit 2040b7d
Show file tree
Hide file tree
Showing 2 changed files with 374 additions and 48 deletions.
105 changes: 57 additions & 48 deletions pkg/scalers/azure_pipelines_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,17 @@ type azurePipelinesPoolIDResponse struct {
}

type azurePipelinesScaler struct {
metricType v2.MetricTargetType
metadata *azurePipelinesMetadata
httpClient *http.Client
logger logr.Logger
metricType v2.MetricTargetType
metadata *azurePipelinesMetadata
httpClient *http.Client
podIdentity kedav1alpha1.AuthPodIdentity
logger logr.Logger
}

type azurePipelinesMetadata struct {
organizationURL string
organizationName string
personalAccessToken string
podIdentityProvider kedav1alpha1.PodIdentityProvider
parent string
demands string
poolID int
Expand All @@ -145,48 +145,45 @@ type azurePipelinesMetadata struct {
func NewAzurePipelinesScaler(ctx context.Context, config *ScalerConfig) (Scaler, error) {
httpClient := kedautil.CreateHTTPClient(config.GlobalHTTPTimeout, false)

logger := InitializeLogger(config, "azure_pipelines_scaler")
metricType, err := GetMetricTargetType(config)
if err != nil {
return nil, fmt.Errorf("error getting scaler metric type: %w", err)
}

meta, err := parseAzurePipelinesMetadata(ctx, config, httpClient)
meta, podIdentity, err := parseAzurePipelinesMetadata(ctx, logger, config, httpClient)
if err != nil {
return nil, fmt.Errorf("error parsing azure Pipelines metadata: %w", err)
}

return &azurePipelinesScaler{
metricType: metricType,
metadata: meta,
httpClient: httpClient,
logger: InitializeLogger(config, "azure_pipelines_scaler"),
metricType: metricType,
metadata: meta,
httpClient: httpClient,
podIdentity: podIdentity,
logger: logger,
}, nil
}

func parseAzureDevOpsAuthMethod(config *ScalerConfig, metadata *azurePipelinesMetadata) error {
if val, ok := config.AuthParams["personalAccessToken"]; ok && val != "" {
// Found the personalAccessToken in a parameter from TriggerAuthentication
metadata.personalAccessToken = config.AuthParams["personalAccessToken"]
} else if val, ok := config.TriggerMetadata["personalAccessTokenFromEnv"]; ok && val != "" {
metadata.personalAccessToken = config.ResolvedEnv[config.TriggerMetadata["personalAccessTokenFromEnv"]]
} else if config.PodIdentity.Provider == kedav1alpha1.PodIdentityProviderAzureWorkload {
//use workload identity
metadata.podIdentityProvider = config.PodIdentity.Provider
} else {
return fmt.Errorf("no personalAccessToken given or PodIdentity provider configured")
func getAuthPodIdentity(config *ScalerConfig) (kedav1alpha1.AuthPodIdentity, error) {
switch config.PodIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
return kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderNone}, nil
case kedav1alpha1.PodIdentityProviderAzure, kedav1alpha1.PodIdentityProviderAzureWorkload:
return kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderAzureWorkload}, nil
default:
return kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("pod identity %s not supported for azure storage blobs", config.PodIdentity)
}

return nil
}

func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, httpClient *http.Client) (*azurePipelinesMetadata, error) {
func parseAzurePipelinesMetadata(ctx context.Context, logger logr.Logger, config *ScalerConfig, httpClient *http.Client) (*azurePipelinesMetadata, kedav1alpha1.AuthPodIdentity, error) {
meta := azurePipelinesMetadata{}
meta.targetPipelinesQueueLength = defaultTargetPipelinesQueueLength

if val, ok := config.TriggerMetadata["targetPipelinesQueueLength"]; ok {
queueLength, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing azure pipelines metadata targetPipelinesQueueLength: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing azure pipelines metadata targetPipelinesQueueLength: %w", err)
}

meta.targetPipelinesQueueLength = queueLength
Expand All @@ -196,7 +193,7 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["activationTargetPipelinesQueueLength"]; ok {
activationQueueLength, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing azure pipelines metadata activationTargetPipelinesQueueLength: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing azure pipelines metadata activationTargetPipelinesQueueLength: %w", err)
}

meta.activationTargetPipelinesQueueLength = activationQueueLength
Expand All @@ -208,18 +205,27 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
} else if val, ok := config.TriggerMetadata["organizationURLFromEnv"]; ok && val != "" {
meta.organizationURL = config.ResolvedEnv[val]
} else {
return nil, fmt.Errorf("no organizationURL given")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no organizationURL given")
}

if val := meta.organizationURL[strings.LastIndex(meta.organizationURL, "/")+1:]; val != "" {
meta.organizationName = meta.organizationURL[strings.LastIndex(meta.organizationURL, "/")+1:]
} else {
return nil, fmt.Errorf("failed to extract organization name from organizationURL")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("failed to extract organization name from organizationURL")
}

err := parseAzureDevOpsAuthMethod(config, &meta)
podIdentity, err := getAuthPodIdentity(config)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}

if val, ok := config.AuthParams["personalAccessToken"]; ok && val != "" {
// Found the personalAccessToken in a parameter from TriggerAuthentication
meta.personalAccessToken = config.AuthParams["personalAccessToken"]
} else if val, ok := config.TriggerMetadata["personalAccessTokenFromEnv"]; ok && val != "" {
meta.personalAccessToken = config.ResolvedEnv[config.TriggerMetadata["personalAccessTokenFromEnv"]]
} else if podIdentity.Provider == kedav1alpha1.PodIdentityProviderNone {
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no personalAccessToken given or PodIdentity provider configured")
}

if val, ok := config.TriggerMetadata["parent"]; ok && val != "" {
Expand All @@ -238,7 +244,7 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["jobsToFetch"]; ok && val != "" {
jobsToFetch, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing jobsToFetch: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing jobsToFetch: %w", err)
}
meta.jobsToFetch = jobsToFetch
}
Expand All @@ -247,41 +253,41 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["requireAllDemands"]; ok && val != "" {
requireAllDemands, err := strconv.ParseBool(val)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.requireAllDemands = requireAllDemands
}

if val, ok := config.TriggerMetadata["poolName"]; ok && val != "" {
var err error
poolID, err := getPoolIDFromName(ctx, val, &meta, httpClient)
poolID, err := getPoolIDFromName(ctx, logger, val, &meta, podIdentity, httpClient)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.poolID = poolID
} else {
if val, ok := config.TriggerMetadata["poolID"]; ok && val != "" {
var err error
poolID, err := validatePoolID(ctx, val, &meta, httpClient)
poolID, err := validatePoolID(ctx, logger, val, &meta, podIdentity, httpClient)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.poolID = poolID
} else {
return nil, fmt.Errorf("no poolName or poolID given")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no poolName or poolID given")
}
}

// Trim any trailing new lines from the Azure Pipelines PAT
// // Trim any trailing new lines from the Azure Pipelines PAT
meta.personalAccessToken = strings.TrimSuffix(meta.personalAccessToken, "\n")
meta.scalerIndex = config.ScalerIndex

return &meta, nil
return &meta, podIdentity, nil
}

func getPoolIDFromName(ctx context.Context, poolName string, metadata *azurePipelinesMetadata, httpClient *http.Client) (int, error) {
func getPoolIDFromName(ctx context.Context, logger logr.Logger, poolName string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) (int, error) {
url := fmt.Sprintf("%s/_apis/distributedtask/pools?poolName=%s", metadata.organizationURL, poolName)
body, err := getAzurePipelineRequest(ctx, url, metadata, httpClient)
body, err := getAzurePipelineRequest(ctx, logger, url, metadata, podIdentity, httpClient)
if err != nil {
return -1, err
}
Expand All @@ -304,9 +310,9 @@ func getPoolIDFromName(ctx context.Context, poolName string, metadata *azurePipe
return result.Value[0].ID, nil
}

func validatePoolID(ctx context.Context, poolID string, metadata *azurePipelinesMetadata, httpClient *http.Client) (int, error) {
func validatePoolID(ctx context.Context, logger logr.Logger, poolID string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) (int, error) {
url := fmt.Sprintf("%s/_apis/distributedtask/pools?poolID=%s", metadata.organizationURL, poolID)
body, err := getAzurePipelineRequest(ctx, url, metadata, httpClient)
body, err := getAzurePipelineRequest(ctx, logger, url, metadata, podIdentity, httpClient)
if err != nil {
return -1, fmt.Errorf("agent pool with id `%s` not found: %w", poolID, err)
}
Expand All @@ -320,24 +326,27 @@ func validatePoolID(ctx context.Context, poolID string, metadata *azurePipelines
return result.ID, nil
}

func getAzurePipelineRequest(ctx context.Context, url string, metadata *azurePipelinesMetadata, httpClient *http.Client) ([]byte, error) {
func getAzurePipelineRequest(ctx context.Context, logger logr.Logger, url string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return []byte{}, err
}

switch metadata.podIdentityProvider {
switch podIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
//PAT
logger.V(0).Info("making request to ADO REST API using PAT")
req.SetBasicAuth("", metadata.personalAccessToken)
case kedav1alpha1.PodIdentityProviderAzureWorkload:
//ADO Resource token
resource := "499b84ac-1321-427f-aa17-267ca6975798"
logger.V(0).Info("making request to ADO REST API using managed identity")
aadToken, err := azure.GetAzureADWorkloadIdentityToken(ctx, "", resource)
if err != nil {
return []byte{}, fmt.Errorf("cannot create workload identity credentials: %s", err.Error())
return []byte{}, fmt.Errorf("cannot create workload identity credentials: %w", err)
}
req.Header.Set("Authentication", "Bearer "+aadToken.AccessToken)
logger.V(0).Info("token acquired setting auth header as 'bearer XXXXXX'")
req.Header.Set("Authorization", "Bearer "+aadToken.AccessToken)
}

r, err := httpClient.Do(req)
Expand Down Expand Up @@ -366,7 +375,7 @@ func (s *azurePipelinesScaler) GetAzurePipelinesQueueLength(ctx context.Context)
} else {
url = fmt.Sprintf("%s/_apis/distributedtask/pools/%d/jobrequests?$top=%d", s.metadata.organizationURL, s.metadata.poolID, s.metadata.jobsToFetch)
}
body, err := getAzurePipelineRequest(ctx, url, s.metadata, s.httpClient)
body, err := getAzurePipelineRequest(ctx, s.logger, url, s.metadata, s.podIdentity, s.httpClient)
if err != nil {
return -1, err
}
Expand Down
Loading

0 comments on commit 2040b7d

Please sign in to comment.