Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [#521] User can custom recover when a request panic #132

Merged
merged 19 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions middleware_timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"time"

"github.com/gofiber/fiber/v2"
contractshttp "github.com/goravel/framework/contracts/http"
)

Expand All @@ -23,22 +24,25 @@ func Timeout(timeout time.Duration) contractshttp.Middleware {

go func() {
defer func() {
if r := recover(); r != nil {
LogFacade.Request(ctx.Request()).Error(r)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log is required, please keep it. I saw it was removed in gin as well, please add it back.

// TODO can be customized in https://github.com/goravel/goravel/issues/521
_ = ctx.Response().Status(http.StatusInternalServerError).String("Internal Server Error").Render()
if err := recover(); err != nil {
if globalRecoverCallback != nil {
globalRecoverCallback(ctx, err)
} else {
LogFacade.Error(err)
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Server Error"})
}
}

close(done)
}()

ctx.Request().Next()
}()

select {
case <-done:
case <-ctx.Context().Done():
case <-timeoutCtx.Done():
if errors.Is(ctx.Context().Err(), context.DeadlineExceeded) {
ctx.Request().AbortWithStatus(http.StatusGatewayTimeout)
ctx.Request().AbortWithStatusJson(http.StatusGatewayTimeout, fiber.Map{"error": "Request Timeout"})
}
}
}
Expand Down
85 changes: 51 additions & 34 deletions middleware_timeout_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package fiber

import (
"fmt"
"io"
"net/http"
"testing"
"time"

"github.com/gofiber/fiber/v2"
contractshttp "github.com/goravel/framework/contracts/http"
mocksconfig "github.com/goravel/framework/mocks/config"
mockslog "github.com/goravel/framework/mocks/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand All @@ -26,48 +25,66 @@ func TestTimeoutMiddleware(t *testing.T) {

route.Middleware(Timeout(1*time.Second)).Get("/timeout", func(ctx contractshttp.Context) contractshttp.Response {
time.Sleep(2 * time.Second)

return ctx.Response().Success().String("timeout")
return nil
})

route.Middleware(Timeout(1*time.Second)).Get("/normal", func(ctx contractshttp.Context) contractshttp.Response {
return ctx.Response().Success().String("normal")
})
route.Middleware(Timeout(1*time.Second)).Get("/panic", func(ctx contractshttp.Context) contractshttp.Response {
panic(1)

route.Middleware(Timeout(5*time.Second)).Get("/panic", func(ctx contractshttp.Context) contractshttp.Response {
panic("test panic")
})

req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)
globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Panic"})
}
route.Recover(globalRecover)

mockLog := mockslog.NewLog(t)

resp, err := route.instance.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode)
t.Run("timeout", func(t *testing.T) {
req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)

req, err = http.NewRequest("GET", "/normal", nil)
require.NoError(t, err)
resp, err := route.instance.Test(req, -1)
require.NoError(t, err)
require.NotNil(t, resp)

resp, err = route.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "normal", string(body))
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.JSONEq(t, `{"error":"Request Timeout"}`, string(body))
})

req, err = http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)
t.Run("normal", func(t *testing.T) {
req, err := http.NewRequest("GET", "/normal", nil)
require.NoError(t, err)

mockLog := mockslog.NewLog(t)
mockLog.EXPECT().Request(mock.Anything).Return(mockLog).Once()
mockLog.EXPECT().Error(mock.Anything).Once()
LogFacade = mockLog

resp, err = route.instance.Test(req)
fmt.Printf("resp: %+v\n", resp)
assert.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)

body, err = io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Internal Server Error", string(body))
resp, err := route.instance.Test(req, -1)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "normal", string(body))
})

t.Run("panic", func(t *testing.T) {
req, err := http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

resp, err := route.instance.Test(req, -1)
require.NoError(t, err)

assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.JSONEq(t, `{"error":"Internal Panic"}`, string(body))
})

mockLog.AssertExpectations(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be removed since we are using mockslog.NewLog(t).

Suggested change
mockLog.AssertExpectations(t)

}
24 changes: 22 additions & 2 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
fiberrecover "github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/template/html/v2"
"github.com/goravel/framework/contracts/config"
httpcontract "github.com/goravel/framework/contracts/http"
Expand All @@ -26,6 +26,8 @@ import (
"github.com/savioxavier/termlink"
)

var globalRecoverCallback func(ctx httpcontract.Context, err any)

// Route fiber route
// Route fiber 路由
type Route struct {
Expand Down Expand Up @@ -109,7 +111,7 @@ func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) {
debug := r.config.GetBool("app.debug", false)
timeout := time.Duration(r.config.GetInt("http.request_timeout", 3)) * time.Second
fiberHandlers := []fiber.Handler{
recover.New(recover.Config{
fiberrecover.New(fiberrecover.Config{
EnableStackTrace: debug,
}),
}
Expand All @@ -130,6 +132,24 @@ func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) {
r.setMiddlewares(fiberHandlers)
}

func (r *Route) Recover(callback func(ctx httpcontract.Context, err any)) {
globalRecoverCallback = callback
middleware := middlewaresToFiberHandlers([]httpcontract.Middleware{
func(ctx httpcontract.Context) {
defer func() {
if err := recover(); err != nil {
if callback != nil {
callback(ctx, err)
} else {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Server Error"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback should not be nil here, so we can:

Suggested change
} else {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Server Error"})

}
}
}()
ctx.Request().Next()
},
})
r.setMiddlewares(middleware)
}
// Listen listen server
// Listen 监听服务器
func (r *Route) Listen(l net.Listener) error {
Expand Down
59 changes: 59 additions & 0 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"mime/multipart"
"net"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
Expand All @@ -21,6 +22,64 @@ import (
"github.com/stretchr/testify/assert"
)

func TestRecoverWithDefaultCallback(t *testing.T) {
globalRecoverCallback = nil
mockConfig := configmocks.NewConfig(t)
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("GetInt", "http.drivers.fiber.body_limit", 4096).Return(4096).Once()
mockConfig.On("GetInt", "http.drivers.fiber.header_limit", 4096).Return(4096).Once()

route, err := NewRoute(mockConfig, nil)
assert.Nil(t, err)
route.Recover(globalRecoverCallback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is unnecessary in this case, we don't need to pass nil into the Recover method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in goravel/gin we used gin.Recovery() and the default recovery test covered this case. Is it necessary to test the default recoverer in goravel/fiber?

Copy link
Contributor Author

@KlassnayaAfrodita KlassnayaAfrodita Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can initially set globalRecoverCallback
globalRecoverCallback := func(ctx contractshttp.Context, err any) { ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Server Error"}) }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is default recovery as well:

fiberHandlers := []fiber.Handler{
  fiberrecover.New(fiberrecover.Config{
	  EnableStackTrace: debug,
  }),
}

Setting a default globalRecoverCallbackis a good idea.

route.Get("/recover", func(ctx contractshttp.Context) contractshttp.Response {
panic(1)
})

req := httptest.NewRequest("GET", "/recover", nil)
resp, err := route.Test(req)
assert.Nil(t, err)

body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
assert.Equal(t, "{\"error\":\"Internal Server Error\"}", string(body))
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)

mockConfig.AssertExpectations(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mockConfig.AssertExpectations(t)

}

func TestRecoverWithCustomCallback(t *testing.T) {
mockConfig := configmocks.NewConfig(t)

mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("GetInt", "http.drivers.fiber.body_limit", 4096).Return(4096).Once()
mockConfig.On("GetInt", "http.drivers.fiber.header_limit", 4096).Return(4096).Once()

route, err := NewRoute(mockConfig, nil)
assert.Nil(t, err)

globalRecoverCallback := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Panic"})
}

route.Recover(globalRecoverCallback)

route.Get("/recover", func(ctx contractshttp.Context) contractshttp.Response {
panic(1)
})

req := httptest.NewRequest("GET", "/recover", nil)
resp, err := route.Test(req)
assert.Nil(t, err)

body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
assert.Equal(t, "{\"error\":\"Internal Panic\"}", string(body))
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)

mockConfig.AssertExpectations(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mockConfig.AssertExpectations(t)

}

func TestFallback(t *testing.T) {
mockConfig := configmocks.NewConfig(t)
mockConfig.EXPECT().GetBool("http.drivers.fiber.prefork", false).Return(false).Once()
Expand Down
Loading