Skip to content

Commit

Permalink
feat: add ErrTooManyRequests to support Retry-After and 429
Browse files Browse the repository at this point in the history
  • Loading branch information
hacdias committed Mar 3, 2023
1 parent 36918f4 commit 810eda9
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 92 deletions.
98 changes: 98 additions & 0 deletions gateway/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package gateway

import (
"context"
"errors"
"net/http"
"strconv"

ipld "github.com/ipfs/go-ipld-format"
"github.com/ipfs/go-path/resolver"
)

var (
ErrGatewayTimeout = errors.New(http.StatusText(http.StatusGatewayTimeout))
ErrBadGateway = errors.New(http.StatusText(http.StatusBadGateway))
)

type ErrTooManyRequests struct {
RetryAfter uint64
}

func (e *ErrTooManyRequests) Error() string {
return http.StatusText(http.StatusTooManyRequests)
}

func (e *ErrTooManyRequests) Is(err error) bool {
switch err.(type) {
case *ErrTooManyRequests:
return true
default:
return false
}
}

func webError(w http.ResponseWriter, err error, defaultCode int) {
code := defaultCode

switch {
case isErrNotFound(err):
code = http.StatusNotFound
case errors.Is(err, ErrGatewayTimeout),
errors.Is(err, context.DeadlineExceeded):
code = http.StatusGatewayTimeout
case errors.Is(err, ErrBadGateway):
code = http.StatusBadGateway
case errors.Is(err, &ErrTooManyRequests{}):
var tooManyRequests *ErrTooManyRequests
_ = errors.As(err, &tooManyRequests)
if tooManyRequests.RetryAfter > 0 {
w.Header().Set("Retry-After", strconv.FormatUint(tooManyRequests.RetryAfter, 10))
}

code = http.StatusTooManyRequests
}

http.Error(w, err.Error(), code)
if code >= 500 {
log.Warnf("server error: %s", err)
}
}

func isErrNotFound(err error) bool {
if ipld.IsNotFound(err) {
return true
}

// Checks if err is a resolver.ErrNoLink. resolver.ErrNoLink does not implement
// the .Is interface and cannot be directly compared to. Therefore, errors.Is
// always returns false with it.
for {
_, ok := err.(resolver.ErrNoLink)
if ok {
return true
}

err = errors.Unwrap(err)
if err == nil {
return false
}
}
}

func webRequestError(w http.ResponseWriter, err *requestError) {
webError(w, err.Err, err.StatusCode)
}

// Custom type for collecting error details to be handled by `webRequestError`
type requestError struct {
StatusCode int
Err error
}

func newRequestError(err error, statusCode int) *requestError {
return &requestError{
Err: err,
StatusCode: statusCode,
}
}
56 changes: 56 additions & 0 deletions gateway/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package gateway

import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestErrTooManyRequestsIs(t *testing.T) {
var err error

err = &ErrTooManyRequests{RetryAfter: 10}
assert.True(t, errors.Is(err, &ErrTooManyRequests{}), "pointer to error must be error")

err = fmt.Errorf("wrapped: %w", err)
assert.True(t, errors.Is(err, &ErrTooManyRequests{}), "wrapped pointer to error must be error")
}

func TestErrTooManyRequestsAs(t *testing.T) {
var (
err error
errTMR *ErrTooManyRequests
)

err = &ErrTooManyRequests{RetryAfter: 25}
assert.True(t, errors.As(err, &errTMR), "pointer to error must be error")
assert.EqualValues(t, errTMR.RetryAfter, 25)

err = fmt.Errorf("wrapped: %w", err)
assert.True(t, errors.As(err, &errTMR), "wrapped pointer to error must be error")
assert.EqualValues(t, errTMR.RetryAfter, 25)
}

func TestWebError(t *testing.T) {
t.Parallel()

t.Run("429 Too Many Requests", func(t *testing.T) {
err := fmt.Errorf("wrapped for testing: %w", &ErrTooManyRequests{})
w := httptest.NewRecorder()
webError(w, err, http.StatusInternalServerError)
assert.Equal(t, http.StatusTooManyRequests, w.Result().StatusCode)
assert.Zero(t, len(w.Result().Header.Values("Retry-After")))
})

t.Run("429 Too Many Requests with Retry-After header", func(t *testing.T) {
err := &ErrTooManyRequests{RetryAfter: 25}
w := httptest.NewRecorder()
webError(w, err, http.StatusInternalServerError)
assert.Equal(t, http.StatusTooManyRequests, w.Result().StatusCode)
assert.Equal(t, "25", w.Result().Header.Get("Retry-After"))
})
}
70 changes: 0 additions & 70 deletions gateway/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gateway

import (
"context"
"errors"
"fmt"
"html/template"
"io"
Expand All @@ -17,9 +16,7 @@ import (
"time"

cid "github.com/ipfs/go-cid"
ipld "github.com/ipfs/go-ipld-format"
logging "github.com/ipfs/go-log"
"github.com/ipfs/go-path/resolver"
coreiface "github.com/ipfs/interface-go-ipfs-core"
ipath "github.com/ipfs/interface-go-ipfs-core/path"
mc "github.com/multiformats/go-multicodec"
Expand All @@ -41,9 +38,6 @@ const (
var (
onlyASCII = regexp.MustCompile("[[:^ascii:]]")
noModtime = time.Unix(0, 0) // disables Last-Modified header if passed as modtime

ErrGatewayTimeout = errors.New(http.StatusText(http.StatusGatewayTimeout))
ErrBadGateway = errors.New(http.StatusText(http.StatusBadGateway))
)

// HTML-based redirect for errors which can be recovered from, but we want
Expand Down Expand Up @@ -95,23 +89,6 @@ type statusResponseWriter struct {
http.ResponseWriter
}

// Custom type for collecting error details to be handled by `webRequestError`
type requestError struct {
StatusCode int
Err error
}

func (r *requestError) Error() string {
return r.Err.Error()
}

func newRequestError(err error, statusCode int) *requestError {
return &requestError{
Err: err,
StatusCode: statusCode,
}
}

func (sw *statusResponseWriter) WriteHeader(code int) {
// Check if we need to adjust Status Code to account for scheduled redirect
// This enables us to return payload along with HTTP 301
Expand Down Expand Up @@ -555,53 +532,6 @@ func (i *handler) buildIpfsRootsHeader(contentPath string, r *http.Request) (str
return rootCidList, nil
}

func webRequestError(w http.ResponseWriter, err *requestError) {
webError(w, err.Err, err.StatusCode)
}

func webError(w http.ResponseWriter, err error, defaultCode int) {
switch {
case isErrNotFound(err):
webErrorWithCode(w, err, http.StatusNotFound)
case errors.Is(err, ErrGatewayTimeout):
webErrorWithCode(w, err, http.StatusGatewayTimeout)
case errors.Is(err, ErrBadGateway):
webErrorWithCode(w, err, http.StatusBadGateway)
case errors.Is(err, context.DeadlineExceeded):
webErrorWithCode(w, err, http.StatusGatewayTimeout)
default:
webErrorWithCode(w, err, defaultCode)
}
}

func isErrNotFound(err error) bool {
if ipld.IsNotFound(err) {
return true
}

// Checks if err is a resolver.ErrNoLink. resolver.ErrNoLink does not implement
// the .Is interface and cannot be directly compared to. Therefore, errors.Is
// always returns false with it.
for {
_, ok := err.(resolver.ErrNoLink)
if ok {
return true
}

err = errors.Unwrap(err)
if err == nil {
return false
}
}
}

func webErrorWithCode(w http.ResponseWriter, err error, code int) {
http.Error(w, err.Error(), code)
if code >= 500 {
log.Warnf("server error: %s", err)
}
}

func getFilename(contentPath ipath.Path) string {
s := contentPath.String()
if (strings.HasPrefix(s, ipfsPathPrefix) || strings.HasPrefix(s, ipnsPathPrefix)) && strings.Count(gopath.Clean(s), "/") <= 2 {
Expand Down
73 changes: 51 additions & 22 deletions gateway/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"testing"

cid "github.com/ipfs/go-cid"
ipld "github.com/ipfs/go-ipld-format"
"github.com/ipfs/go-libipfs/blocks"
"github.com/ipfs/go-libipfs/files"
"github.com/ipfs/go-path/resolver"
iface "github.com/ipfs/interface-go-ipfs-core"
ipath "github.com/ipfs/interface-go-ipfs-core/path"
"github.com/tj/assert"
Expand Down Expand Up @@ -85,28 +87,55 @@ func TestGatewayInternalServerErrorInvalidPath(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
}

func TestGatewayTimeoutBubblingFromAPI(t *testing.T) {
api := &errorMockAPI{err: fmt.Errorf("the mock api has timed out: %w", ErrGatewayTimeout)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusGatewayTimeout, res.StatusCode)
}

func TestBadGatewayBubblingFromAPI(t *testing.T) {
api := &errorMockAPI{err: fmt.Errorf("the mock api has a bad gateway: %w", ErrBadGateway)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)
func TestErrorBubblingFromAPI(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)
for _, test := range []struct {
name string
err error
status int
}{
{"404 Not Found from IPLD", &ipld.ErrNotFound{}, http.StatusNotFound},
{"404 Not Found from path resolver", resolver.ErrNoLink{}, http.StatusNotFound},
{"502 Bad Gateway", ErrBadGateway, http.StatusBadGateway},
{"504 Gateway Timeout", ErrGatewayTimeout, http.StatusGatewayTimeout},
} {
t.Run(test.name, func(t *testing.T) {
api := &errorMockAPI{err: fmt.Errorf("wrapped for testing purposes: %w", test.err)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, test.status, res.StatusCode)
})
}

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusBadGateway, res.StatusCode)
for _, test := range []struct {
name string
err error
status int
headerName string
headerValue string
headerLength int // how many times was headerName set
}{
{"429 Too Many Requests without Retry-After header", &ErrTooManyRequests{}, http.StatusTooManyRequests, "Retry-After", "", 0},
{"429 Too Many Requests with Retry-After header", &ErrTooManyRequests{RetryAfter: 3600}, http.StatusTooManyRequests, "Retry-After", "3600", 1},
} {
api := &errorMockAPI{err: fmt.Errorf("wrapped for testing purposes: %w", test.err)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, test.status, res.StatusCode)
assert.Equal(t, test.headerValue, res.Header.Get(test.headerName))
assert.Equal(t, test.headerLength, len(res.Header.Values(test.headerName)))
}
}

0 comments on commit 810eda9

Please sign in to comment.