diff --git a/lambda/entry_generic_test.go b/lambda/entry_generic_test.go index a1bc5acc..fff471ff 100644 --- a/lambda/entry_generic_test.go +++ b/lambda/entry_generic_test.go @@ -27,7 +27,7 @@ func TestStartHandlerFunc(t *testing.T) { handlerType := reflect.TypeOf(f) - handlerTakesContext, err := validateArguments(handlerType) + handlerTakesContext, err := handlerTakesContext(handlerType) assert.NoError(t, err) assert.True(t, handlerTakesContext) diff --git a/lambda/handler.go b/lambda/handler.go index 0fc82d6e..73d38e20 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -99,20 +99,36 @@ func WithEnableSIGTERM(callbacks ...func()) Option { }) } -func validateArguments(handler reflect.Type) (bool, error) { - handlerTakesContext := false - if handler.NumIn() > 2 { - return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn()) - } else if handler.NumIn() > 0 { +// handlerTakesContext returns whether the handler takes a context.Context as its first argument. +func handlerTakesContext(handler reflect.Type) (bool, error) { + switch handler.NumIn() { + case 0: + return false, nil + case 1: contextType := reflect.TypeOf((*context.Context)(nil)).Elem() argumentType := handler.In(0) - handlerTakesContext = argumentType.Implements(contextType) - if handler.NumIn() > 1 && !handlerTakesContext { + if argumentType.Kind() != reflect.Interface { + return false, nil + } + + // handlers like func(event any) are valid. + if argumentType.NumMethod() == 0 { + return false, nil + } + + if !contextType.Implements(argumentType) || !argumentType.Implements(contextType) { + return false, fmt.Errorf("handler takes an interface, but it is not context.Context: %q", argumentType.Name()) + } + return true, nil + case 2: + contextType := reflect.TypeOf((*context.Context)(nil)).Elem() + argumentType := handler.In(0) + if argumentType.Kind() != reflect.Interface || !contextType.Implements(argumentType) || !argumentType.Implements(contextType) { return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind()) } + return true, nil } - - return handlerTakesContext, nil + return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn()) } func validateReturns(handler reflect.Type) error { @@ -198,7 +214,7 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func)) } - takesContext, err := validateArguments(handlerType) + takesContext, err := handlerTakesContext(handlerType) if err != nil { return errorHandler(err) } diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 3c3c51d4..610dbdf1 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "testing" + "time" "github.com/aws/aws-lambda-go/lambda/handlertrace" "github.com/aws/aws-lambda-go/lambda/messages" @@ -14,6 +15,21 @@ import ( ) func TestInvalidHandlers(t *testing.T) { + type valuer interface { + Value(key interface{}) interface{} + } + + type customContext interface { + context.Context + MyCustomMethod() + } + + type myContext interface { + Deadline() (deadline time.Time, ok bool) + Done() <-chan struct{} + Err() error + Value(key interface{}) interface{} + } testCases := []struct { name string @@ -72,12 +88,58 @@ func TestInvalidHandlers(t *testing.T) { handler: func() { }, }, + { + name: "the handler takes the empty interface", + expected: nil, + handler: func(v interface{}) error { + if _, ok := v.(context.Context); ok { + return errors.New("v should not be a Context") + } + return nil + }, + }, + { + name: "the handler takes a subset of context.Context", + expected: errors.New("handler takes an interface, but it is not context.Context: \"valuer\""), + handler: func(ctx valuer) { + }, + }, + { + name: "the handler takes a same interface with context.Context", + expected: nil, + handler: func(ctx myContext) { + }, + }, + { + name: "the handler takes a superset of context.Context", + expected: errors.New("handler takes an interface, but it is not context.Context: \"customContext\""), + handler: func(ctx customContext) { + }, + }, + { + name: "the handler takes two arguments and first argument is a subset of context.Context", + expected: errors.New("handler takes two arguments, but the first is not Context. got interface"), + handler: func(ctx valuer, v interface{}) { + }, + }, + { + name: "the handler takes two arguments and first argument is a same interface with context.Context", + expected: nil, + handler: func(ctx myContext, v interface{}) { + }, + }, + { + name: "the handler takes two arguments and first argument is a superset of context.Context", + expected: errors.New("handler takes two arguments, but the first is not Context. got interface"), + handler: func(ctx customContext, v interface{}) { + }, + }, } for i, testCase := range testCases { testCase := testCase t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) { lambdaHandler := NewHandler(testCase.handler) - _, err := lambdaHandler.Invoke(context.TODO(), make([]byte, 0)) + _, err := lambdaHandler.Invoke(context.TODO(), []byte("{}")) assert.Equal(t, testCase.expected, err) }) }