Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for type-safe Start* function #468

Merged
merged 10 commits into from
Dec 4, 2022
8 changes: 4 additions & 4 deletions lambda/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ import (
// - handler must be a function
// - handler may take between 0 and two arguments.
// - if there are two arguments, the first argument must satisfy the "context.Context" interface.
// - handler may return between 0 and two arguments.
// - if there are two return values, the second argument must be an error.
// - handler may return between 0 and two values.
// - if there are two return values, the second return value must be an error.
// - if there is one return value it must be an error.
//
// Valid function signatures:
//
// func ()
// func (TIn)
// func () error
// func (TIn) error
// func () (TOut, error)
// func (TIn) (TOut, error)
// func (context.Context) error
// func (context.Context, TIn)
// func (context.Context, TIn) error
// func (context.Context) (TOut, error)
// func (context.Context, TIn) (TOut, error)
//
// Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library.
Expand Down
52 changes: 52 additions & 0 deletions lambda/entry_generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//go:build go1.18
// +build go1.18
logandavies181 marked this conversation as resolved.
Show resolved Hide resolved

// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved

package lambda

import (
"context"
)

// HandlerFunc represents a valid input as described by Start
type HandlerFunc[TIn, TOut any] interface {
func() |
func(TIn) |
func() error |
func(TIn) error |
func() (TOut, error) |
func(TIn) (TOut, error) |
func(context.Context, TIn) |
func(context.Context, TIn) error |
func(context.Context, TIn) (TOut, error)
}

// StartWithOptionsTypeSafe is the same as StartWithOptions except that it takes a generic input
// so that the function signature can be validated at compile time.
// The caller can supply "any" for TIn or TOut if the input function does not use that argument or return value.
//
// Examples:
//
// TIn and TOut ignored
//
// StartWithOptionsTypeSafe[any, any](func() {
// fmt.Println("Hello world")
// })
//
// TIn used and TOut ignored
//
// type event events.APIGatewayV2HTTPRequest
// StartWithOptionsTypeSafe[event, any](func(e event) {
// fmt.Printf("Version: %s", e.Version)
// })
//
// TIn ignored, TOut used and an error returned
//
// type resp events.APIGatewayV2HTTPResponse
// StartWithOptionsTypeSafe[any, resp](func() (resp, error) {
// return resp{Body: "hello, world"}, nil
// })
func StartWithOptionsTypeSafe[TIn any, TOut any, H HandlerFunc[TIn, TOut]](handler H, options ...Option) {
start(newHandler(handler, options...))
}
112 changes: 112 additions & 0 deletions lambda/entry_generic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//go:build go1.18
// +build go1.18

// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved

package lambda

import (
"context"
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

func TestStartWithOptionsTypeSafe(t *testing.T) {
testCases := []struct {
name string
handler any
takesContext bool
}{
{
name: "0 arg, 0 returns",
handler: func() {},
takesContext: false,
},
{
name: "0 arg, 1 returns",
handler: func() error { return nil },
takesContext: false,
},
{
name: "1 arg, 0 returns",
handler: func(any) {},
takesContext: false,
},
{
name: "1 arg, 1 returns",
handler: func(any) error { return nil },
takesContext: false,
},
{
name: "0 arg, 2 returns",
handler: func() (any, error) { return 1, nil },
takesContext: false,
},
{
name: "1 arg, 2 returns",
handler: func(any) (any, error) { return 1, nil },
takesContext: false,
},
{
name: "2 arg, 0 returns",
handler: func(context.Context, any) {},
takesContext: true,
},
{
name: "2 arg, 1 returns",
handler: func(context.Context, any) error { return nil },
takesContext: true,
},
{
name: "2 arg, 2 returns",
handler: func(context.Context, any) (any, error) { return 1, nil },
takesContext: true,
},
}

for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}
switch h := testCase.handler.(type) {
case func():
StartWithOptionsTypeSafe[any, any](h)
case func() error:
StartWithOptionsTypeSafe[any, any](h)
case func(any):
StartWithOptionsTypeSafe[any, any](h)
case func(any) error:
StartWithOptionsTypeSafe[any, any](h)
case func() (any, error):
StartWithOptionsTypeSafe[any, any](h)
case func(any) (any, error):
StartWithOptionsTypeSafe[any, any](h)
case func(context.Context, any):
StartWithOptionsTypeSafe[any, any](h)
case func(context.Context, any) error:
StartWithOptionsTypeSafe[any, any](h)
case func(context.Context, any) (any, error):
StartWithOptionsTypeSafe[any, any](h)
default:
assert.Fail(t, "Unexpected type: %T for test case: %s", h, testCase.name)
}

assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)

handlerType := reflect.TypeOf(testCase.handler)

handlerTakesContext, err := validateArguments(handlerType)
assert.NoError(t, err)
assert.Equal(t, testCase.takesContext, handlerTakesContext)

err = validateReturns(handlerType)
assert.NoError(t, err)
})
}
}