Skip to content

Commit

Permalink
Merge pull request #6596 from AndriiDiachuk/rest-msg-max-request-size…
Browse files Browse the repository at this point in the history
…-configurable

[Access] Make REST message size limit configurable
  • Loading branch information
peterargue authored Nov 6, 2024
2 parents 899e12e + 640a918 commit df95497
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 25 deletions.
17 changes: 13 additions & 4 deletions cmd/access/node_builder/access_node_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,11 @@ func DefaultAccessNodeConfig() *AccessNodeConfig {
TxResultQueryMode: backend.IndexQueryModeExecutionNodesOnly.String(), // default to ENs only for now
},
RestConfig: rest.Config{
ListenAddress: "",
WriteTimeout: rest.DefaultWriteTimeout,
ReadTimeout: rest.DefaultReadTimeout,
IdleTimeout: rest.DefaultIdleTimeout,
ListenAddress: "",
WriteTimeout: rest.DefaultWriteTimeout,
ReadTimeout: rest.DefaultReadTimeout,
IdleTimeout: rest.DefaultIdleTimeout,
MaxRequestSize: routes.DefaultMaxRequestSize,
},
MaxMsgSize: grpcutils.DefaultMaxMsgSize,
CompressorName: grpcutils.NoCompressor,
Expand Down Expand Up @@ -1190,6 +1191,10 @@ func (builder *FlowAccessNodeBuilder) extraFlags() {
defaultConfig.rpcConf.RestConfig.ReadTimeout,
"timeout to use when reading REST request headers")
flags.DurationVar(&builder.rpcConf.RestConfig.IdleTimeout, "rest-idle-timeout", defaultConfig.rpcConf.RestConfig.IdleTimeout, "idle timeout for REST connections")
flags.Int64Var(&builder.rpcConf.RestConfig.MaxRequestSize,
"rest-max-request-size",
defaultConfig.rpcConf.RestConfig.MaxRequestSize,
"the maximum request size in bytes for payload sent over REST server")
flags.StringVarP(&builder.rpcConf.CollectionAddr,
"static-collection-ingress-addr",
"",
Expand Down Expand Up @@ -1508,6 +1513,10 @@ func (builder *FlowAccessNodeBuilder) extraFlags() {
return errors.New("execution-data-indexing-enabled must be set if check-payer-balance is enabled")
}

if builder.rpcConf.RestConfig.MaxRequestSize <= 0 {
return errors.New("rest-max-request-size must be greater than 0")
}

return nil
})
}
Expand Down
17 changes: 13 additions & 4 deletions cmd/observer/node_builder/observer_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,11 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig {
TxResultQueryMode: backend.IndexQueryModeExecutionNodesOnly.String(), // default to ENs only for now
},
RestConfig: rest.Config{
ListenAddress: "",
WriteTimeout: rest.DefaultWriteTimeout,
ReadTimeout: rest.DefaultReadTimeout,
IdleTimeout: rest.DefaultIdleTimeout,
ListenAddress: "",
WriteTimeout: rest.DefaultWriteTimeout,
ReadTimeout: rest.DefaultReadTimeout,
IdleTimeout: rest.DefaultIdleTimeout,
MaxRequestSize: routes.DefaultMaxRequestSize,
},
MaxMsgSize: grpcutils.DefaultMaxMsgSize,
CompressorName: grpcutils.NoCompressor,
Expand Down Expand Up @@ -621,6 +622,10 @@ func (builder *ObserverServiceBuilder) extraFlags() {
defaultConfig.rpcConf.RestConfig.ReadTimeout,
"timeout to use when reading REST request headers")
flags.DurationVar(&builder.rpcConf.RestConfig.IdleTimeout, "rest-idle-timeout", defaultConfig.rpcConf.RestConfig.IdleTimeout, "idle timeout for REST connections")
flags.Int64Var(&builder.rpcConf.RestConfig.MaxRequestSize,
"rest-max-request-size",
defaultConfig.rpcConf.RestConfig.MaxRequestSize,
"the maximum request size in bytes for payload sent over REST server")
flags.UintVar(&builder.rpcConf.MaxMsgSize,
"rpc-max-message-size",
defaultConfig.rpcConf.MaxMsgSize,
Expand Down Expand Up @@ -851,6 +856,10 @@ func (builder *ObserverServiceBuilder) extraFlags() {
}
}

if builder.rpcConf.RestConfig.MaxRequestSize <= 0 {
return errors.New("rest-max-request-size must be greater than 0")
}

return nil
})
}
Expand Down
3 changes: 2 additions & 1 deletion engine/access/rest/routes/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ func NewHandler(
handlerFunc ApiHandlerFunc,
generator models.LinkGenerator,
chain flow.Chain,
maxRequestSize int64,
) *Handler {
handler := &Handler{
backend: backend,
apiHandlerFunc: handlerFunc,
linkGenerator: generator,
HttpHandler: NewHttpHandler(logger, chain),
HttpHandler: NewHttpHandler(logger, chain, maxRequestSize),
}

return handler
Expand Down
12 changes: 8 additions & 4 deletions engine/access/rest/routes/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,27 @@ import (
"github.com/onflow/flow-go/model/flow"
)

const MaxRequestSize = 2 << 20 // 2MB
const DefaultMaxRequestSize = 2 << 20 // 2MB

// HttpHandler is custom http handler implementing custom handler function.
// HttpHandler function allows easier handling of errors and responses as it
// wraps functionality for handling error and responses outside of endpoint handling.
type HttpHandler struct {
Logger zerolog.Logger
Chain flow.Chain

MaxRequestSize int64
}

func NewHttpHandler(
logger zerolog.Logger,
chain flow.Chain,
maxRequestSize int64,
) *HttpHandler {
return &HttpHandler{
Logger: logger,
Chain: chain,
Logger: logger,
Chain: chain,
MaxRequestSize: maxRequestSize,
}
}

Expand All @@ -43,7 +47,7 @@ func (h *HttpHandler) VerifyRequest(w http.ResponseWriter, r *http.Request) erro
errLog := h.Logger.With().Str("request_url", r.URL.String()).Logger()

// limit requested body size
r.Body = http.MaxBytesReader(w, r.Body, MaxRequestSize)
r.Body = http.MaxBytesReader(w, r.Body, h.MaxRequestSize)
err := r.ParseForm()
if err != nil {
h.errorHandler(w, err, errLog)
Expand Down
11 changes: 8 additions & 3 deletions engine/access/rest/routes/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,14 @@ func NewRouterBuilder(
}

// AddRestRoutes adds rest routes to the router.
func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *RouterBuilder {
func (b *RouterBuilder) AddRestRoutes(
backend access.API,
chain flow.Chain,
maxRequestSize int64,
) *RouterBuilder {
linkGenerator := models.NewLinkGeneratorImpl(b.v1SubRouter)
for _, r := range Routes {
h := NewHandler(b.logger, backend, r.Handler, linkGenerator, chain)
h := NewHandler(b.logger, backend, r.Handler, linkGenerator, chain, maxRequestSize)
b.v1SubRouter.
Methods(r.Method).
Path(r.Pattern).
Expand All @@ -64,10 +68,11 @@ func (b *RouterBuilder) AddWsRoutes(
stateStreamApi state_stream.API,
chain flow.Chain,
stateStreamConfig backend.Config,
maxRequestSize int64,
) *RouterBuilder {

for _, r := range WSRoutes {
h := NewWSHandler(b.logger, stateStreamApi, r.Handler, chain, stateStreamConfig)
h := NewWSHandler(b.logger, stateStreamApi, r.Handler, chain, stateStreamConfig, maxRequestSize)
b.v1SubRouter.
Methods(r.Method).
Path(r.Pattern).
Expand Down
3 changes: 2 additions & 1 deletion engine/access/rest/routes/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func executeRequest(req *http.Request, backend access.API) *httptest.ResponseRec
).AddRestRoutes(
backend,
flow.Testnet.Chain(),
DefaultMaxRequestSize,
).Build()

rr := httptest.NewRecorder()
Expand All @@ -144,7 +145,7 @@ func executeWsRequest(req *http.Request, stateStreamApi state_stream.API, respon

router := NewRouterBuilder(unittest.Logger(), restCollector).AddWsRoutes(
stateStreamApi,
chain, config).Build()
chain, config, DefaultMaxRequestSize).Build()
router.ServeHTTP(responseRecorder, req)
}

Expand Down
3 changes: 2 additions & 1 deletion engine/access/rest/routes/websocket_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ func NewWSHandler(
subscribeFunc SubscribeHandlerFunc,
chain flow.Chain,
stateStreamConfig backend.Config,
maxRequestSize int64,
) *WSHandler {
handler := &WSHandler{
subscribeFunc: subscribeFunc,
Expand All @@ -261,7 +262,7 @@ func NewWSHandler(
maxStreams: int32(stateStreamConfig.MaxGlobalStreams),
defaultHeartbeatInterval: stateStreamConfig.HeartbeatInterval,
activeStreamCount: atomic.NewInt32(0),
HttpHandler: NewHttpHandler(logger, chain),
HttpHandler: NewHttpHandler(logger, chain, maxRequestSize),
}

return handler
Expand Down
13 changes: 7 additions & 6 deletions engine/access/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ const (
)

type Config struct {
ListenAddress string
WriteTimeout time.Duration
ReadTimeout time.Duration
IdleTimeout time.Duration
ListenAddress string
WriteTimeout time.Duration
ReadTimeout time.Duration
IdleTimeout time.Duration
MaxRequestSize int64
}

// NewServer returns an HTTP server initialized with the REST API handler
Expand All @@ -42,9 +43,9 @@ func NewServer(serverAPI access.API,
stateStreamApi state_stream.API,
stateStreamConfig backend.Config,
) (*http.Server, error) {
builder := routes.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain)
builder := routes.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain, config.MaxRequestSize)
if stateStreamApi != nil {
builder.AddWsRoutes(stateStreamApi, chain, stateStreamConfig)
builder.AddWsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize)
}

c := cors.New(cors.Options{
Expand Down
2 changes: 1 addition & 1 deletion engine/access/rest_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ func (suite *RestAPITestSuite) TestRequestSizeRestriction() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
// make a request of size larger than the max permitted size
requestBytes := make([]byte, routes.MaxRequestSize+1)
requestBytes := make([]byte, routes.DefaultMaxRequestSize+1)
script := restclient.ScriptsBody{
Script: string(requestBytes),
}
Expand Down

0 comments on commit df95497

Please sign in to comment.