Skip to content

Commit

Permalink
feat: [#521] User can custom recover when a request panic (#132)
Browse files Browse the repository at this point in the history
* Update route.go

* Update route.go

* Update route.go

* Update middleware_timeout.go

* Update middleware_timeout_test.go

* Update middleware_timeout.go

* Update route.go

* Update middleware_timeout_test.go

* Update middleware_timeout_test.go

* Update middleware_timeout_test.go

* Update route.go

* Update route_test.go

* Update route.go

* Update route_test.go

* Update route_test.go

* Update route_test.go

* Update middleware_timeout_test.go

* Update route_test.go

* Update route.go
  • Loading branch information
KlassnayaAfrodita authored Dec 29, 2024
1 parent 7ce6cfc commit 7328b59
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 42 deletions.
18 changes: 11 additions & 7 deletions middleware_timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"time"

"github.com/gofiber/fiber/v2"
contractshttp "github.com/goravel/framework/contracts/http"
)

Expand All @@ -23,22 +24,25 @@ func Timeout(timeout time.Duration) contractshttp.Middleware {

go func() {
defer func() {
if r := recover(); r != nil {
LogFacade.Request(ctx.Request()).Error(r)
// TODO can be customized in https://github.com/goravel/goravel/issues/521
_ = ctx.Response().Status(http.StatusInternalServerError).String("Internal Server Error").Render()
if err := recover(); err != nil {
if globalRecoverCallback != nil {
globalRecoverCallback(ctx, err)
} else {
LogFacade.Error(err)
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Server Error"})
}
}

close(done)
}()

ctx.Request().Next()
}()

select {
case <-done:
case <-ctx.Context().Done():
case <-timeoutCtx.Done():
if errors.Is(ctx.Context().Err(), context.DeadlineExceeded) {
ctx.Request().AbortWithStatus(http.StatusGatewayTimeout)
ctx.Request().AbortWithStatusJson(http.StatusGatewayTimeout, fiber.Map{"error": "Request Timeout"})
}
}
}
Expand Down
78 changes: 45 additions & 33 deletions middleware_timeout_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package fiber

import (
"fmt"
"io"
"net/http"
"testing"
"time"

"github.com/gofiber/fiber/v2"
contractshttp "github.com/goravel/framework/contracts/http"
mocksconfig "github.com/goravel/framework/mocks/config"
mockslog "github.com/goravel/framework/mocks/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand All @@ -26,48 +24,62 @@ func TestTimeoutMiddleware(t *testing.T) {

route.Middleware(Timeout(1*time.Second)).Get("/timeout", func(ctx contractshttp.Context) contractshttp.Response {
time.Sleep(2 * time.Second)

return ctx.Response().Success().String("timeout")
return nil
})

route.Middleware(Timeout(1*time.Second)).Get("/normal", func(ctx contractshttp.Context) contractshttp.Response {
return ctx.Response().Success().String("normal")
})
route.Middleware(Timeout(1*time.Second)).Get("/panic", func(ctx contractshttp.Context) contractshttp.Response {
panic(1)

route.Middleware(Timeout(5*time.Second)).Get("/panic", func(ctx contractshttp.Context) contractshttp.Response {
panic("test panic")
})

req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)
globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Panic"})
}
route.Recover(globalRecover)

resp, err := route.instance.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode)
t.Run("timeout", func(t *testing.T) {
req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)

req, err = http.NewRequest("GET", "/normal", nil)
require.NoError(t, err)
resp, err := route.instance.Test(req, -1)
require.NoError(t, err)
require.NotNil(t, resp)

resp, err = route.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "normal", string(body))
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.JSONEq(t, `{"error":"Request Timeout"}`, string(body))
})

req, err = http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)
t.Run("normal", func(t *testing.T) {
req, err := http.NewRequest("GET", "/normal", nil)
require.NoError(t, err)

mockLog := mockslog.NewLog(t)
mockLog.EXPECT().Request(mock.Anything).Return(mockLog).Once()
mockLog.EXPECT().Error(mock.Anything).Once()
LogFacade = mockLog
resp, err := route.instance.Test(req, -1)
assert.NoError(t, err)

resp, err = route.instance.Test(req)
fmt.Printf("resp: %+v\n", resp)
assert.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assert.Equal(t, http.StatusOK, resp.StatusCode)

body, err = io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Internal Server Error", string(body))
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "normal", string(body))
})

t.Run("panic", func(t *testing.T) {
req, err := http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

resp, err := route.instance.Test(req, -1)
require.NoError(t, err)

assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.JSONEq(t, `{"error":"Internal Panic"}`, string(body))
})
}
23 changes: 21 additions & 2 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
fiberrecover "github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/template/html/v2"
"github.com/goravel/framework/contracts/config"
httpcontract "github.com/goravel/framework/contracts/http"
Expand All @@ -26,6 +26,8 @@ import (
"github.com/savioxavier/termlink"
)

var globalRecoverCallback func(ctx httpcontract.Context, err any)

// Route fiber route
// Route fiber 路由
type Route struct {
Expand Down Expand Up @@ -109,7 +111,7 @@ func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) {
debug := r.config.GetBool("app.debug", false)
timeout := time.Duration(r.config.GetInt("http.request_timeout", 3)) * time.Second
fiberHandlers := []fiber.Handler{
recover.New(recover.Config{
fiberrecover.New(fiberrecover.Config{
EnableStackTrace: debug,
}),
}
Expand All @@ -130,6 +132,23 @@ func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) {
r.setMiddlewares(fiberHandlers)
}

func (r *Route) Recover(callback func(ctx httpcontract.Context, err any)) {
globalRecoverCallback = callback
middleware := middlewaresToFiberHandlers([]httpcontract.Middleware{
func(ctx httpcontract.Context) {
defer func() {
if err := recover(); err != nil {
if callback != nil {
callback(ctx, err)
}
}
}()
ctx.Request().Next()
},
})
r.setMiddlewares(middleware)
}

// Listen listen server
// Listen 监听服务器
func (r *Route) Listen(l net.Listener) error {
Expand Down
31 changes: 31 additions & 0 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"mime/multipart"
"net"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
Expand All @@ -21,6 +22,36 @@ import (
"github.com/stretchr/testify/assert"
)

func TestRecoverWithCustomCallback(t *testing.T) {
mockConfig := configmocks.NewConfig(t)

mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("GetInt", "http.drivers.fiber.body_limit", 4096).Return(4096).Once()
mockConfig.On("GetInt", "http.drivers.fiber.header_limit", 4096).Return(4096).Once()

route, err := NewRoute(mockConfig, nil)
assert.Nil(t, err)

globalRecoverCallback := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Panic"})
}

route.Recover(globalRecoverCallback)

route.Get("/recover", func(ctx contractshttp.Context) contractshttp.Response {
panic(1)
})

req := httptest.NewRequest("GET", "/recover", nil)
resp, err := route.Test(req)
assert.Nil(t, err)

body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
assert.Equal(t, "{\"error\":\"Internal Panic\"}", string(body))
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
}

func TestFallback(t *testing.T) {
mockConfig := configmocks.NewConfig(t)
mockConfig.EXPECT().GetBool("http.drivers.fiber.prefork", false).Return(false).Once()
Expand Down

0 comments on commit 7328b59

Please sign in to comment.