Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update upstream #15

Merged
merged 6 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions cfn/cfn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log"
"os"
"path"
"sync"
"time"

"github.com/aws-cloudformation/cloudformation-cli-go-plugin/cfn/cfnerr"
Expand Down Expand Up @@ -39,6 +40,8 @@ const (
listAction = "LIST"
)

var once sync.Once

// Handler is the interface that all resource providers must implement
//
// Each method of Handler maps directly to a CloudFormation action.
Expand Down Expand Up @@ -97,15 +100,21 @@ type handlerFunc func(request handler.Request) handler.ProgressEvent
// MakeEventFunc is the entry point to all invocations of a custom resource
func makeEventFunc(h Handler) eventFunc {
return func(ctx context.Context, event *event) (response, error) {
//pls := credentials.SessionFromCredentialsProvider(&event.RequestData.PlatformCredentials)
ps := credentials.SessionFromCredentialsProvider(&event.RequestData.ProviderCredentials)
l, err := logging.NewCloudWatchLogsProvider(
cloudwatchlogs.New(ps),
event.RequestData.ProviderLogGroupName,
)
// Set default logger to output to CWL in the provider account
logging.SetProviderLogOutput(l)
m := metrics.New(cloudwatch.New(ps), event.ResourceType)
once.Do(func() {
l, err := logging.NewCloudWatchLogsProvider(
cloudwatchlogs.New(ps),
event.RequestData.ProviderLogGroupName,
)
if err != nil {
log.Printf("Error: %v, Logging to Stdout", err)
m.PublishExceptionMetric(time.Now(), event.Action, err)
l = os.Stdout
}
// Set default logger to output to CWL in the provider account
logging.SetProviderLogOutput(l)
})
re := newReportErr(m)
if err := scrubFiles("/tmp"); err != nil {
log.Printf("Error: %v", err)
Expand All @@ -119,10 +128,18 @@ func makeEventFunc(h Handler) eventFunc {
if err := validateEvent(event); err != nil {
return re.report(event, "validation error", err, invalidRequestError)
}

rctx := handler.RequestContext{
StackID: event.StackID,
Region: event.Region,
AccountID: event.AWSAccountID,
StackTags: event.RequestData.StackTags,
SystemTags: event.RequestData.SystemTags,
NextToken: event.NextToken,
}
request := handler.NewRequest(
event.RequestData.LogicalResourceID,
event.CallbackContext,
rctx,
credentials.SessionFromCredentialsProvider(&event.RequestData.CallerCredentials),
event.RequestData.PreviousResourceProperties,
event.RequestData.ResourceProperties,
Expand Down Expand Up @@ -180,9 +197,17 @@ func makeTestEventFunc(h Handler) testEventFunc {
if err != nil {
return handler.NewFailedEvent(err), err
}
rctx := handler.RequestContext{
Region: event.Request.Region,
AccountID: event.Request.AWSAccountID,
StackTags: event.Request.DesiredResourceTags,
SystemTags: event.Request.SystemTags,
NextToken: event.Request.NextToken,
}
request := handler.NewRequest(
event.Request.LogicalResourceIdentifier,
event.CallbackContext,
rctx,
credentials.SessionFromCredentialsProvider(&event.Credentials),
event.Request.PreviousResourceState,
event.Request.DesiredResourceState,
Expand Down
9 changes: 5 additions & 4 deletions cfn/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ type Props struct {
}

func TestNewRequest(t *testing.T) {
rctx := RequestContext{}
t.Run("Happy Path", func(t *testing.T) {
prev := Props{}
curr := Props{}

req := NewRequest("foo", nil, nil, []byte(`{"color": "red"}`), []byte(`{"color": "green"}`))
req := NewRequest("foo", nil, rctx, nil, []byte(`{"color": "red"}`), []byte(`{"color": "green"}`))

if err := req.UnmarshalPrevious(&prev); err != nil {
t.Fatalf("Unable to unmarshal props: %v", err)
Expand All @@ -42,7 +43,7 @@ func TestNewRequest(t *testing.T) {

t.Run("ResourceProps", func(t *testing.T) {
t.Run("Invalid Body", func(t *testing.T) {
req := NewRequest("foo", nil, nil, []byte(``), []byte(``))
req := NewRequest("foo", nil, rctx, nil, []byte(``), []byte(``))

invalid := struct {
Color *int `json:"color"`
Expand All @@ -60,7 +61,7 @@ func TestNewRequest(t *testing.T) {
})

t.Run("Invalid Marshal", func(t *testing.T) {
req := NewRequest("foo", nil, nil, []byte(`{"color": "ref"}`), []byte(`---BAD JSON---`))
req := NewRequest("foo", nil, rctx, nil, []byte(`{"color": "ref"}`), []byte(`---BAD JSON---`))

var invalid Props

Expand All @@ -78,7 +79,7 @@ func TestNewRequest(t *testing.T) {

t.Run("PreviousResourceProps", func(t *testing.T) {
t.Run("Invalid Marshal", func(t *testing.T) {
req := NewRequest("foo", nil, nil, []byte(`---BAD JSON---`), []byte(`{"color": "green"}`))
req := NewRequest("foo", nil, rctx, nil, []byte(`---BAD JSON---`), []byte(`{"color": "green"}`))

var invalid Props

Expand Down
29 changes: 28 additions & 1 deletion cfn/handler/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,48 @@ type Request struct {
// identifier which can be used to continue polling for stabilization
CallbackContext map[string]interface{}

// The RequestContext is information about the current
// invocation.
RequestContext RequestContext

// An authenticated AWS session that can be used with the AWS Go SDK
Session *session.Session

previousResourcePropertiesBody []byte
resourcePropertiesBody []byte
}

// RequestContext represents information about the current
// invocation request of the handler.
type RequestContext struct {
// The stack ID of the CloudFormation stack
StackID string

// The Region of the requester
Region string

// The Account ID of the requester
AccountID string

// The stack tags associated with the cloudformation stack
StackTags map[string]string

// The SystemTags associated with the request
SystemTags map[string]string

// The NextToken provided in the request
NextToken string
}

// NewRequest returns a new Request based on the provided parameters
func NewRequest(id string, ctx map[string]interface{}, sess *session.Session, previousBody, body []byte) Request {
func NewRequest(id string, ctx map[string]interface{}, requestCTX RequestContext, sess *session.Session, previousBody, body []byte) Request {
return Request{
LogicalResourceID: id,
CallbackContext: ctx,
Session: sess,
previousResourcePropertiesBody: previousBody,
resourcePropertiesBody: body,
RequestContext: requestCTX,
}
}

Expand Down
16 changes: 3 additions & 13 deletions cfn/logging/cloudwatchlogs.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package logging

import (
"context"
"io"
"log"
"os"
Expand Down Expand Up @@ -130,10 +129,7 @@ func (p *cloudWatchLogsProvider) Write(b []byte) (int, error) {
// // do something
// }
func CloudWatchLogGroupExists(client cloudwatchlogsiface.CloudWatchLogsAPI, logGroupName string) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()

resp, err := client.DescribeLogGroupsWithContext(ctx, &cloudwatchlogs.DescribeLogGroupsInput{
resp, err := client.DescribeLogGroups(&cloudwatchlogs.DescribeLogGroupsInput{
Limit: aws.Int64(1),
LogGroupNamePrefix: aws.String(logGroupName),
})
Expand Down Expand Up @@ -161,10 +157,7 @@ func CloudWatchLogGroupExists(client cloudwatchlogsiface.CloudWatchLogsAPI, logG
// panic("Unable to create log group", err)
// }
func CreateNewCloudWatchLogGroup(client cloudwatchlogsiface.CloudWatchLogsAPI, logGroupName string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()

if _, err := client.CreateLogGroupWithContext(ctx, &cloudwatchlogs.CreateLogGroupInput{
if _, err := client.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{
LogGroupName: aws.String(logGroupName),
}); err != nil {
return err
Expand All @@ -175,10 +168,7 @@ func CreateNewCloudWatchLogGroup(client cloudwatchlogsiface.CloudWatchLogsAPI, l

// CreateNewLogStream creates a log stream inside of a LogGroup
func CreateNewLogStream(client cloudwatchlogsiface.CloudWatchLogsAPI, logGroupName string, logStreamName string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()

_, err := client.CreateLogStreamWithContext(ctx, &cloudwatchlogs.CreateLogStreamInput{
_, err := client.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{
LogGroupName: aws.String(logGroupName),
LogStreamName: aws.String(logStreamName),
})
Expand Down
Loading