diff --git a/blob/s3blob/s3blob.go b/blob/s3blob/s3blob.go index c7a5bd3c41..b85e4dfa4e 100644 --- a/blob/s3blob/s3blob.go +++ b/blob/s3blob/s3blob.go @@ -144,24 +144,58 @@ type URLOpener struct { Options Options } +const ( + sseTypeParamKey = "ssetype" + kmsKeyIdParamKey = "kmskeyid" +) + +func toServerSideEncryptionType(value string) (typesv2.ServerSideEncryption, error) { + for _, sseType := range typesv2.ServerSideEncryptionAes256.Values() { + if strings.ToLower(string(sseType)) == strings.ToLower(value) { + return sseType, nil + } + } + return "", fmt.Errorf("'%s' is not a valid value for '%s'", value, sseTypeParamKey) +} + // OpenBucketURL opens a blob.Bucket based on u. func (o *URLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) { + q := u.Query() + + if sseTypeParam := q.Get(sseTypeParamKey); sseTypeParam != "" { + q.Del(sseTypeParamKey) + + sseType, err := toServerSideEncryptionType(sseTypeParam) + if err != nil { + return nil, err + } + + o.Options.EncryptionType = sseType + } + + if kmsKeyID := q.Get(kmsKeyIdParamKey); kmsKeyID != "" { + q.Del(kmsKeyIdParamKey) + o.Options.KMSEncryptionID = kmsKeyID + } + if o.UseV2 { - cfg, err := gcaws.V2ConfigFromURLParams(ctx, u.Query()) + cfg, err := gcaws.V2ConfigFromURLParams(ctx, q) if err != nil { return nil, fmt.Errorf("open bucket %v: %v", u, err) } clientV2 := s3v2.NewFromConfig(cfg) + return OpenBucketV2(ctx, clientV2, u.Host, &o.Options) } configProvider := &gcaws.ConfigOverrider{ Base: o.ConfigProvider, } - overrideCfg, err := gcaws.ConfigFromURLParams(u.Query()) + overrideCfg, err := gcaws.ConfigFromURLParams(q) if err != nil { return nil, fmt.Errorf("open bucket %v: %v", u, err) } configProvider.Configs = append(configProvider.Configs, overrideCfg) + return OpenBucket(ctx, configProvider, u.Host, &o.Options) } @@ -171,6 +205,16 @@ type Options struct { // Some S3-compatible services (like CEPH) do not currently support // ListObjectsV2. UseLegacyList bool + + // EncryptionType sets the encryption type headers when making write or + // copy calls. This is required if the bucket has a restrictive bucket + // policy that enforces a specific encryption type + EncryptionType typesv2.ServerSideEncryption + + // KMSEncryptionID sets the kms key id header for write or copy calls. + // This is required when a bucket policy enforces the use of a specific + // KMS key for uploads + KMSEncryptionID string } // openBucket returns an S3 Bucket. @@ -193,11 +237,13 @@ func openBucket(ctx context.Context, useV2 bool, sess client.ConfigProvider, cli client = s3.New(sess) } return &bucket{ - useV2: useV2, - name: bucketName, - client: client, - clientV2: clientV2, - useLegacyList: opts.UseLegacyList, + useV2: useV2, + name: bucketName, + client: client, + clientV2: clientV2, + useLegacyList: opts.UseLegacyList, + kmsKeyId: opts.KMSEncryptionID, + encryptionType: opts.EncryptionType, }, nil } @@ -365,6 +411,9 @@ type bucket struct { client *s3.S3 clientV2 *s3v2.Client useLegacyList bool + + encryptionType typesv2.ServerSideEncryption + kmsKeyId string } func (b *bucket) Close() error { @@ -973,6 +1022,12 @@ func (b *bucket) NewTypedWriter(ctx context.Context, key string, contentType str if len(opts.ContentMD5) > 0 { reqV2.ContentMD5 = aws.String(base64.StdEncoding.EncodeToString(opts.ContentMD5)) } + if b.encryptionType != "" { + reqV2.ServerSideEncryption = b.encryptionType + } + if b.kmsKeyId != "" { + reqV2.SSEKMSKeyId = aws.String(b.kmsKeyId) + } if opts.BeforeWrite != nil { asFunc := func(i interface{}) bool { // Note that since the Go CDK Blob @@ -1046,6 +1101,12 @@ func (b *bucket) NewTypedWriter(ctx context.Context, key string, contentType str if len(opts.ContentMD5) > 0 { req.ContentMD5 = aws.String(base64.StdEncoding.EncodeToString(opts.ContentMD5)) } + if b.encryptionType != "" { + req.ServerSideEncryption = aws.String(string(b.encryptionType)) + } + if b.kmsKeyId != "" { + req.SSEKMSKeyId = aws.String(b.kmsKeyId) + } if opts.BeforeWrite != nil { asFunc := func(i interface{}) bool { pu, ok := i.(**s3manager.Uploader) @@ -1083,6 +1144,12 @@ func (b *bucket) Copy(ctx context.Context, dstKey, srcKey string, opts *driver.C CopySource: aws.String(b.name + "/" + srcKey), Key: aws.String(dstKey), } + if b.encryptionType != "" { + input.ServerSideEncryption = b.encryptionType + } + if b.kmsKeyId != "" { + input.SSEKMSKeyId = aws.String(b.kmsKeyId) + } if opts.BeforeCopy != nil { asFunc := func(i interface{}) bool { switch v := i.(type) { @@ -1104,6 +1171,12 @@ func (b *bucket) Copy(ctx context.Context, dstKey, srcKey string, opts *driver.C CopySource: aws.String(b.name + "/" + srcKey), Key: aws.String(dstKey), } + if b.encryptionType != "" { + input.ServerSideEncryption = aws.String(string(b.encryptionType)) + } + if b.kmsKeyId != "" { + input.SSEKMSKeyId = aws.String(b.kmsKeyId) + } if opts.BeforeCopy != nil { asFunc := func(i interface{}) bool { switch v := i.(type) { diff --git a/blob/s3blob/s3blob_test.go b/blob/s3blob/s3blob_test.go index 865d920139..b8e5ccfac6 100644 --- a/blob/s3blob/s3blob_test.go +++ b/blob/s3blob/s3blob_test.go @@ -466,6 +466,10 @@ func TestOpenBucketFromURL(t *testing.T) { {"s3://mybucket?profile=main®ion=us-west-1", false}, // OK, use V2. {"s3://mybucket?awssdk=v2", false}, + // OK, use KMS Server Side Encryption + {"s3://mybucket?ssetype=aws:kms&kmskeyid=arn:aws:us-east-1:12345:key/1-a-2-b", false}, + // Invalid ssetype + {"s3://mybucket?ssetype=aws:notkmsoraes&kmskeyid=arn:aws:us-east-1:12345:key/1-a-2-b", true}, // Invalid parameter together with a valid one. {"s3://mybucket?profile=main¶m=value", true}, // Invalid parameter. @@ -483,3 +487,32 @@ func TestOpenBucketFromURL(t *testing.T) { } } } + +func TestToServerSideEncryptionType(t *testing.T) { + tests := []struct { + value string + sseType typesv2.ServerSideEncryption + expectedError error + }{ + // OK. + {"AES256", typesv2.ServerSideEncryptionAes256, nil}, + // OK, KMS + {"aws:kms", typesv2.ServerSideEncryptionAwsKms, nil}, + // OK, KMS + {"aws:kms:dsse", typesv2.ServerSideEncryptionAwsKmsDsse, nil}, + // OK, AES256 mixed case + {"Aes256", typesv2.ServerSideEncryptionAes256, nil}, + // Invalid SSE type + {"invalid", "", fmt.Errorf("'invalid' is not a valid value for '%s'", sseTypeParamKey)}, + } + + for _, test := range tests { + sseType, err := toServerSideEncryptionType(test.value) + if ((err != nil) != (test.expectedError != nil)) && err.Error() != test.expectedError.Error() { + t.Errorf("%s: got error \"%v\", want error \"%v\"", test.value, err, test.expectedError) + } + if sseType != test.sseType { + t.Errorf("%s: got type %v, want type %v", test.value, sseType, test.sseType) + } + } +}