From bc1e02e53e87ff7a7b199d40d05c0dd709328819 Mon Sep 17 00:00:00 2001 From: logandavies181 Date: Tue, 29 Nov 2022 09:03:09 +0000 Subject: [PATCH] add ValidateHandlerFunc --- lambda/handler.go | 25 ++++++++++++++++++ lambda/handler_test.go | 59 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/lambda/handler.go b/lambda/handler.go index 0fc82d6e..02f10cf3 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -99,6 +99,31 @@ func WithEnableSIGTERM(callbacks ...func()) Option { }) } +// ValidateHandlerFunc validates the handler against the criteria for Start and returns an error +// if the criteria are not met +func ValidateHandlerFunc(handlerFunc interface{}) error { + if handlerFunc == nil { + return errors.New("handler is nil") + } + + handlerType := reflect.TypeOf(handlerFunc) + if handlerType.Kind() != reflect.Func { + return fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func) + } + + _, err := validateArguments(handlerType) + if err != nil { + return err + } + + err = validateReturns(handlerType) + if err != nil { + return err + } + + return nil +} + func validateArguments(handler reflect.Type) (bool, error) { handlerTakesContext := false if handler.NumIn() > 2 { diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 3c3c51d4..c8dc80ec 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -75,11 +75,68 @@ func TestInvalidHandlers(t *testing.T) { } for i, testCase := range testCases { testCase := testCase - t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) { + t.Run(fmt.Sprintf("testCase[%d] %s part 1", i, testCase.name), func(t *testing.T) { lambdaHandler := NewHandler(testCase.handler) _, err := lambdaHandler.Invoke(context.TODO(), make([]byte, 0)) assert.Equal(t, testCase.expected, err) }) + + t.Run(fmt.Sprintf("testCase[%d] %s part 2", i, testCase.name), func(t *testing.T) { + err := ValidateHandlerFunc(testCase.handler) + assert.Equal(t, testCase.expected, err) + }) + } +} + +func TestValidateHandlerFuncValidHandlers(t *testing.T) { + testCases := []struct { + name string + handler interface{} + }{ + { + name: "0 arg 0 return", + handler: func() {}, + }, + { + name: "0 arg, 1 returns", + handler: func() error { return nil }, + }, + { + name: "1 arg, 0 returns", + handler: func(any) {}, + }, + { + name: "1 arg, 1 returns", + handler: func(any) error { return nil }, + }, + { + name: "0 arg, 2 returns", + handler: func() (any, error) { return 1, nil }, + }, + { + name: "1 arg, 2 returns", + handler: func(any) (any, error) { return 1, nil }, + }, + { + name: "2 arg, 0 returns", + handler: func(context.Context, any) {}, + }, + { + name: "2 arg, 1 returns", + handler: func(context.Context, any) error { return nil }, + }, + { + name: "2 arg, 2 returns", + handler: func(context.Context, any) (any, error) { return 1, nil }, + }, + } + + for i, testCase := range testCases { + testCase := testCase + t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) { + err := ValidateHandlerFunc(testCase.handler) + assert.Nil(t, err) + }) } }