Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support HTTP 429 with Retry-After #194

Merged
merged 7 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions gateway/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package gateway

import (
"context"
"errors"
"net/http"
"strconv"

ipld "github.com/ipfs/go-ipld-format"
"github.com/ipfs/go-path/resolver"
)

var (
ErrGatewayTimeout = errors.New(http.StatusText(http.StatusGatewayTimeout))
ErrBadGateway = errors.New(http.StatusText(http.StatusBadGateway))
)

type ErrTooManyRequests struct {
RetryAfter uint64
hacdias marked this conversation as resolved.
Show resolved Hide resolved
}

func (e *ErrTooManyRequests) Error() string {
return http.StatusText(http.StatusTooManyRequests)
}

func (e *ErrTooManyRequests) Is(err error) bool {
switch err.(type) {
case *ErrTooManyRequests:
return true
default:
return false
}
}

func webError(w http.ResponseWriter, err error, defaultCode int) {
code := defaultCode

switch {
case isErrNotFound(err):
code = http.StatusNotFound
case errors.Is(err, ErrGatewayTimeout),
errors.Is(err, context.DeadlineExceeded):
code = http.StatusGatewayTimeout
case errors.Is(err, ErrBadGateway):
code = http.StatusBadGateway
case errors.Is(err, &ErrTooManyRequests{}):
var tooManyRequests *ErrTooManyRequests
_ = errors.As(err, &tooManyRequests)
if tooManyRequests.RetryAfter > 0 {
w.Header().Set("Retry-After", strconv.FormatUint(tooManyRequests.RetryAfter, 10))
}

code = http.StatusTooManyRequests
}

http.Error(w, err.Error(), code)
if code >= 500 {
log.Warnf("server error: %s", err)
}
hacdias marked this conversation as resolved.
Show resolved Hide resolved
}

func isErrNotFound(err error) bool {
if ipld.IsNotFound(err) {
return true
}

// Checks if err is a resolver.ErrNoLink. resolver.ErrNoLink does not implement
// the .Is interface and cannot be directly compared to. Therefore, errors.Is
// always returns false with it.
for {
_, ok := err.(resolver.ErrNoLink)
if ok {
return true
}

err = errors.Unwrap(err)
if err == nil {
return false
}
}
}

func webRequestError(w http.ResponseWriter, err *requestError) {
webError(w, err.Err, err.StatusCode)
}

// Custom type for collecting error details to be handled by `webRequestError`
type requestError struct {
StatusCode int
Err error
}

func newRequestError(err error, statusCode int) *requestError {
return &requestError{
Err: err,
StatusCode: statusCode,
}
}
56 changes: 56 additions & 0 deletions gateway/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package gateway

import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestErrTooManyRequestsIs(t *testing.T) {
var err error

err = &ErrTooManyRequests{RetryAfter: 10}
assert.True(t, errors.Is(err, &ErrTooManyRequests{}), "pointer to error must be error")

err = fmt.Errorf("wrapped: %w", err)
assert.True(t, errors.Is(err, &ErrTooManyRequests{}), "wrapped pointer to error must be error")
}

func TestErrTooManyRequestsAs(t *testing.T) {
var (
err error
errTMR *ErrTooManyRequests
)

err = &ErrTooManyRequests{RetryAfter: 25}
assert.True(t, errors.As(err, &errTMR), "pointer to error must be error")
assert.EqualValues(t, errTMR.RetryAfter, 25)

err = fmt.Errorf("wrapped: %w", err)
assert.True(t, errors.As(err, &errTMR), "wrapped pointer to error must be error")
assert.EqualValues(t, errTMR.RetryAfter, 25)
}

func TestWebError(t *testing.T) {
t.Parallel()

t.Run("429 Too Many Requests", func(t *testing.T) {
err := fmt.Errorf("wrapped for testing: %w", &ErrTooManyRequests{})
w := httptest.NewRecorder()
webError(w, err, http.StatusInternalServerError)
assert.Equal(t, http.StatusTooManyRequests, w.Result().StatusCode)
assert.Zero(t, len(w.Result().Header.Values("Retry-After")))
})

t.Run("429 Too Many Requests with Retry-After header", func(t *testing.T) {
err := &ErrTooManyRequests{RetryAfter: 25}
w := httptest.NewRecorder()
webError(w, err, http.StatusInternalServerError)
assert.Equal(t, http.StatusTooManyRequests, w.Result().StatusCode)
assert.Equal(t, "25", w.Result().Header.Get("Retry-After"))
})
}
70 changes: 0 additions & 70 deletions gateway/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gateway

import (
"context"
"errors"
"fmt"
"html/template"
"io"
Expand All @@ -17,9 +16,7 @@ import (
"time"

cid "github.com/ipfs/go-cid"
ipld "github.com/ipfs/go-ipld-format"
logging "github.com/ipfs/go-log"
"github.com/ipfs/go-path/resolver"
coreiface "github.com/ipfs/interface-go-ipfs-core"
ipath "github.com/ipfs/interface-go-ipfs-core/path"
mc "github.com/multiformats/go-multicodec"
Expand All @@ -41,9 +38,6 @@ const (
var (
onlyASCII = regexp.MustCompile("[[:^ascii:]]")
noModtime = time.Unix(0, 0) // disables Last-Modified header if passed as modtime

ErrGatewayTimeout = errors.New(http.StatusText(http.StatusGatewayTimeout))
ErrBadGateway = errors.New(http.StatusText(http.StatusBadGateway))
)

// HTML-based redirect for errors which can be recovered from, but we want
Expand Down Expand Up @@ -95,23 +89,6 @@ type statusResponseWriter struct {
http.ResponseWriter
}

// Custom type for collecting error details to be handled by `webRequestError`
type requestError struct {
StatusCode int
Err error
}

func (r *requestError) Error() string {
return r.Err.Error()
}

func newRequestError(err error, statusCode int) *requestError {
return &requestError{
Err: err,
StatusCode: statusCode,
}
}

func (sw *statusResponseWriter) WriteHeader(code int) {
// Check if we need to adjust Status Code to account for scheduled redirect
// This enables us to return payload along with HTTP 301
Expand Down Expand Up @@ -555,53 +532,6 @@ func (i *handler) buildIpfsRootsHeader(contentPath string, r *http.Request) (str
return rootCidList, nil
}

func webRequestError(w http.ResponseWriter, err *requestError) {
webError(w, err.Err, err.StatusCode)
}

func webError(w http.ResponseWriter, err error, defaultCode int) {
switch {
case isErrNotFound(err):
webErrorWithCode(w, err, http.StatusNotFound)
case errors.Is(err, ErrGatewayTimeout):
webErrorWithCode(w, err, http.StatusGatewayTimeout)
case errors.Is(err, ErrBadGateway):
webErrorWithCode(w, err, http.StatusBadGateway)
case errors.Is(err, context.DeadlineExceeded):
webErrorWithCode(w, err, http.StatusGatewayTimeout)
default:
webErrorWithCode(w, err, defaultCode)
}
}

func isErrNotFound(err error) bool {
if ipld.IsNotFound(err) {
return true
}

// Checks if err is a resolver.ErrNoLink. resolver.ErrNoLink does not implement
// the .Is interface and cannot be directly compared to. Therefore, errors.Is
// always returns false with it.
for {
_, ok := err.(resolver.ErrNoLink)
if ok {
return true
}

err = errors.Unwrap(err)
if err == nil {
return false
}
}
}

func webErrorWithCode(w http.ResponseWriter, err error, code int) {
http.Error(w, err.Error(), code)
if code >= 500 {
log.Warnf("server error: %s", err)
}
}

func getFilename(contentPath ipath.Path) string {
s := contentPath.String()
if (strings.HasPrefix(s, ipfsPathPrefix) || strings.HasPrefix(s, ipnsPathPrefix)) && strings.Count(gopath.Clean(s), "/") <= 2 {
Expand Down
73 changes: 51 additions & 22 deletions gateway/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"testing"

cid "github.com/ipfs/go-cid"
ipld "github.com/ipfs/go-ipld-format"
"github.com/ipfs/go-libipfs/blocks"
"github.com/ipfs/go-libipfs/files"
"github.com/ipfs/go-path/resolver"
iface "github.com/ipfs/interface-go-ipfs-core"
ipath "github.com/ipfs/interface-go-ipfs-core/path"
"github.com/tj/assert"
Expand Down Expand Up @@ -85,28 +87,55 @@ func TestGatewayInternalServerErrorInvalidPath(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
}

func TestGatewayTimeoutBubblingFromAPI(t *testing.T) {
api := &errorMockAPI{err: fmt.Errorf("the mock api has timed out: %w", ErrGatewayTimeout)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusGatewayTimeout, res.StatusCode)
}

func TestBadGatewayBubblingFromAPI(t *testing.T) {
api := &errorMockAPI{err: fmt.Errorf("the mock api has a bad gateway: %w", ErrBadGateway)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)
func TestErrorBubblingFromAPI(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)
for _, test := range []struct {
name string
err error
status int
}{
{"404 Not Found from IPLD", &ipld.ErrNotFound{}, http.StatusNotFound},
{"404 Not Found from path resolver", resolver.ErrNoLink{}, http.StatusNotFound},
{"502 Bad Gateway", ErrBadGateway, http.StatusBadGateway},
{"504 Gateway Timeout", ErrGatewayTimeout, http.StatusGatewayTimeout},
} {
t.Run(test.name, func(t *testing.T) {
api := &errorMockAPI{err: fmt.Errorf("wrapped for testing purposes: %w", test.err)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, test.status, res.StatusCode)
})
}

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusBadGateway, res.StatusCode)
for _, test := range []struct {
name string
err error
status int
headerName string
headerValue string
headerLength int // how many times was headerName set
}{
{"429 Too Many Requests without Retry-After header", &ErrTooManyRequests{}, http.StatusTooManyRequests, "Retry-After", "", 0},
{"429 Too Many Requests with Retry-After header", &ErrTooManyRequests{RetryAfter: 3600}, http.StatusTooManyRequests, "Retry-After", "3600", 1},
} {
api := &errorMockAPI{err: fmt.Errorf("wrapped for testing purposes: %w", test.err)}
ts := newTestServer(t, api)
t.Logf("test server url: %s", ts.URL)

req, err := http.NewRequest(http.MethodGet, ts.URL+"/ipns/en.wikipedia-on-ipfs.org", nil)
assert.Nil(t, err)

res, err := ts.Client().Do(req)
assert.Nil(t, err)
assert.Equal(t, test.status, res.StatusCode)
assert.Equal(t, test.headerValue, res.Header.Get(test.headerName))
assert.Equal(t, test.headerLength, len(res.Header.Values(test.headerName)))
}
}