Skip to content
This repository has been archived by the owner on Nov 1, 2022. It is now read-only.

Commit

Permalink
Add ECR auth helpers
Browse files Browse the repository at this point in the history
This takes the essential workings of
#1455 and creates a helper for
using AWS authorization with ECR (the AWS container registry).
  • Loading branch information
squaremo committed Dec 27, 2018
1 parent 39901c6 commit cd451fc
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 42 deletions.
97 changes: 67 additions & 30 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

95 changes: 95 additions & 0 deletions registry/aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package registry

// References:
// - https://github.com/bzon/ecr-k8s-secret-creator
// - https://github.com/kubernetes/kubernetes/blob/master/pkg/credentialprovider/aws/aws_credentials.go
// - https://github.com/weaveworks/flux/pull/1455

import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/go-kit/kit/log"
)

const (
// For recognising ECR hosts
ecrHostSuffix = ".amazonaws.com"
// How long AWS tokens remain valid
tokenValid = 12 * time.Hour
)

type AWSRegistryConfig struct {
Region string
RegistryIDs []string
}

func ImageCredsWithAWS(lookup func() ImageCreds, logger log.Logger, config AWSRegistryConfig) (func() ImageCreds, error) {
awsCreds := NoCredentials()
var credsExpire time.Time

refresh := func(now time.Time) error {
var err error
awsCreds, err = fetchAWSCreds(config)
if err != nil {
// bump this along so we don't spam the log
credsExpire = now.Add(time.Hour)
return err
}
credsExpire = now.Add(tokenValid)
return nil
}

// pre-flight check
if err := refresh(time.Now()); err != nil {
return nil, err
}

return func() ImageCreds {
imageCreds := lookup()

now := time.Now()
if now.After(credsExpire) {
if err := refresh(now); err != nil {
logger.Log("warning", "AWS token not refreshed", "err", err)
}
}

for name, creds := range imageCreds {
if strings.HasSuffix(name.Domain, ecrHostSuffix) {
newCreds := NoCredentials()
newCreds.Merge(awsCreds)
newCreds.Merge(creds)
imageCreds[name] = newCreds
}
}
return imageCreds
}, nil
}

func fetchAWSCreds(config AWSRegistryConfig) (Credentials, error) {
sess := session.Must(session.NewSession(&aws.Config{Region: &config.Region}))
svc := ecr.New(sess)
ecrToken, err := svc.GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{
RegistryIds: aws.StringSlice(config.RegistryIDs),
})
if err != nil {
return Credentials{}, err
}
auths := make(map[string]creds)
for _, v := range ecrToken.AuthorizationData {
// Remove the https prefix
host := strings.TrimPrefix(*v.ProxyEndpoint, "https://")
creds, err := parseAuth(*v.AuthorizationToken)
if err != nil {
return Credentials{}, err
}
creds.provenance = "AWS API"
creds.registry = host
auths[host] = creds
}
return Credentials{m: auths}, nil
}
32 changes: 20 additions & 12 deletions registry/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ func NoCredentials() Credentials {
}
}

func parseAuth(auth string) (creds, error) {
decodedAuth, err := base64.StdEncoding.DecodeString(auth)
if err != nil {
return creds{}, err
}
authParts := strings.SplitN(string(decodedAuth), ":", 2)
if len(authParts) != 2 {
return creds{},
fmt.Errorf("decoded credential has wrong number of fields (expected 2, got %d)", len(authParts))
}
return creds{
username: authParts[0],
password: authParts[1],
}, nil
}

func ParseCredentials(from string, b []byte) (Credentials, error) {
var config struct {
Auths map[string]struct {
Expand All @@ -53,15 +69,10 @@ func ParseCredentials(from string, b []byte) (Credentials, error) {
}
m := map[string]creds{}
for host, entry := range config.Auths {
decodedAuth, err := base64.StdEncoding.DecodeString(entry.Auth)
creds, err := parseAuth(entry.Auth)
if err != nil {
return Credentials{}, err
}
authParts := strings.SplitN(string(decodedAuth), ":", 2)
if len(authParts) != 2 {
return Credentials{},
fmt.Errorf("decoded credential for %v has wrong number of fields (expected 2, got %d)", host, len(authParts))
}

// Some users were passing in credentials in the form of
// http://docker.io and http://docker.io/v1/, etc.
Expand All @@ -87,12 +98,9 @@ func ParseCredentials(from string, b []byte) (Credentials, error) {
}
host = u.Host

m[host] = creds{
registry: host,
provenance: from,
username: authParts[0],
password: authParts[1],
}
creds.registry = host
creds.provenance = from
m[host] = creds
}
return Credentials{m: m}, nil
}
Expand Down

0 comments on commit cd451fc

Please sign in to comment.