Skip to content

Commit

Permalink
feat: add support for aws ecr tokens (#2650)
Browse files Browse the repository at this point in the history
Signed-off-by: K Tamil Vanan <[email protected]>
  • Loading branch information
tamilhce committed Jan 9, 2025
1 parent e410f39 commit d3d08a7
Show file tree
Hide file tree
Showing 10 changed files with 470 additions and 32 deletions.
40 changes: 40 additions & 0 deletions examples/config-sync-ecr-credential-helper.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"distSpecVersion": "1.1.0",
"storage": {
"rootDirectory": "/tmp/zot",
"dedupe": false,
"storageDriver": {
"name": "s3",
"region": "REGION_NAME",
"bucket": "BUGKET_NAME",
"rootdirectory": "/ROOTDIR",
"secure": true,
"skipverify": false
}
},
"http": {
"address": "0.0.0.0",
"port": "8080"
},
"log": {
"level": "debug"
},
"extensions": {
"sync": {
"credentialsFile": "",
"DownloadDir": "/tmp/zot",
"registries": [
{
"urls": [
"https://ACCOUNTID.dkr.ecr.REGION.amazonaws.com"
],
"onDemand": true,
"maxRetries": 5,
"retryDelay": "2m",
"credentialHelper": "ecr"
}
]
}
}
}

2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.28.7
github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.22
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.1
github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.8
github.com/aws/aws-secretsmanager-caching-go v1.2.0
github.com/aws/smithy-go v1.22.1
Expand Down Expand Up @@ -158,7 +159,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.10 // indirect
github.com/aws/aws-sdk-go-v2/service/ebs v1.25.3 // indirect
github.com/aws/aws-sdk-go-v2/service/ec2 v1.193.0 // indirect
github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6 // indirect
github.com/aws/aws-sdk-go-v2/service/ecrpublic v1.25.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.7 // indirect
Expand Down
19 changes: 10 additions & 9 deletions pkg/extensions/config/sync/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ type Config struct {
}

type RegistryConfig struct {
URLs []string
PollInterval time.Duration
Content []Content
TLSVerify *bool
OnDemand bool
CertDir string
MaxRetries *int
RetryDelay *time.Duration
OnlySigned *bool
URLs []string
PollInterval time.Duration
Content []Content
TLSVerify *bool
OnDemand bool
CertDir string
MaxRetries *int
RetryDelay *time.Duration
OnlySigned *bool
CredentialHelper string
}

type Content struct {
Expand Down
162 changes: 162 additions & 0 deletions pkg/extensions/sync/ecr_credential_helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//go:build sync
// +build sync

package sync

import (
"context"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecr"

syncconf "zotregistry.dev/zot/pkg/extensions/config/sync"
"zotregistry.dev/zot/pkg/log"
)

// ECR tokens are valid for 12 hours. The ExpiryWindow variable is set to 1 hour,
// meaning if the remaining validity of the token is less than 1 hour, it will be considered expired.
const (
ExpiryWindow int = 1
ECRURLSplitPartsCount int = 6
UsernameTokenParts int = 2
)

var (
ErrInvalidURLFormat = errors.New("invalid ECR URL is received")
ErrInvalidTokenFormat = errors.New("invalid token format received from ECR")
ErrUnableToLoadAWSConfig = errors.New("unable to load AWS config for region")
ErrUnableToGetECRAuthToken = errors.New("unable to get ECR authorization token for account")
ErrUnableToDecodeECRToken = errors.New("unable to decode ECR token")
ErrFailedToGetECRCredentials = errors.New("failed to get ECR credentials")
)

type ECRCredential struct {
username string
password string
expiry time.Time
account string
region string
}

type ECRCredentialsHelper struct {
credentials map[string]ECRCredential
log log.Logger
}

func NewECRCredentialHelper(log log.Logger) CredentialHelper {
return &ECRCredentialsHelper{
credentials: make(map[string]ECRCredential),
log: log,
}
}

// extractAccountAndRegion extracts the account ID and region from the given ECR URL.
// Example URL format: account.dkr.ecr.region.amazonaws.com.
func extractAccountAndRegion(url string) (string, string, error) {
parts := strings.Split(url, ".")
if len(parts) < ECRURLSplitPartsCount {
return "", "", fmt.Errorf("%w: %s", ErrInvalidURLFormat, url)
}

accountID := parts[0] // First part is the account ID
region := parts[3] // Fourth part is the region

return accountID, region, nil
}

func getECRCredentials(remoteAddress string) (ECRCredential, error) {
// Extract account ID and region from the URL.
accountID, region, err := extractAccountAndRegion(remoteAddress)
if err != nil {
return ECRCredential{}, fmt.Errorf("%w %s: %w", ErrInvalidTokenFormat, remoteAddress, err)
}

// Load the AWS config for the specific region.
cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region))
if err != nil {
return ECRCredential{}, fmt.Errorf("%w %s: %w", ErrUnableToLoadAWSConfig, region, err)
}

// Create an ECR client
ecrClient := ecr.NewFromConfig(cfg)

// Fetch the ECR authorization token.
ecrAuth, err := ecrClient.GetAuthorizationToken(context.TODO(), &ecr.GetAuthorizationTokenInput{
RegistryIds: []string{accountID}, // Filter by the account ID.
})
if err != nil {
return ECRCredential{}, fmt.Errorf("%w %s: %w", ErrUnableToGetECRAuthToken, accountID, err)
}

// Decode the base64-encoded ECR token.
authToken := *ecrAuth.AuthorizationData[0].AuthorizationToken
decodedToken, err := base64.StdEncoding.DecodeString(authToken)

if err != nil {
return ECRCredential{}, fmt.Errorf("%w: %w", ErrUnableToDecodeECRToken, err)
}

// Split the decoded token into username and password (username is "AWS").
tokenParts := strings.Split(string(decodedToken), ":")
if len(tokenParts) != UsernameTokenParts {
return ECRCredential{}, fmt.Errorf("%w", ErrInvalidTokenFormat)
}

expiry := *ecrAuth.AuthorizationData[0].ExpiresAt
username := tokenParts[0]
password := tokenParts[1]

return ECRCredential{username: username, password: password, expiry: expiry, account: accountID, region: region}, nil
}

// GetECRCredentials retrieves the ECR credentials (username and password) from AWS ECR.
func (credHelper *ECRCredentialsHelper) GetCredentials(urls []string) (syncconf.CredentialsFile, error) {
ecrCredentials := make(syncconf.CredentialsFile)

for _, url := range urls {
remoteAddress := StripRegistryTransport(url)
ecrCred, err := getECRCredentials(remoteAddress)

if err != nil {
return syncconf.CredentialsFile{}, fmt.Errorf("%w %s: %w", ErrFailedToGetECRCredentials, url, err)
}
// Store the credentials in the map using the base URL as the key.
ecrCredentials[remoteAddress] = syncconf.Credentials{
Username: ecrCred.username,
Password: ecrCred.password,
}
credHelper.credentials[remoteAddress] = ecrCred
}

return ecrCredentials, nil
}

func (credHelper *ECRCredentialsHelper) IsCredentialsValid(remoteAddress string) bool {
expiry := credHelper.credentials[remoteAddress].expiry
expiryDuration := time.Duration(ExpiryWindow) * time.Hour

if time.Until(expiry) <= expiryDuration {
credHelper.log.Info().Str("url", remoteAddress).Msg("The credentials are close to expiring")

return false
}
credHelper.log.Info().Str("url", remoteAddress).Msg("The credentials are valid")

return true
}

func (credHelper *ECRCredentialsHelper) RefreshCredentials(remoteAddress string) (syncconf.Credentials, error) {
credHelper.log.Info().Str("url", remoteAddress).Msg("Refreshing the ECR credentials")
ecrCred, err := getECRCredentials(remoteAddress)

if err != nil {
return syncconf.Credentials{}, fmt.Errorf("%w %s: %w", ErrFailedToGetECRCredentials, remoteAddress, err)
}

return syncconf.Credentials{Username: ecrCred.username, Password: ecrCred.password}, nil
}
7 changes: 7 additions & 0 deletions pkg/extensions/sync/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ func NewRemoteRegistry(client *client.Client, logger log.Logger) Remote {
return registry
}

func (registry *RemoteRegistry) SetUpstreamAuthConfig(username, password string) {
registry.context.DockerAuthConfig = &types.DockerAuthConfig{
Username: username,
Password: password,
}
}

func (registry *RemoteRegistry) GetContext() *types.SystemContext {
return registry.context
}
Expand Down
102 changes: 80 additions & 22 deletions pkg/extensions/sync/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@ import (
)

type BaseService struct {
config syncconf.RegistryConfig
credentials syncconf.CredentialsFile
clusterConfig *config.ClusterConfig
remote Remote
destination Destination
retryOptions *retry.RetryOptions
contentManager ContentManager
storeController storage.StoreController
metaDB mTypes.MetaDB
repositories []string
references references.References
client *client.Client
log log.Logger
config syncconf.RegistryConfig
credentials syncconf.CredentialsFile
credentialHelper CredentialHelper
clusterConfig *config.ClusterConfig
remote Remote
destination Destination
retryOptions *retry.RetryOptions
contentManager ContentManager
storeController storage.StoreController
metaDB mTypes.MetaDB
repositories []string
references references.References
client *client.Client
log log.Logger
}

func New(
Expand All @@ -60,16 +61,37 @@ func New(
var err error

var credentialsFile syncconf.CredentialsFile
if credentialsFilepath != "" {
credentialsFile, err = getFileCredentials(credentialsFilepath)
if err != nil {
log.Error().Str("errortype", common.TypeOf(err)).Str("path", credentialsFilepath).
Err(err).Msg("couldn't get registry credentials from configured path")
if service.config.CredentialHelper == "" {
// Only load credentials from file if CredentialHelper is not set
if credentialsFilepath != "" {
log.Info().Msgf("Using file-based credentials because CredentialHelper is not set")
credentialsFile, err = getFileCredentials(credentialsFilepath)
if err != nil {
log.Error().Str("errortype", common.TypeOf(err)).Str("path", credentialsFilepath).
Err(err).Msg("couldn't get registry credentials from configured path")
}
service.credentialHelper = nil
service.credentials = credentialsFile
}
} else {
log.Info().Msgf("Using credentials helper, because CredentialHelper is set to %s", service.config.CredentialHelper)

credentialHelper := service.config.CredentialHelper
switch credentialHelper {
case "ecr":
// Logic to fetch credentials for ECR
log.Info().Msg("Fetch the credentials using AWS ECR Auth Token.")
service.credentialHelper = NewECRCredentialHelper(log)
creds, err := service.credentialHelper.GetCredentials(service.config.URLs)
if err != nil {
log.Error().Err(err).Msg("Failed to retrieve credentials using ECR credentials helper.")
}
service.credentials = creds
default:
log.Warn().Msgf("Unsupported CredentialHelper: %s", credentialHelper)
}
}

service.credentials = credentialsFile

// load the cluster config into the object
// can be nil if the user did not configure cluster config
service.clusterConfig = clusterConfig
Expand Down Expand Up @@ -102,7 +124,6 @@ func New(

service.retryOptions = retryOptions
service.storeController = storeController

// try to set next client.
if err := service.SetNextAvailableClient(); err != nil {
// if it's a ping issue, it will be retried
Expand All @@ -126,9 +147,46 @@ func New(
return service, nil
}

// refreshRegistryTemporaryCredentials refreshes the temporary credentials for the registry if necessary.
// It checks whether a CredentialHelper is configured and if the current credentials have expired.
// If the credentials are expired, it attempts to refresh them and updates the service configuration.
func (service *BaseService) refreshRegistryTemporaryCredentials() error {
// If a CredentialHelper is configured, proceed to refresh the credentials if they are invalid or expired.
if service.config.CredentialHelper != "" {
// Strip the transport protocol (e.g., https:// or http://) from the remote address.
remoteAddress := StripRegistryTransport(service.client.GetHostname())

if !service.credentialHelper.IsCredentialsValid(remoteAddress) {
// Attempt to refresh the credentials using the CredentialHelper.
credentials, err := service.credentialHelper.RefreshCredentials(remoteAddress)
if err != nil {
service.log.Error().
Err(err).
Str("url", remoteAddress).
Msg("Failed to refresh the credentials")

return err
}

service.log.Info().
Str("url", remoteAddress).
Msg("Refreshing the upstream remote registry credentials")

// Update the service's credentials map with the new set of credentials.
service.credentials[remoteAddress] = credentials

// Set the upstream authentication context using the refreshed credentials.
service.remote.SetUpstreamAuthConfig(credentials.Username, credentials.Password)
}
}

// Return nil to indicate the operation completed successfully.
return nil
}

func (service *BaseService) SetNextAvailableClient() error {
if service.client != nil && service.client.Ping() {
return nil
return service.refreshRegistryTemporaryCredentials()
}

found := false
Expand Down
Loading

0 comments on commit d3d08a7

Please sign in to comment.