From 246e1b58658cc1c754eb04f56875790e56e166a1 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Wed, 21 Dec 2022 07:31:47 -0500 Subject: [PATCH 01/10] routing/http: feat: add streaming support This adds streaming support to the routing/v1 client and server by changing the interfaces to use iterators instead of slices, and adding content type negotation to the client and server. --- routing/http/client/client.go | 100 +++++++++++--- routing/http/client/client_test.go | 19 ++- routing/http/client/measures.go | 34 ++--- routing/http/contentrouter/contentrouter.go | 55 +++++--- .../http/contentrouter/contentrouter_test.go | 8 +- routing/http/server/server.go | 130 ++++++++++++++++-- routing/http/server/server_test.go | 15 +- routing/http/types/iter/iter.go | 30 ++++ routing/http/types/iter/json.go | 49 +++++++ routing/http/types/iter/slice.go | 23 ++++ routing/http/types/json/provider.go | 116 ++++++++++++++++ routing/http/types/ndjson/provider.go | 96 +++++++++++++ routing/http/types/provider.go | 125 ----------------- 13 files changed, 595 insertions(+), 205 deletions(-) create mode 100644 routing/http/types/iter/iter.go create mode 100644 routing/http/types/iter/json.go create mode 100644 routing/http/types/iter/slice.go create mode 100644 routing/http/types/json/provider.go create mode 100644 routing/http/types/ndjson/provider.go diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 4e5cfb8f4..6fa7fa44b 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -6,7 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "mime" "net/http" + "strings" "time" "github.com/benbjohnson/clock" @@ -15,6 +17,9 @@ import ( "github.com/ipfs/boxo/routing/http/internal/drjson" "github.com/ipfs/boxo/routing/http/server" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" + jsontypes "github.com/ipfs/boxo/routing/http/types/json" + "github.com/ipfs/boxo/routing/http/types/ndjson" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" record "github.com/libp2p/go-libp2p-record" @@ -23,7 +28,15 @@ import ( "github.com/multiformats/go-multiaddr" ) -var logger = logging.Logger("service/delegatedrouting") +var ( + _ contentrouter.Client = &client{} + logger = logging.Logger("service/delegatedrouting") +) + +const ( + mediaTypeJSON = "application/json" + mediaTypeNDJSON = "application/x-ndjson" +) type client struct { baseURL string @@ -31,6 +44,8 @@ type client struct { validator record.Validator clock clock.Clock + accepts string + peerID peer.ID addrs []types.Multiaddr identity crypto.PrivKey @@ -90,6 +105,12 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { } } +func WithStreamResultsRequired() option { + return func(c *client) { + c.accepts = mediaTypeNDJSON + } +} + // New creates a content routing API client. // The Provider and identity parameters are option. If they are nil, the `Provide` method will not function. func New(baseURL string, opts ...option) (*client, error) { @@ -105,6 +126,7 @@ func New(baseURL string, opts ...option) (*client, error) { httpClient: defaultHTTPClient, validator: ipns.Validator{}, clock: clock.New(), + accepts: strings.Join([]string{mediaTypeNDJSON, mediaTypeJSON}, ","), } for _, opt := range opts { @@ -118,43 +140,87 @@ func New(baseURL string, opts ...option) (*client, error) { return client, nil } -func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs []types.ProviderResponse, err error) { - measurement := newMeasurement("FindProviders") - defer func() { - measurement.length = len(provs) - measurement.record(ctx) - }() +// measuringIter measures the length of the iter and then publishes metrics about the whole req once the iter is closed. +// Of course, if the caller forgets to close the iter, this won't publish anything. +type measuringIter[T any] struct { + iter.Iter[T] + ctx context.Context + m *measurement +} + +func (c *measuringIter[T]) Next() (T, bool, error) { + c.m.length++ + return c.Iter.Next() +} + +func (c *measuringIter[T]) Close() error { + c.m.record(c.ctx) + return c.Iter.Close() +} + +func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Iter[types.ProviderResponse], err error) { + // TODO test measurements + m := newMeasurement("FindProviders") url := c.baseURL + server.ProvidePath + key.String() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } - measurement.host = req.Host + + m.host = req.Host start := c.clock.Now() resp, err := c.httpClient.Do(req) - measurement.err = err - measurement.latency = c.clock.Since(start) + m.err = err + m.latency = c.clock.Since(start) if err != nil { + m.record(ctx) return nil, err } - defer resp.Body.Close() - measurement.statusCode = resp.StatusCode + m.statusCode = resp.StatusCode if resp.StatusCode == http.StatusNotFound { + resp.Body.Close() + m.record(ctx) return nil, nil } if resp.StatusCode != http.StatusOK { + resp.Body.Close() + m.record(ctx) return nil, httpError(resp.StatusCode, resp.Body) } - parsedResp := &types.ReadProvidersResponse{} - err = json.NewDecoder(resp.Body).Decode(parsedResp) - return parsedResp.Providers, err + respContentType := resp.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(respContentType) + if err != nil { + resp.Body.Close() + m.err = err + m.record(ctx) + return nil, fmt.Errorf("parsing Content-Type: %w", err) + } + + m.mediaType = mediaType + + var it iter.Iter[types.ProviderResponse] + switch mediaType { + case mediaTypeJSON: + defer resp.Body.Close() + parsedResp := &jsontypes.ReadProvidersResponse{} + err = json.NewDecoder(resp.Body).Decode(parsedResp) + it = iter.FromSlice(parsedResp.Providers) + case mediaTypeNDJSON: + it = ndjson.NewReadProvidersResponseIter(resp.Body) + default: + defer resp.Body.Close() + logger.Errorw("unknown media type", "MediaType", mediaType, "ContentType", respContentType) + return nil, errors.New("unknown content type") + } + + return &measuringIter[types.ProviderResponse]{Iter: it, ctx: ctx, m: m}, nil } func (c *client) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Duration) (time.Duration, error) { @@ -202,7 +268,7 @@ func (c *client) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Du // ProvideAsync makes a provide request to a delegated router func (c *client) provideSignedBitswapRecord(ctx context.Context, bswp *types.WriteBitswapProviderRecord) (time.Duration, error) { - req := types.WriteProvidersRequest{Providers: []types.WriteProviderRecord{bswp}} + req := jsontypes.WriteProvidersRequest{Providers: []types.WriteProviderRecord{bswp}} url := c.baseURL + server.ProvidePath @@ -225,7 +291,7 @@ func (c *client) provideSignedBitswapRecord(ctx context.Context, bswp *types.Wri if resp.StatusCode != http.StatusOK { return 0, httpError(resp.StatusCode, resp.Body) } - var provideResult types.WriteProvidersResponse + var provideResult jsontypes.WriteProvidersResponse err = json.NewDecoder(resp.Body).Decode(&provideResult) if err != nil { return 0, err diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index d8fc4abac..036c99518 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -12,7 +12,9 @@ import ( "github.com/benbjohnson/clock" "github.com/ipfs/boxo/routing/http/server" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" + "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multiaddr" @@ -25,9 +27,9 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).([]types.ProviderResponse), args.Error(1) + return args.Get(0).(iter.Iter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) @@ -180,6 +182,9 @@ func TestClient_FindProviders(t *testing.T) { client := deps.client router := deps.router + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + if c.httpStatusCode != 0 { deps.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(c.httpStatusCode) @@ -191,10 +196,12 @@ func TestClient_FindProviders(t *testing.T) { } cid := makeCID() + findProvsIter := iter.FromSlice(c.routerProvs) + router.On("FindProviders", mock.Anything, cid). - Return(c.routerProvs, c.routerErr) + Return(findProvsIter, c.routerErr) - provs, err := client.FindProviders(context.Background(), cid) + provsIter, err := client.FindProviders(ctx, cid) var errList []string if runtime.GOOS == "windows" && len(c.expWinErrContains) != 0 { @@ -210,6 +217,9 @@ func TestClient_FindProviders(t *testing.T) { require.NoError(t, err) } + provs, err := iter.ReadAll(provsIter) + require.NoError(t, err) + assert.Equal(t, c.expProvs, provs) }) } @@ -274,7 +284,6 @@ func TestClient_Provide(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - // deps := makeTestDeps(t) deps := makeTestDeps(t) client := deps.client router := deps.router diff --git a/routing/http/client/measures.go b/routing/http/client/measures.go index 942460518..9cc1911b3 100644 --- a/routing/http/client/measures.go +++ b/routing/http/client/measures.go @@ -23,6 +23,7 @@ var ( keyHost = tag.MustNewKey("host") keyStatusCode = tag.MustNewKey("code") keyError = tag.MustNewKey("error") + keyMediaType = tag.MustNewKey("mediatype") ViewLatency = &view.View{ Measure: measureLatency, @@ -42,6 +43,7 @@ var ( ) type measurement struct { + mediaType string operation string err error latency time.Duration @@ -51,32 +53,22 @@ type measurement struct { } func (m measurement) record(ctx context.Context) { - stats.RecordWithTags( - ctx, - []tag.Mutator{ - tag.Upsert(keyHost, m.host), - tag.Upsert(keyOperation, m.operation), - tag.Upsert(keyStatusCode, strconv.Itoa(m.statusCode)), - tag.Upsert(keyError, metricsErrStr(m.err)), - }, - measureLatency.M(m.latency.Milliseconds()), - ) - if m.err == nil { - stats.RecordWithTags( - ctx, - []tag.Mutator{ - tag.Upsert(keyHost, m.host), - tag.Upsert(keyOperation, m.operation), - }, - measureLength.M(int64(m.length)), - ) + muts := []tag.Mutator{ + tag.Upsert(keyHost, m.host), + tag.Upsert(keyOperation, m.operation), + tag.Upsert(keyStatusCode, strconv.Itoa(m.statusCode)), + tag.Upsert(keyError, metricsErrStr(m.err)), + tag.Upsert(keyMediaType, m.mediaType), } + stats.RecordWithTags(ctx, muts, measureLatency.M(m.latency.Milliseconds())) + stats.RecordWithTags(ctx, muts, measureLength.M(int64(m.length))) } -func newMeasurement(operation string) measurement { - return measurement{ +func newMeasurement(operation string) *measurement { + return &measurement{ operation: operation, host: "None", + mediaType: "None", } } diff --git a/routing/http/contentrouter/contentrouter.go b/routing/http/contentrouter/contentrouter.go index 572ac2655..dc729ca4f 100644 --- a/routing/http/contentrouter/contentrouter.go +++ b/routing/http/contentrouter/contentrouter.go @@ -7,6 +7,7 @@ import ( "github.com/ipfs/boxo/routing/http/internal" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/peer" @@ -21,7 +22,7 @@ const ttl = 24 * time.Hour type Client interface { ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Duration) (time.Duration, error) - FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) + FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) } type contentRouter struct { @@ -101,24 +102,34 @@ func (c *contentRouter) Ready() bool { return true } -func (c *contentRouter) FindProvidersAsync(ctx context.Context, key cid.Cid, numResults int) <-chan peer.AddrInfo { - results, err := c.client.FindProviders(ctx, key) - if err != nil { - logger.Warnw("error finding providers", "CID", key, "Error", err) - ch := make(chan peer.AddrInfo) - close(ch) - return ch - } - - ch := make(chan peer.AddrInfo, len(results)) - for _, r := range results { - if r.GetSchema() == types.SchemaBitswap { - result, ok := r.(*types.ReadBitswapProviderRecord) +// readProviderResponses reads bitswap records from the iterator into the given channel, dropping non-bitswap records. +func readProviderResponses(iter iter.Iter[types.ProviderResponse], ch chan<- peer.AddrInfo) { + defer close(ch) + var ( + v types.ProviderResponse + ok bool + err error + ) + for { + v, ok, err = iter.Next() + if err != nil { + logger.Warnw( + "error iterating provider responses", + "Schema", v.GetSchema(), + "Type", reflect.TypeOf(v).String(), + ) + continue + } + if !ok { + return + } + if v.GetSchema() == types.SchemaBitswap { + result, ok := v.(*types.ReadBitswapProviderRecord) if !ok { logger.Errorw( "problem casting find providers result", - "Schema", r.GetSchema(), - "Type", reflect.TypeOf(r).String(), + "Schema", v.GetSchema(), + "Type", reflect.TypeOf(v).String(), ) continue } @@ -133,8 +144,18 @@ func (c *contentRouter) FindProvidersAsync(ctx context.Context, key cid.Cid, num Addrs: addrs, } } + } +} +func (c *contentRouter) FindProvidersAsync(ctx context.Context, key cid.Cid, numResults int) <-chan peer.AddrInfo { + resultsIter, err := c.client.FindProviders(ctx, key) + if err != nil { + logger.Warnw("error finding providers", "CID", key, "Error", err) + ch := make(chan peer.AddrInfo) + close(ch) + return ch } - close(ch) + ch := make(chan peer.AddrInfo) + go readProviderResponses(resultsIter, ch) return ch } diff --git a/routing/http/contentrouter/contentrouter_test.go b/routing/http/contentrouter/contentrouter_test.go index 643ad301f..809532739 100644 --- a/routing/http/contentrouter/contentrouter_test.go +++ b/routing/http/contentrouter/contentrouter_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multihash" @@ -21,9 +22,9 @@ func (m *mockClient) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl tim args := m.Called(ctx, keys, ttl) return args.Get(0).(time.Duration), args.Error(1) } -func (m *mockClient) FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) { +func (m *mockClient) FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).([]types.ProviderResponse), args.Error(1) + return args.Get(0).(iter.Iter[types.ProviderResponse]), args.Error(1) } func (m *mockClient) Ready(ctx context.Context) (bool, error) { args := m.Called(ctx) @@ -120,8 +121,9 @@ func TestFindProvidersAsync(t *testing.T) { Protocol: "UNKNOWN", }, } + aisIter := iter.FromSlice(ais) - client.On("FindProviders", ctx, key).Return(ais, nil) + client.On("FindProviders", ctx, key).Return(aisIter, nil) aiChan := crc.FindProvidersAsync(ctx, key, 2) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index ca650ff28..876a06196 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -7,12 +7,15 @@ import ( "errors" "fmt" "io" + "mime" "net/http" "time" "github.com/gorilla/mux" "github.com/ipfs/boxo/routing/http/internal/drjson" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" + jsontypes "github.com/ipfs/boxo/routing/http/types/json" "github.com/ipfs/go-cid" "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multiaddr" @@ -20,13 +23,23 @@ import ( logging "github.com/ipfs/go-log/v2" ) +const ( + mediaTypeJSON = "application/json" + mediaTypeNDJSON = "application/x-ndjson" +) + var logger = logging.Logger("service/server/delegatedrouting") const ProvidePath = "/routing/v1/providers/" const FindProvidersPath = "/routing/v1/providers/{cid}" +type FindProvidersAsyncResponse struct { + ProviderResponse types.ProviderResponse + Error error +} + type ContentRouter interface { - FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) + FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -47,6 +60,13 @@ type WriteProvideRequest struct { type serverOption func(s *server) +// WithStreamingResultsDisabled disables ndjson responses, so that the server only supports JSON responses. +func WithStreamingResultsDisabled() serverOption { + return func(s *server) { + s.disableNDJSON = true + } +} + func Handler(svc ContentRouter, opts ...serverOption) http.Handler { server := &server{ svc: svc, @@ -64,11 +84,12 @@ func Handler(svc ContentRouter, opts ...serverOption) http.Handler { } type server struct { - svc ContentRouter + svc ContentRouter + disableNDJSON bool } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { - req := types.WriteProvidersRequest{} + req := jsontypes.WriteProvidersRequest{} err := json.NewDecoder(httpReq.Body).Decode(&req) _ = httpReq.Body.Close() if err != nil { @@ -76,7 +97,7 @@ func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { return } - resp := types.WriteProvidersResponse{} + resp := jsontypes.WriteProvidersResponse{} for i, prov := range req.Providers { switch v := prov.(type) { @@ -131,7 +152,7 @@ func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { return } } - writeResult(w, "Provide", resp) + writeJSONResult(w, "Provide", resp) } func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { @@ -142,17 +163,106 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse CID: %w", err)) return } - providers, err := s.svc.FindProviders(httpReq.Context(), cid) + + var handlerFunc func(w http.ResponseWriter, provIter iter.Iter[types.ProviderResponse]) + + var supportsNDJSON bool + var supportsJSON bool + accepts := httpReq.Header.Values("Accept") + if len(accepts) == 0 { + handlerFunc = s.findProvidersJSON + } else { + for _, accept := range accepts { + mediaType, _, err := mime.ParseMediaType(accept) + if err != nil { + writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse Accept header: %w", err)) + return + } + + switch mediaType { + case mediaTypeJSON: + supportsJSON = true + case mediaTypeNDJSON: + supportsNDJSON = true + } + } + + if supportsNDJSON && !s.disableNDJSON { + handlerFunc = s.findProvidersNDJSON + } else if supportsJSON { + handlerFunc = s.findProvidersJSON + } else { + writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types")) + return + } + } + + provIter, err := s.svc.FindProviders(httpReq.Context(), cid) if err != nil { writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return } - response := types.ReadProvidersResponse{Providers: providers} - writeResult(w, "FindProviders", response) + + handlerFunc(w, provIter) +} + +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.Iter[types.ProviderResponse]) { + defer provIter.Close() + + var ( + providers []types.ProviderResponse + i int + ) + + for { + v, ok, err := provIter.Next() + if err != nil { + writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error on result %d: %w", i, err)) + } + if !ok { + break + } + providers = append(providers, v) + i++ + } + response := jsontypes.ReadProvidersResponse{Providers: providers} + writeJSONResult(w, "FindProviders", response) +} + +func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.Iter[types.ProviderResponse]) { + defer provIter.Close() + + w.Header().Set("Content-Type", mediaTypeNDJSON) + w.WriteHeader(http.StatusOK) + for { + v, ok, err := provIter.Next() + if err != nil { + logger.Errorw("FindProviders ndjson iterator error", "Error", err) + return + } + if !ok { + break + } + // don't use an encoder because we can't easily differentiate writer errors from encoding errors + b, err := drjson.MarshalJSONBytes(v) + if err != nil { + logger.Errorw("FindProviders ndjson marshal error", "Error", err) + return + } + + _, err = w.Write(b) + if err != nil { + logger.Warn("FindProviders ndjson write error", "Error", err) + return + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } } -func writeResult(w http.ResponseWriter, method string, val any) { - w.Header().Add("Content-Type", "application/json") +func writeJSONResult(w http.ResponseWriter, method string, val any) { + w.Header().Add("Content-Type", mediaTypeJSON) // keep the marshaling separate from the writing, so we can distinguish bugs (which surface as 500) // from transient network issues (which surface as transport errors) diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 5609b2ed3..12c9c2fdb 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -19,25 +20,25 @@ func TestHeaders(t *testing.T) { t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - result := []types.ProviderResponse{ + results := iter.FromSlice([]types.ProviderResponse{ &types.ReadBitswapProviderRecord{ Protocol: "transport-bitswap", Schema: types.SchemaBitswap, - }, - } + }}, + ) c := "baeabep4vu3ceru7nerjjbk37sxb7wmftteve4hcosmyolsbsiubw2vr6pqzj6mw7kv6tbn6nqkkldnklbjgm5tzbi4hkpkled4xlcr7xz4bq" cb, err := cid.Decode(c) require.NoError(t, err) router.On("FindProviders", mock.Anything, cb). - Return(result, nil) + Return(results, nil) resp, err := http.Get(serverAddr + ProvidePath + c) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) header := resp.Header.Get("Content-Type") - require.Equal(t, "application/json", header) + require.Equal(t, mediaTypeJSON, header) resp, err = http.Get(serverAddr + ProvidePath + "BAD_CID") require.NoError(t, err) @@ -48,9 +49,9 @@ func TestHeaders(t *testing.T) { type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).([]types.ProviderResponse), args.Error(1) + return args.Get(0).(iter.Iter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go new file mode 100644 index 000000000..d4ad1f5f3 --- /dev/null +++ b/routing/http/types/iter/iter.go @@ -0,0 +1,30 @@ +package iter + +// Iter is an iterator of aribtrary values. +// Iterators are generally not goroutine-safe, to make them safe just read from them into a channel. +// For our use cases, these usually have a single reader. This motivates iterators instead of channels, +// since the overhead of goroutines+channels has a significant performance cost. +// Using an iterator, you can read results directly without necessarily involving the Go scheduler. +type Iter[T any] interface { + // Next returns the next element, true if an attempt was made to get the next element, and any error that occurred. + // You should generally check err before ok. + Next() (val T, ok bool, err error) + Close() error +} + +func ReadAll[T any](iter Iter[T]) ([]T, error) { + if iter == nil { + return nil, nil + } + var vs []T + for { + v, ok, err := iter.Next() + if err != nil { + return vs, err + } + if !ok { + return vs, nil + } + vs = append(vs, v) + } +} diff --git a/routing/http/types/iter/json.go b/routing/http/types/iter/json.go new file mode 100644 index 000000000..40e283f9a --- /dev/null +++ b/routing/http/types/iter/json.go @@ -0,0 +1,49 @@ +package iter + +import ( + "encoding/json" + "errors" + "fmt" + "io" +) + +// FromReaderJSON returns an iterator over the given reader that reads whitespace-delimited JSON values. +func FromReaderJSON[T any](r io.Reader) *JSONIter[T] { + return &JSONIter[T]{Decoder: json.NewDecoder(r), Reader: r} +} + +// JSONIter iterates over whitespace-delimited JSON values of a byte stream. +// This closes the reader if it is a closer, to faciliate easy reading of HTTP responses. +type JSONIter[T any] struct { + Decoder *json.Decoder + Reader io.Reader + + done bool +} + +func (j *JSONIter[T]) Next() (T, bool, error) { + var val T + + if j.done { + return val, false, nil + } + + err := j.Decoder.Decode(&val) + if errors.Is(err, io.EOF) { + return val, false, j.Close() + } + if err != nil { + j.Close() + return val, false, fmt.Errorf("json iterator: %w", err) + } + + return val, true, nil +} + +func (j *JSONIter[T]) Close() error { + j.done = true + if closer, ok := j.Reader.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/routing/http/types/iter/slice.go b/routing/http/types/iter/slice.go new file mode 100644 index 000000000..7d348e756 --- /dev/null +++ b/routing/http/types/iter/slice.go @@ -0,0 +1,23 @@ +package iter + +// FromSlice returns an iterator over the given slice. +func FromSlice[T any](s []T) *SliceIter[T] { + return &SliceIter[T]{Slice: s} +} + +type SliceIter[T any] struct { + Slice []T + i int +} + +func (s *SliceIter[T]) Next() (T, bool, error) { + var val T + if s.i >= len(s.Slice) { + return val, false, nil + } + val = s.Slice[s.i] + s.i++ + return val, true, nil +} + +func (s *SliceIter[T]) Close() error { return nil } diff --git a/routing/http/types/json/provider.go b/routing/http/types/json/provider.go new file mode 100644 index 000000000..7ad344565 --- /dev/null +++ b/routing/http/types/json/provider.go @@ -0,0 +1,116 @@ +package json + +import ( + "encoding/json" + + "github.com/ipfs/go-libipfs/routing/http/types" +) + +// ReadProvidersResponse is the result of a Provide request +type ReadProvidersResponse struct { + Providers []types.ProviderResponse +} + +func (r *ReadProvidersResponse) UnmarshalJSON(b []byte) error { + var tempFPR struct{ Providers []json.RawMessage } + err := json.Unmarshal(b, &tempFPR) + if err != nil { + return err + } + + for _, provBytes := range tempFPR.Providers { + var readProv types.UnknownProviderRecord + err := json.Unmarshal(provBytes, &readProv) + if err != nil { + return err + } + + switch readProv.Schema { + case types.SchemaBitswap: + var prov types.ReadBitswapProviderRecord + err := json.Unmarshal(readProv.Bytes, &prov) + if err != nil { + return err + } + r.Providers = append(r.Providers, &prov) + default: + r.Providers = append(r.Providers, &readProv) + } + + } + return nil +} + +type WriteProvidersRequest struct { + Providers []types.WriteProviderRecord +} + +func (r *WriteProvidersRequest) UnmarshalJSON(b []byte) error { + type wpr struct{ Providers []json.RawMessage } + var tempWPR wpr + err := json.Unmarshal(b, &tempWPR) + if err != nil { + return err + } + + for _, provBytes := range tempWPR.Providers { + var rawProv types.UnknownProviderRecord + err := json.Unmarshal(provBytes, &rawProv) + if err != nil { + return err + } + + switch rawProv.Schema { + case types.SchemaBitswap: + var prov types.WriteBitswapProviderRecord + err := json.Unmarshal(rawProv.Bytes, &prov) + if err != nil { + return err + } + r.Providers = append(r.Providers, &prov) + default: + var prov types.UnknownProviderRecord + err := json.Unmarshal(b, &prov) + if err != nil { + return err + } + r.Providers = append(r.Providers, &prov) + } + } + return nil +} + +// WriteProvidersResponse is the result of a Provide operation +type WriteProvidersResponse struct { + ProvideResults []types.ProviderResponse +} + +func (r *WriteProvidersResponse) UnmarshalJSON(b []byte) error { + var tempWPR struct{ ProvideResults []json.RawMessage } + err := json.Unmarshal(b, &tempWPR) + if err != nil { + return err + } + + for _, provBytes := range tempWPR.ProvideResults { + var rawProv types.UnknownProviderRecord + err := json.Unmarshal(provBytes, &rawProv) + if err != nil { + return err + } + + switch rawProv.Schema { + case types.SchemaBitswap: + var prov types.WriteBitswapProviderRecordResponse + err := json.Unmarshal(rawProv.Bytes, &prov) + if err != nil { + return err + } + r.ProvideResults = append(r.ProvideResults, &prov) + default: + r.ProvideResults = append(r.ProvideResults, &rawProv) + } + } + + return nil +} diff --git a/routing/http/types/ndjson/provider.go b/routing/http/types/ndjson/provider.go new file mode 100644 index 000000000..7643f6360 --- /dev/null +++ b/routing/http/types/ndjson/provider.go @@ -0,0 +1,96 @@ +package ndjson + +import ( + "encoding/json" + "io" + + "github.com/ipfs/go-libipfs/routing/http/types" + "github.com/ipfs/go-libipfs/routing/http/types/iter" +) + +type readProvidersResponseIter struct { + iter.Iter[types.UnknownProviderRecord] +} + +func NewReadProvidersResponseIter(r io.Reader) *readProvidersResponseIter { + return &readProvidersResponseIter{Iter: iter.FromReaderJSON[types.UnknownProviderRecord](r)} +} + +func (p *readProvidersResponseIter) Next() (types.ProviderResponse, bool, error) { + v, ok, err := p.Iter.Next() + if err != nil { + return nil, false, err + } + if !ok { + return nil, false, nil + } + switch v.Schema { + case types.SchemaBitswap: + var prov types.ReadBitswapProviderRecord + err := json.Unmarshal(v.Bytes, &prov) + if err != nil { + return nil, false, err + } + return &prov, true, nil + default: + return &v, true, nil + } +} + +func NewWriteProvidersRequestIter(r io.Reader) *writeProvidersRequestIter { + return &writeProvidersRequestIter{Iter: iter.FromReaderJSON[types.UnknownProviderRecord](r)} +} + +type writeProvidersRequestIter struct { + iter.Iter[types.UnknownProviderRecord] +} + +func (p *writeProvidersRequestIter) Next() (types.WriteProviderRecord, bool, error) { + v, ok, err := p.Iter.Next() + if err != nil { + return nil, false, err + } + if !ok { + return nil, false, nil + } + switch v.Schema { + case types.SchemaBitswap: + var prov types.WriteBitswapProviderRecord + err := json.Unmarshal(v.Bytes, &prov) + if err != nil { + return nil, false, err + } + return &prov, true, nil + default: + return &v, true, nil + } +} + +func NewWriteProvidersResponseIter(r io.Reader) *writeProvidersResponseIter { + return &writeProvidersResponseIter{Iter: iter.FromReaderJSON[types.UnknownProviderRecord](r)} +} + +type writeProvidersResponseIter struct { + iter.Iter[types.UnknownProviderRecord] +} + +func (p *writeProvidersResponseIter) Next() (types.ProviderResponse, bool, error) { + v, ok, err := p.Iter.Next() + if err != nil { + return nil, false, err + } + if !ok { + return nil, false, nil + } + switch v.Schema { + case types.SchemaBitswap: + var prov types.WriteBitswapProviderRecordResponse + err := json.Unmarshal(v.Bytes, &prov) + if err != nil { + return nil, false, err + } + return &prov, true, nil + default: + return &v, true, nil + } +} diff --git a/routing/http/types/provider.go b/routing/http/types/provider.go index ef9e95ada..6e8e303f7 100644 --- a/routing/http/types/provider.go +++ b/routing/http/types/provider.go @@ -1,9 +1,5 @@ package types -import ( - "encoding/json" -) - // WriteProviderRecord is a type that enforces structs to imlement it to avoid confusion type WriteProviderRecord interface { IsWriteProviderRecord() @@ -14,129 +10,8 @@ type ReadProviderRecord interface { IsReadProviderRecord() } -type WriteProvidersRequest struct { - Providers []WriteProviderRecord -} - -func (r *WriteProvidersRequest) UnmarshalJSON(b []byte) error { - type wpr struct { - Providers []json.RawMessage - } - var tempWPR wpr - err := json.Unmarshal(b, &tempWPR) - if err != nil { - return err - } - - for _, provBytes := range tempWPR.Providers { - var rawProv UnknownProviderRecord - err := json.Unmarshal(provBytes, &rawProv) - if err != nil { - return err - } - - switch rawProv.Schema { - case SchemaBitswap: - var prov WriteBitswapProviderRecord - err := json.Unmarshal(rawProv.Bytes, &prov) - if err != nil { - return err - } - r.Providers = append(r.Providers, &prov) - default: - var prov UnknownProviderRecord - err := json.Unmarshal(b, &prov) - if err != nil { - return err - } - r.Providers = append(r.Providers, &prov) - } - } - return nil -} - // ProviderResponse is implemented for any ProviderResponse. It needs to have a Protocol field. type ProviderResponse interface { GetProtocol() string GetSchema() string } - -// WriteProvidersResponse is the result of a Provide operation -type WriteProvidersResponse struct { - ProvideResults []ProviderResponse -} - -// rawWriteProvidersResponse is a helper struct to make possible to parse WriteProvidersResponse's -type rawWriteProvidersResponse struct { - ProvideResults []json.RawMessage -} - -func (r *WriteProvidersResponse) UnmarshalJSON(b []byte) error { - var tempWPR rawWriteProvidersResponse - err := json.Unmarshal(b, &tempWPR) - if err != nil { - return err - } - - for _, provBytes := range tempWPR.ProvideResults { - var rawProv UnknownProviderRecord - err := json.Unmarshal(provBytes, &rawProv) - if err != nil { - return err - } - - switch rawProv.Schema { - case SchemaBitswap: - var prov WriteBitswapProviderRecordResponse - err := json.Unmarshal(rawProv.Bytes, &prov) - if err != nil { - return err - } - r.ProvideResults = append(r.ProvideResults, &prov) - default: - r.ProvideResults = append(r.ProvideResults, &rawProv) - } - } - - return nil -} - -// ReadProvidersResponse is the result of a Provide request -type ReadProvidersResponse struct { - Providers []ProviderResponse -} - -// rawReadProvidersResponse is a helper struct to make possible to parse ReadProvidersResponse's -type rawReadProvidersResponse struct { - Providers []json.RawMessage -} - -func (r *ReadProvidersResponse) UnmarshalJSON(b []byte) error { - var tempFPR rawReadProvidersResponse - err := json.Unmarshal(b, &tempFPR) - if err != nil { - return err - } - - for _, provBytes := range tempFPR.Providers { - var readProv UnknownProviderRecord - err := json.Unmarshal(provBytes, &readProv) - if err != nil { - return err - } - - switch readProv.Schema { - case SchemaBitswap: - var prov ReadBitswapProviderRecord - err := json.Unmarshal(readProv.Bytes, &prov) - if err != nil { - return err - } - r.Providers = append(r.Providers, &prov) - default: - r.Providers = append(r.Providers, &readProv) - } - - } - return nil -} From dbee4946d3f9bf7d717a9b759fb1598849e820c5 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Wed, 8 Feb 2023 11:26:13 -0500 Subject: [PATCH 02/10] Change iter interface to not include errors Errors can just be packed inside of a struct, it over-complicates the interface to require errors as separate return values everywhere. --- routing/http/client/client.go | 35 ++++--- routing/http/client/client_test.go | 60 +++++++----- routing/http/contentrouter/contentrouter.go | 25 ++--- .../http/contentrouter/contentrouter_test.go | 7 +- routing/http/server/server.go | 63 ++++++------ routing/http/server/server_test.go | 11 ++- routing/http/types/iter/iter.go | 58 +++++++---- routing/http/types/iter/json.go | 27 ++++-- routing/http/types/iter/json_test.go | 73 ++++++++++++++ routing/http/types/iter/map.go | 44 +++++++++ routing/http/types/iter/map_test.go | 73 ++++++++++++++ routing/http/types/iter/slice.go | 18 ++-- routing/http/types/iter/slice_test.go | 37 +++++++ routing/http/types/jsonseq/provider.go | 36 +++++++ routing/http/types/ndjson/provider.go | 96 ------------------- 15 files changed, 436 insertions(+), 227 deletions(-) create mode 100644 routing/http/types/iter/json_test.go create mode 100644 routing/http/types/iter/map.go create mode 100644 routing/http/types/iter/map_test.go create mode 100644 routing/http/types/iter/slice_test.go create mode 100644 routing/http/types/jsonseq/provider.go delete mode 100644 routing/http/types/ndjson/provider.go diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 6fa7fa44b..6dfd9b48d 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "mime" "net/http" "strings" @@ -19,7 +20,7 @@ import ( "github.com/ipfs/boxo/routing/http/types" "github.com/ipfs/boxo/routing/http/types/iter" jsontypes "github.com/ipfs/boxo/routing/http/types/json" - "github.com/ipfs/boxo/routing/http/types/ndjson" + "github.com/ipfs/boxo/routing/http/types/jsonseq" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" record "github.com/libp2p/go-libp2p-record" @@ -34,8 +35,8 @@ var ( ) const ( - mediaTypeJSON = "application/json" - mediaTypeNDJSON = "application/x-ndjson" + mediaTypeJSON = "application/json" + mediaTypeJSONSeq = "application/json-seq" ) type client struct { @@ -107,7 +108,7 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { func WithStreamResultsRequired() option { return func(c *client) { - c.accepts = mediaTypeNDJSON + c.accepts = mediaTypeJSONSeq } } @@ -126,7 +127,7 @@ func New(baseURL string, opts ...option) (*client, error) { httpClient: defaultHTTPClient, validator: ipns.Validator{}, clock: clock.New(), - accepts: strings.Join([]string{mediaTypeNDJSON, mediaTypeJSON}, ","), + accepts: strings.Join([]string{mediaTypeJSONSeq, mediaTypeJSON}, ","), } for _, opt := range opts { @@ -148,17 +149,24 @@ type measuringIter[T any] struct { m *measurement } -func (c *measuringIter[T]) Next() (T, bool, error) { +func (c *measuringIter[T]) Next() bool { c.m.length++ return c.Iter.Next() } +func (c *measuringIter[T]) Val() T { + return c.Iter.Val() +} + func (c *measuringIter[T]) Close() error { c.m.record(c.ctx) - return c.Iter.Close() + if closer, ok := c.Iter.(io.Closer); ok { + return closer.Close() + } + return nil } -func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Iter[types.ProviderResponse], err error) { +func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.ResultIter[types.ProviderResponse], err error) { // TODO test measurements m := newMeasurement("FindProviders") @@ -205,22 +213,23 @@ func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Ite m.mediaType = mediaType - var it iter.Iter[types.ProviderResponse] + var it iter.ResultIter[types.ProviderResponse] switch mediaType { case mediaTypeJSON: defer resp.Body.Close() parsedResp := &jsontypes.ReadProvidersResponse{} err = json.NewDecoder(resp.Body).Decode(parsedResp) - it = iter.FromSlice(parsedResp.Providers) - case mediaTypeNDJSON: - it = ndjson.NewReadProvidersResponseIter(resp.Body) + var sliceIt iter.Iter[types.ProviderResponse] = iter.FromSlice(parsedResp.Providers) + it = iter.ToResultIter(sliceIt) + case mediaTypeJSONSeq: + it = jsonseq.NewReadProvidersResponseIter(resp.Body) default: defer resp.Body.Close() logger.Errorw("unknown media type", "MediaType", mediaType, "ContentType", respContentType) return nil, errors.New("unknown content type") } - return &measuringIter[types.ProviderResponse]{Iter: it, ctx: ctx, m: m}, nil + return &measuringIter[iter.Result[types.ProviderResponse]]{Iter: it, ctx: ctx, m: m}, nil } func (c *client) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Duration) (time.Duration, error) { diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index 036c99518..802b78bb5 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -27,9 +27,9 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ClosingResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).(iter.Iter[types.ProviderResponse]), args.Error(1) + return args.Get(0).(iter.ClosingResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) @@ -144,20 +144,34 @@ func makeProviderAndIdentity() (peer.ID, []multiaddr.Multiaddr, crypto.PrivKey) return peerID, []multiaddr.Multiaddr{ma1, ma2}, priv } +type osErrContains struct { + expContains string + expContainsWin string +} + +func (e *osErrContains) errContains(t *testing.T, err error) { + if runtime.GOOS == "windows" && len(e.expContainsWin) != 0 { + assert.ErrorContains(t, err, e.expContainsWin) + } else { + assert.ErrorContains(t, err, e.expContains) + } +} + func TestClient_FindProviders(t *testing.T) { bsReadProvResp := makeBSReadProviderResp() - bitswapProvs := []types.ProviderResponse{&bsReadProvResp} + bitswapProvs := []iter.Result[types.ProviderResponse]{ + {Val: &bsReadProvResp}, + } cases := []struct { name string httpStatusCode int stopServer bool - routerProvs []types.ProviderResponse + routerProvs []iter.Result[types.ProviderResponse] routerErr error - expProvs []types.ProviderResponse - expErrContains []string - expWinErrContains []string + expProvs []iter.Result[types.ProviderResponse] + expErrContains []osErrContains }{ { name: "happy case", @@ -167,13 +181,15 @@ func TestClient_FindProviders(t *testing.T) { { name: "returns an error if there's a non-200 response", httpStatusCode: 500, - expErrContains: []string{"HTTP error with StatusCode=500: "}, + expErrContains: []osErrContains{{expContains: "HTTP error with StatusCode=500: "}}, }, { - name: "returns an error if the HTTP client returns a non-HTTP error", - stopServer: true, - expErrContains: []string{"connect: connection refused"}, - expWinErrContains: []string{"connectex: No connection could be made because the target machine actively refused it."}, + name: "returns an error if the HTTP client returns a non-HTTP error", + stopServer: true, + expErrContains: []osErrContains{{ + expContains: "connect: connection refused", + expContainsWin: "connectex: No connection could be made because the target machine actively refused it.", + }}, }, } for _, c := range cases { @@ -196,30 +212,22 @@ func TestClient_FindProviders(t *testing.T) { } cid := makeCID() - findProvsIter := iter.FromSlice(c.routerProvs) + sliceIter := iter.FromSlice(c.routerProvs) + findProvsIter := &iter.NoopClosingIter[iter.Result[types.ProviderResponse]]{Iter: sliceIter} router.On("FindProviders", mock.Anything, cid). Return(findProvsIter, c.routerErr) provsIter, err := client.FindProviders(ctx, cid) - var errList []string - if runtime.GOOS == "windows" && len(c.expWinErrContains) != 0 { - errList = c.expWinErrContains - } else { - errList = c.expErrContains - } - - for _, exp := range errList { - require.ErrorContains(t, err, exp) + for _, exp := range c.expErrContains { + exp.errContains(t, err) } - if len(errList) == 0 { + if len(c.expErrContains) == 0 { require.NoError(t, err) } - provs, err := iter.ReadAll(provsIter) - require.NoError(t, err) - + provs := iter.ReadAll[iter.Result[types.ProviderResponse]](provsIter) assert.Equal(t, c.expProvs, provs) }) } diff --git a/routing/http/contentrouter/contentrouter.go b/routing/http/contentrouter/contentrouter.go index dc729ca4f..2daee07f0 100644 --- a/routing/http/contentrouter/contentrouter.go +++ b/routing/http/contentrouter/contentrouter.go @@ -22,7 +22,7 @@ const ttl = 24 * time.Hour type Client interface { ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Duration) (time.Duration, error) - FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) + FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) } type contentRouter struct { @@ -103,26 +103,15 @@ func (c *contentRouter) Ready() bool { } // readProviderResponses reads bitswap records from the iterator into the given channel, dropping non-bitswap records. -func readProviderResponses(iter iter.Iter[types.ProviderResponse], ch chan<- peer.AddrInfo) { +func readProviderResponses(iter iter.ResultIter[types.ProviderResponse], ch chan<- peer.AddrInfo) { defer close(ch) - var ( - v types.ProviderResponse - ok bool - err error - ) - for { - v, ok, err = iter.Next() - if err != nil { - logger.Warnw( - "error iterating provider responses", - "Schema", v.GetSchema(), - "Type", reflect.TypeOf(v).String(), - ) + for iter.Next() { + res := iter.Val() + if res.Err != nil { + logger.Warnw("error iterating provider responses: %s", res.Err) continue } - if !ok { - return - } + v := res.Val if v.GetSchema() == types.SchemaBitswap { result, ok := v.(*types.ReadBitswapProviderRecord) if !ok { diff --git a/routing/http/contentrouter/contentrouter_test.go b/routing/http/contentrouter/contentrouter_test.go index 809532739..4ca620c5d 100644 --- a/routing/http/contentrouter/contentrouter_test.go +++ b/routing/http/contentrouter/contentrouter_test.go @@ -22,9 +22,10 @@ func (m *mockClient) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl tim args := m.Called(ctx, keys, ttl) return args.Get(0).(time.Duration), args.Error(1) } -func (m *mockClient) FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) { + +func (m *mockClient) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).(iter.Iter[types.ProviderResponse]), args.Error(1) + return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockClient) Ready(ctx context.Context) (bool, error) { args := m.Called(ctx) @@ -121,7 +122,7 @@ func TestFindProvidersAsync(t *testing.T) { Protocol: "UNKNOWN", }, } - aisIter := iter.FromSlice(ais) + aisIter := iter.ToResultIter[types.ProviderResponse](iter.FromSlice(ais)) client.On("FindProviders", ctx, key).Return(aisIter, nil) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 876a06196..08b0a1a62 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -24,8 +24,8 @@ import ( ) const ( - mediaTypeJSON = "application/json" - mediaTypeNDJSON = "application/x-ndjson" + mediaTypeJSON = "application/json" + mediaTypeJSONSeq = "application/json-seq" ) var logger = logging.Logger("service/server/delegatedrouting") @@ -39,7 +39,7 @@ type FindProvidersAsyncResponse struct { } type ContentRouter interface { - FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) + FindProviders(ctx context.Context, key cid.Cid) (iter.ClosingResultIter[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -60,10 +60,10 @@ type WriteProvideRequest struct { type serverOption func(s *server) -// WithStreamingResultsDisabled disables ndjson responses, so that the server only supports JSON responses. +// WithStreamingResultsDisabled disables jsonseq responses, so that the server only supports JSON responses. func WithStreamingResultsDisabled() serverOption { return func(s *server) { - s.disableNDJSON = true + s.disableJSONSeq = true } } @@ -84,8 +84,8 @@ func Handler(svc ContentRouter, opts ...serverOption) http.Handler { } type server struct { - svc ContentRouter - disableNDJSON bool + svc ContentRouter + disableJSONSeq bool } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { @@ -164,9 +164,9 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { return } - var handlerFunc func(w http.ResponseWriter, provIter iter.Iter[types.ProviderResponse]) + var handlerFunc func(w http.ResponseWriter, provIter iter.ClosingResultIter[types.ProviderResponse]) - var supportsNDJSON bool + var supportsJSONSeq bool var supportsJSON bool accepts := httpReq.Header.Values("Accept") if len(accepts) == 0 { @@ -182,13 +182,13 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { switch mediaType { case mediaTypeJSON: supportsJSON = true - case mediaTypeNDJSON: - supportsNDJSON = true + case mediaTypeJSONSeq: + supportsJSONSeq = true } } - if supportsNDJSON && !s.disableNDJSON { - handlerFunc = s.findProvidersNDJSON + if supportsJSONSeq && !s.disableJSONSeq { + handlerFunc = s.findProvidersJSONSeq } else if supportsJSON { handlerFunc = s.findProvidersJSON } else { @@ -206,7 +206,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { handlerFunc(w, provIter) } -func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.Iter[types.ProviderResponse]) { +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ClosingResultIter[types.ProviderResponse]) { defer provIter.Close() var ( @@ -214,45 +214,40 @@ func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.Iter[typ i int ) - for { - v, ok, err := provIter.Next() - if err != nil { - writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error on result %d: %w", i, err)) - } - if !ok { - break + for provIter.Next() { + res := provIter.Val() + if res.Err != nil { + writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error on result %d: %w", i, res.Err)) + return } - providers = append(providers, v) + providers = append(providers, res.Val) i++ } response := jsontypes.ReadProvidersResponse{Providers: providers} writeJSONResult(w, "FindProviders", response) } -func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.Iter[types.ProviderResponse]) { +func (s *server) findProvidersJSONSeq(w http.ResponseWriter, provIter iter.ClosingResultIter[types.ProviderResponse]) { defer provIter.Close() - w.Header().Set("Content-Type", mediaTypeNDJSON) + w.Header().Set("Content-Type", mediaTypeJSONSeq) w.WriteHeader(http.StatusOK) - for { - v, ok, err := provIter.Next() - if err != nil { - logger.Errorw("FindProviders ndjson iterator error", "Error", err) + for provIter.Next() { + res := provIter.Val() + if res.Err != nil { + logger.Errorw("FindProviders jsonseq iterator error", "Error", res.Err) return } - if !ok { - break - } // don't use an encoder because we can't easily differentiate writer errors from encoding errors - b, err := drjson.MarshalJSONBytes(v) + b, err := drjson.MarshalJSONBytes(res.Val) if err != nil { - logger.Errorw("FindProviders ndjson marshal error", "Error", err) + logger.Errorw("FindProviders jsonseq marshal error", "Error", err) return } _, err = w.Write(b) if err != nil { - logger.Warn("FindProviders ndjson write error", "Error", err) + logger.Warn("FindProviders jsonseq write error", "Error", err) return } if f, ok := w.(http.Flusher); ok { diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 12c9c2fdb..520084f17 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -20,12 +20,13 @@ func TestHeaders(t *testing.T) { t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - results := iter.FromSlice([]types.ProviderResponse{ - &types.ReadBitswapProviderRecord{ + sliceIter := iter.FromSlice([]iter.Result[types.ProviderResponse]{ + {Val: &types.ReadBitswapProviderRecord{ Protocol: "transport-bitswap", Schema: types.SchemaBitswap, - }}, + }}}, ) + results := &iter.NoopClosingIter[iter.Result[types.ProviderResponse]]{Iter: sliceIter} c := "baeabep4vu3ceru7nerjjbk37sxb7wmftteve4hcosmyolsbsiubw2vr6pqzj6mw7kv6tbn6nqkkldnklbjgm5tzbi4hkpkled4xlcr7xz4bq" cb, err := cid.Decode(c) @@ -49,9 +50,9 @@ func TestHeaders(t *testing.T) { type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.Iter[types.ProviderResponse], error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ClosingResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).(iter.Iter[types.ProviderResponse]), args.Error(1) + return args.Get(0).(iter.ClosingResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go index d4ad1f5f3..7d4b18c11 100644 --- a/routing/http/types/iter/iter.go +++ b/routing/http/types/iter/iter.go @@ -1,30 +1,56 @@ package iter -// Iter is an iterator of aribtrary values. +import "io" + +// Iter is an iterator of arbitrary values. // Iterators are generally not goroutine-safe, to make them safe just read from them into a channel. // For our use cases, these usually have a single reader. This motivates iterators instead of channels, // since the overhead of goroutines+channels has a significant performance cost. // Using an iterator, you can read results directly without necessarily involving the Go scheduler. type Iter[T any] interface { - // Next returns the next element, true if an attempt was made to get the next element, and any error that occurred. - // You should generally check err before ok. - Next() (val T, ok bool, err error) - Close() error + // Next sets the iterator to the next value, returning true if an attempt was made to get the next value. + Next() bool + Val() T +} + +type ResultIter[T any] interface { + Next() bool + Val() Result[T] +} + +type Result[T any] struct { + Val T + Err error +} + +// ToResultIter returns an iterator that wraps each value in a Result. +func ToResultIter[T any](iter Iter[T]) Iter[Result[T]] { + return Map(iter, func(t T) Result[T] { + return Result[T]{Val: t} + }) } -func ReadAll[T any](iter Iter[T]) ([]T, error) { +type ClosingIter[T any] interface { + Iter[T] + io.Closer +} + +type ClosingResultIter[T any] interface { + ResultIter[T] + io.Closer +} + +func ReadAll[T any](iter Iter[T]) []T { if iter == nil { - return nil, nil + return nil } var vs []T - for { - v, ok, err := iter.Next() - if err != nil { - return vs, err - } - if !ok { - return vs, nil - } - vs = append(vs, v) + for iter.Next() { + vs = append(vs, iter.Val()) } + return vs } + +type NoopClosingIter[T any] struct{ Iter[T] } + +func (n *NoopClosingIter[T]) Close() error { return nil } diff --git a/routing/http/types/iter/json.go b/routing/http/types/iter/json.go index 40e283f9a..428331e28 100644 --- a/routing/http/types/iter/json.go +++ b/routing/http/types/iter/json.go @@ -3,7 +3,6 @@ package iter import ( "encoding/json" "errors" - "fmt" "io" ) @@ -19,25 +18,37 @@ type JSONIter[T any] struct { Reader io.Reader done bool + res Result[T] } -func (j *JSONIter[T]) Next() (T, bool, error) { +func (j *JSONIter[T]) Next() bool { var val T if j.done { - return val, false, nil + return false } err := j.Decoder.Decode(&val) + + j.res.Val, j.res.Err = val, err + + // EOF is not an error, it just marks the end of iteration if errors.Is(err, io.EOF) { - return val, false, j.Close() + j.done = true + j.res.Err = nil + return false } - if err != nil { - j.Close() - return val, false, fmt.Errorf("json iterator: %w", err) + + // stop iterating on an error + if j.res.Err != nil { + j.done = true } - return val, true, nil + return true +} + +func (j *JSONIter[T]) Val() Result[T] { + return j.res } func (j *JSONIter[T]) Close() error { diff --git a/routing/http/types/iter/json_test.go b/routing/http/types/iter/json_test.go new file mode 100644 index 000000000..99c3bde07 --- /dev/null +++ b/routing/http/types/iter/json_test.go @@ -0,0 +1,73 @@ +package iter + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONIter(t *testing.T) { + type obj struct { + A int + } + + type expResult struct { + val obj + errContains string + } + + for _, c := range []struct { + jsonStr string + expResults []expResult + }{ + { + jsonStr: "{\"a\":1}\n{\"a\":2}", + expResults: []expResult{ + {val: obj{A: 1}}, + {val: obj{A: 2}}, + }, + }, + { + jsonStr: "", + expResults: nil, + }, + { + jsonStr: "\n", + expResults: nil, + }, + { + jsonStr: "{\"a\":1}{\"a\":2}", + expResults: []expResult{ + {val: obj{A: 1}}, + {val: obj{A: 2}}, + }, + }, + { + jsonStr: "{\"a\":1}{\"a\":asdf}", + expResults: []expResult{ + {val: obj{A: 1}}, + {errContains: "invalid character"}, + }, + }, + } { + t.Run(c.jsonStr, func(t *testing.T) { + reader := bytes.NewReader([]byte(c.jsonStr)) + iter := FromReaderJSON[obj](reader) + results := ReadAll[Result[obj]](iter) + + require.Len(t, results, len(c.expResults)) + for i, res := range results { + expRes := c.expResults[i] + if expRes.errContains != "" { + assert.ErrorContains(t, res.Err, expRes.errContains) + } else { + assert.NoError(t, res.Err) + assert.Equal(t, expRes.val, res.Val) + } + } + }) + } + +} diff --git a/routing/http/types/iter/map.go b/routing/http/types/iter/map.go new file mode 100644 index 000000000..ebaea12cd --- /dev/null +++ b/routing/http/types/iter/map.go @@ -0,0 +1,44 @@ +package iter + +import "io" + +// Map invokes f on each element of iter. +func Map[T any, U any](iter Iter[T], f func(t T) U) Iter[U] { + return &mapIter[T, U]{iter: iter, f: f} +} + +type mapIter[T any, U any] struct { + iter Iter[T] + f func(T) U + + done bool + val U +} + +func (m *mapIter[T, U]) Next() bool { + if m.done { + return false + } + + ok := m.iter.Next() + m.done = !ok + + if m.done { + return false + } + + m.val = m.f(m.iter.Val()) + + return true +} + +func (m *mapIter[T, U]) Val() U { + return m.val +} + +func (m *mapIter[T, U]) Close() error { + if closer, ok := m.iter.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/routing/http/types/iter/map_test.go b/routing/http/types/iter/map_test.go new file mode 100644 index 000000000..881679c7a --- /dev/null +++ b/routing/http/types/iter/map_test.go @@ -0,0 +1,73 @@ +package iter + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// type nthErrIter[T any] struct { +// Iter[T] +// i int +// n int +// err error +// } + +// func (n *nthErrIter[T]) Next() (T, bool) { +// v, ok := n.Iter.Next() +// n.i++ +// return v, ok +// } +// func (n *nthErrIter[T]) Err() error { +// if n.i-1 == n.n { +// return n.err +// } +// return n.Iter.Err() +// } + +// func nthErr[T any](iter Iter[T], n int, err error) Iter[T] { +// return &nthErrIter[T]{Iter: iter, n: n, err: err} +// } + +func TestMap(t *testing.T) { + for _, c := range []struct { + input Iter[int] + f func(int) int + expResults []int + }{ + { + input: FromSlice([]int{1, 2, 3}), + f: func(i int) int { return i + 1 }, + expResults: []int{2, 3, 4}, + }, + // { + // input: FromSlice([]int{}), + // f: func(i int) int { return i + 1 }, + // expResults: []int{}, + // }, + // { + // input: FromSlice([]int{1}), + // f: func(i int) int { return i + 1 }, + // expResults: []int{2}, + // }, + // { + // input: FromSlice([]int{1, 2, 3}), 2, errors.New("boom"), + // f: func(i int) (int, error) { return i + 1, nil }, + // expResults: []result{ + // {val: 2}, + // {val: 3}, + // {errContains: "boom"}, + // }, + // }, + } { + t.Run(fmt.Sprintf("%v", c.input), func(t *testing.T) { + iter := Map(c.input, c.f) + var res []int + for iter.Next() { + res = append(res, iter.Val()) + } + assert.Equal(t, c.expResults, res) + }) + } +} diff --git a/routing/http/types/iter/slice.go b/routing/http/types/iter/slice.go index 7d348e756..ec3619b7d 100644 --- a/routing/http/types/iter/slice.go +++ b/routing/http/types/iter/slice.go @@ -2,22 +2,24 @@ package iter // FromSlice returns an iterator over the given slice. func FromSlice[T any](s []T) *SliceIter[T] { - return &SliceIter[T]{Slice: s} + return &SliceIter[T]{Slice: s, i: -1} } type SliceIter[T any] struct { Slice []T i int + val T } -func (s *SliceIter[T]) Next() (T, bool, error) { - var val T +func (s *SliceIter[T]) Next() bool { + s.i++ if s.i >= len(s.Slice) { - return val, false, nil + return false } - val = s.Slice[s.i] - s.i++ - return val, true, nil + s.val = s.Slice[s.i] + return true } -func (s *SliceIter[T]) Close() error { return nil } +func (s *SliceIter[T]) Val() T { + return s.val +} diff --git a/routing/http/types/iter/slice_test.go b/routing/http/types/iter/slice_test.go new file mode 100644 index 000000000..a3c8ff619 --- /dev/null +++ b/routing/http/types/iter/slice_test.go @@ -0,0 +1,37 @@ +package iter + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSliceIter(t *testing.T) { + for _, c := range []struct { + slice []int + expSlice []int + }{ + { + slice: []int{1, 2, 3}, + expSlice: []int{1, 2, 3}, + }, + { + slice: nil, + expSlice: nil, + }, + { + slice: []int{}, + expSlice: nil, + }, + } { + t.Run(fmt.Sprintf("%+v", c.slice), func(t *testing.T) { + iter := FromSlice(c.slice) + var vals []int + for iter.Next() { + vals = append(vals, iter.Val()) + } + require.Equal(t, c.expSlice, vals) + }) + } +} diff --git a/routing/http/types/jsonseq/provider.go b/routing/http/types/jsonseq/provider.go new file mode 100644 index 000000000..480b1fdd5 --- /dev/null +++ b/routing/http/types/jsonseq/provider.go @@ -0,0 +1,36 @@ +package jsonseq + +import ( + "encoding/json" + "io" + + "github.com/ipfs/go-libipfs/routing/http/types" + "github.com/ipfs/go-libipfs/routing/http/types/iter" +) + +// NewReadProvidersResponseIter returns an iterator that reads Read Provider Records from the given reader. +func NewReadProvidersResponseIter(r io.Reader) iter.ResultIter[types.ProviderResponse] { + jsonIter := iter.FromReaderJSON[types.UnknownProviderRecord](r) + mapFn := func(upr iter.Result[types.UnknownProviderRecord]) iter.Result[types.ProviderResponse] { + var result iter.Result[types.ProviderResponse] + if upr.Err != nil { + result.Err = upr.Err + return result + } + switch upr.Val.Schema { + case types.SchemaBitswap: + var prov types.ReadBitswapProviderRecord + err := json.Unmarshal(upr.Val.Bytes, &prov) + if err != nil { + result.Err = err + return result + } + result.Val = &prov + default: + result.Val = &upr.Val + } + return result + } + + return iter.Map[iter.Result[types.UnknownProviderRecord]](jsonIter, mapFn) +} diff --git a/routing/http/types/ndjson/provider.go b/routing/http/types/ndjson/provider.go deleted file mode 100644 index 7643f6360..000000000 --- a/routing/http/types/ndjson/provider.go +++ /dev/null @@ -1,96 +0,0 @@ -package ndjson - -import ( - "encoding/json" - "io" - - "github.com/ipfs/go-libipfs/routing/http/types" - "github.com/ipfs/go-libipfs/routing/http/types/iter" -) - -type readProvidersResponseIter struct { - iter.Iter[types.UnknownProviderRecord] -} - -func NewReadProvidersResponseIter(r io.Reader) *readProvidersResponseIter { - return &readProvidersResponseIter{Iter: iter.FromReaderJSON[types.UnknownProviderRecord](r)} -} - -func (p *readProvidersResponseIter) Next() (types.ProviderResponse, bool, error) { - v, ok, err := p.Iter.Next() - if err != nil { - return nil, false, err - } - if !ok { - return nil, false, nil - } - switch v.Schema { - case types.SchemaBitswap: - var prov types.ReadBitswapProviderRecord - err := json.Unmarshal(v.Bytes, &prov) - if err != nil { - return nil, false, err - } - return &prov, true, nil - default: - return &v, true, nil - } -} - -func NewWriteProvidersRequestIter(r io.Reader) *writeProvidersRequestIter { - return &writeProvidersRequestIter{Iter: iter.FromReaderJSON[types.UnknownProviderRecord](r)} -} - -type writeProvidersRequestIter struct { - iter.Iter[types.UnknownProviderRecord] -} - -func (p *writeProvidersRequestIter) Next() (types.WriteProviderRecord, bool, error) { - v, ok, err := p.Iter.Next() - if err != nil { - return nil, false, err - } - if !ok { - return nil, false, nil - } - switch v.Schema { - case types.SchemaBitswap: - var prov types.WriteBitswapProviderRecord - err := json.Unmarshal(v.Bytes, &prov) - if err != nil { - return nil, false, err - } - return &prov, true, nil - default: - return &v, true, nil - } -} - -func NewWriteProvidersResponseIter(r io.Reader) *writeProvidersResponseIter { - return &writeProvidersResponseIter{Iter: iter.FromReaderJSON[types.UnknownProviderRecord](r)} -} - -type writeProvidersResponseIter struct { - iter.Iter[types.UnknownProviderRecord] -} - -func (p *writeProvidersResponseIter) Next() (types.ProviderResponse, bool, error) { - v, ok, err := p.Iter.Next() - if err != nil { - return nil, false, err - } - if !ok { - return nil, false, nil - } - switch v.Schema { - case types.SchemaBitswap: - var prov types.WriteBitswapProviderRecordResponse - err := json.Unmarshal(v.Bytes, &prov) - if err != nil { - return nil, false, err - } - return &prov, true, nil - default: - return &v, true, nil - } -} From 66b5712e55f53db2eebbad2dd0931b056db2af1a Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Thu, 9 Feb 2023 17:57:11 -0500 Subject: [PATCH 03/10] Unexport noop iter closer, rename iter closer, and add more docs --- routing/http/client/client_test.go | 6 +++--- routing/http/server/server.go | 8 ++++---- routing/http/server/server_test.go | 6 +++--- routing/http/types/iter/iter.go | 19 +++++++++++++++---- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index 802b78bb5..c4f44a35c 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -27,9 +27,9 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ClosingResultIter[types.ProviderResponse], error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIterCloser[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).(iter.ClosingResultIter[types.ProviderResponse]), args.Error(1) + return args.Get(0).(iter.ResultIterCloser[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) @@ -213,7 +213,7 @@ func TestClient_FindProviders(t *testing.T) { cid := makeCID() sliceIter := iter.FromSlice(c.routerProvs) - findProvsIter := &iter.NoopClosingIter[iter.Result[types.ProviderResponse]]{Iter: sliceIter} + findProvsIter := iter.IterCloserNoop[iter.Result[types.ProviderResponse]](sliceIter) router.On("FindProviders", mock.Anything, cid). Return(findProvsIter, c.routerErr) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 08b0a1a62..770bb3169 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -39,7 +39,7 @@ type FindProvidersAsyncResponse struct { } type ContentRouter interface { - FindProviders(ctx context.Context, key cid.Cid) (iter.ClosingResultIter[types.ProviderResponse], error) + FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIterCloser[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -164,7 +164,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { return } - var handlerFunc func(w http.ResponseWriter, provIter iter.ClosingResultIter[types.ProviderResponse]) + var handlerFunc func(w http.ResponseWriter, provIter iter.ResultIterCloser[types.ProviderResponse]) var supportsJSONSeq bool var supportsJSON bool @@ -206,7 +206,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { handlerFunc(w, provIter) } -func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ClosingResultIter[types.ProviderResponse]) { +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIterCloser[types.ProviderResponse]) { defer provIter.Close() var ( @@ -227,7 +227,7 @@ func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ClosingR writeJSONResult(w, "FindProviders", response) } -func (s *server) findProvidersJSONSeq(w http.ResponseWriter, provIter iter.ClosingResultIter[types.ProviderResponse]) { +func (s *server) findProvidersJSONSeq(w http.ResponseWriter, provIter iter.ResultIterCloser[types.ProviderResponse]) { defer provIter.Close() w.Header().Set("Content-Type", mediaTypeJSONSeq) diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 520084f17..1d4a17d24 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -26,7 +26,7 @@ func TestHeaders(t *testing.T) { Schema: types.SchemaBitswap, }}}, ) - results := &iter.NoopClosingIter[iter.Result[types.ProviderResponse]]{Iter: sliceIter} + results := iter.IterCloserNoop[iter.Result[types.ProviderResponse]](sliceIter) c := "baeabep4vu3ceru7nerjjbk37sxb7wmftteve4hcosmyolsbsiubw2vr6pqzj6mw7kv6tbn6nqkkldnklbjgm5tzbi4hkpkled4xlcr7xz4bq" cb, err := cid.Decode(c) @@ -50,9 +50,9 @@ func TestHeaders(t *testing.T) { type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ClosingResultIter[types.ProviderResponse], error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIterCloser[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).(iter.ClosingResultIter[types.ProviderResponse]), args.Error(1) + return args.Get(0).(iter.ResultIterCloser[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go index 7d4b18c11..82403e757 100644 --- a/routing/http/types/iter/iter.go +++ b/routing/http/types/iter/iter.go @@ -7,6 +7,12 @@ import "io" // For our use cases, these usually have a single reader. This motivates iterators instead of channels, // since the overhead of goroutines+channels has a significant performance cost. // Using an iterator, you can read results directly without necessarily involving the Go scheduler. +// +// There are a lot of options for an iterator interface, this one was picked for ease-of-use +// and for highest probability of consumers using it correctly. +// E.g. because there is a separate method for the value, it's easier to use in a loop but harder to implement. +// +// Hopefully in the future, Go will include an iterator in the language and we can remove this. type Iter[T any] interface { // Next sets the iterator to the next value, returning true if an attempt was made to get the next value. Next() bool @@ -30,12 +36,12 @@ func ToResultIter[T any](iter Iter[T]) Iter[Result[T]] { }) } -type ClosingIter[T any] interface { +type IterCloser[T any] interface { Iter[T] io.Closer } -type ClosingResultIter[T any] interface { +type ResultIterCloser[T any] interface { ResultIter[T] io.Closer } @@ -51,6 +57,11 @@ func ReadAll[T any](iter Iter[T]) []T { return vs } -type NoopClosingIter[T any] struct{ Iter[T] } +// iterCloserNoop creates an io.Closer from an Iter that does nothing on close. +type iterCloserNoop[T any] struct{ Iter[T] } -func (n *NoopClosingIter[T]) Close() error { return nil } +func (n *iterCloserNoop[T]) Close() error { return nil } + +func IterCloserNoop[T any](it Iter[T]) IterCloser[T] { + return &iterCloserNoop[T]{it} +} From 9143620fbfeb8562e67de07445d8b57476a91479 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Fri, 10 Feb 2023 09:19:19 -0500 Subject: [PATCH 04/10] Add Close() to iter interface, remove Closer interfaces This is always needed, so we can simplify things by just including it in the main iterator interface. --- routing/http/client/client.go | 6 +-- routing/http/client/client_test.go | 7 ++- routing/http/contentrouter/contentrouter.go | 1 + routing/http/server/server.go | 8 ++-- routing/http/server/server_test.go | 7 ++- routing/http/types/iter/iter.go | 27 +---------- routing/http/types/iter/map.go | 19 +++----- routing/http/types/iter/map_test.go | 52 ++++----------------- routing/http/types/iter/slice.go | 4 ++ routing/http/types/jsonseq/provider.go | 2 +- 10 files changed, 36 insertions(+), 97 deletions(-) diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 6dfd9b48d..1258141e4 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "mime" "net/http" "strings" @@ -160,10 +159,7 @@ func (c *measuringIter[T]) Val() T { func (c *measuringIter[T]) Close() error { c.m.record(c.ctx) - if closer, ok := c.Iter.(io.Closer); ok { - return closer.Close() - } - return nil + return c.Iter.Close() } func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.ResultIter[types.ProviderResponse], err error) { diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index c4f44a35c..ca724b1c2 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -27,9 +27,9 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIterCloser[types.ProviderResponse], error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).(iter.ResultIterCloser[types.ProviderResponse]), args.Error(1) + return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) @@ -212,8 +212,7 @@ func TestClient_FindProviders(t *testing.T) { } cid := makeCID() - sliceIter := iter.FromSlice(c.routerProvs) - findProvsIter := iter.IterCloserNoop[iter.Result[types.ProviderResponse]](sliceIter) + findProvsIter := iter.FromSlice(c.routerProvs) router.On("FindProviders", mock.Anything, cid). Return(findProvsIter, c.routerErr) diff --git a/routing/http/contentrouter/contentrouter.go b/routing/http/contentrouter/contentrouter.go index 2daee07f0..8318a3163 100644 --- a/routing/http/contentrouter/contentrouter.go +++ b/routing/http/contentrouter/contentrouter.go @@ -105,6 +105,7 @@ func (c *contentRouter) Ready() bool { // readProviderResponses reads bitswap records from the iterator into the given channel, dropping non-bitswap records. func readProviderResponses(iter iter.ResultIter[types.ProviderResponse], ch chan<- peer.AddrInfo) { defer close(ch) + defer iter.Close() for iter.Next() { res := iter.Val() if res.Err != nil { diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 770bb3169..cef122f9e 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -39,7 +39,7 @@ type FindProvidersAsyncResponse struct { } type ContentRouter interface { - FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIterCloser[types.ProviderResponse], error) + FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -164,7 +164,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { return } - var handlerFunc func(w http.ResponseWriter, provIter iter.ResultIterCloser[types.ProviderResponse]) + var handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) var supportsJSONSeq bool var supportsJSON bool @@ -206,7 +206,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { handlerFunc(w, provIter) } -func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIterCloser[types.ProviderResponse]) { +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) { defer provIter.Close() var ( @@ -227,7 +227,7 @@ func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIt writeJSONResult(w, "FindProviders", response) } -func (s *server) findProvidersJSONSeq(w http.ResponseWriter, provIter iter.ResultIterCloser[types.ProviderResponse]) { +func (s *server) findProvidersJSONSeq(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) { defer provIter.Close() w.Header().Set("Content-Type", mediaTypeJSONSeq) diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 1d4a17d24..ab96afe03 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -20,13 +20,12 @@ func TestHeaders(t *testing.T) { t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - sliceIter := iter.FromSlice([]iter.Result[types.ProviderResponse]{ + results := iter.FromSlice([]iter.Result[types.ProviderResponse]{ {Val: &types.ReadBitswapProviderRecord{ Protocol: "transport-bitswap", Schema: types.SchemaBitswap, }}}, ) - results := iter.IterCloserNoop[iter.Result[types.ProviderResponse]](sliceIter) c := "baeabep4vu3ceru7nerjjbk37sxb7wmftteve4hcosmyolsbsiubw2vr6pqzj6mw7kv6tbn6nqkkldnklbjgm5tzbi4hkpkled4xlcr7xz4bq" cb, err := cid.Decode(c) @@ -50,9 +49,9 @@ func TestHeaders(t *testing.T) { type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIterCloser[types.ProviderResponse], error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).(iter.ResultIterCloser[types.ProviderResponse]), args.Error(1) + return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go index 82403e757..d313d4614 100644 --- a/routing/http/types/iter/iter.go +++ b/routing/http/types/iter/iter.go @@ -1,7 +1,5 @@ package iter -import "io" - // Iter is an iterator of arbitrary values. // Iterators are generally not goroutine-safe, to make them safe just read from them into a channel. // For our use cases, these usually have a single reader. This motivates iterators instead of channels, @@ -17,12 +15,10 @@ type Iter[T any] interface { // Next sets the iterator to the next value, returning true if an attempt was made to get the next value. Next() bool Val() T + Close() error } -type ResultIter[T any] interface { - Next() bool - Val() Result[T] -} +type ResultIter[T any] interface{ Iter[Result[T]] } type Result[T any] struct { Val T @@ -36,16 +32,6 @@ func ToResultIter[T any](iter Iter[T]) Iter[Result[T]] { }) } -type IterCloser[T any] interface { - Iter[T] - io.Closer -} - -type ResultIterCloser[T any] interface { - ResultIter[T] - io.Closer -} - func ReadAll[T any](iter Iter[T]) []T { if iter == nil { return nil @@ -56,12 +42,3 @@ func ReadAll[T any](iter Iter[T]) []T { } return vs } - -// iterCloserNoop creates an io.Closer from an Iter that does nothing on close. -type iterCloserNoop[T any] struct{ Iter[T] } - -func (n *iterCloserNoop[T]) Close() error { return nil } - -func IterCloserNoop[T any](it Iter[T]) IterCloser[T] { - return &iterCloserNoop[T]{it} -} diff --git a/routing/http/types/iter/map.go b/routing/http/types/iter/map.go index ebaea12cd..b7e636f3f 100644 --- a/routing/http/types/iter/map.go +++ b/routing/http/types/iter/map.go @@ -1,13 +1,11 @@ package iter -import "io" - // Map invokes f on each element of iter. -func Map[T any, U any](iter Iter[T], f func(t T) U) Iter[U] { - return &mapIter[T, U]{iter: iter, f: f} +func Map[T any, U any](iter Iter[T], f func(t T) U) *MapIter[T, U] { + return &MapIter[T, U]{iter: iter, f: f} } -type mapIter[T any, U any] struct { +type MapIter[T any, U any] struct { iter Iter[T] f func(T) U @@ -15,7 +13,7 @@ type mapIter[T any, U any] struct { val U } -func (m *mapIter[T, U]) Next() bool { +func (m *MapIter[T, U]) Next() bool { if m.done { return false } @@ -32,13 +30,10 @@ func (m *mapIter[T, U]) Next() bool { return true } -func (m *mapIter[T, U]) Val() U { +func (m *MapIter[T, U]) Val() U { return m.val } -func (m *mapIter[T, U]) Close() error { - if closer, ok := m.iter.(io.Closer); ok { - return closer.Close() - } - return nil +func (m *MapIter[T, U]) Close() error { + return m.iter.Close() } diff --git a/routing/http/types/iter/map_test.go b/routing/http/types/iter/map_test.go index 881679c7a..d0605df85 100644 --- a/routing/http/types/iter/map_test.go +++ b/routing/http/types/iter/map_test.go @@ -7,29 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -// type nthErrIter[T any] struct { -// Iter[T] -// i int -// n int -// err error -// } - -// func (n *nthErrIter[T]) Next() (T, bool) { -// v, ok := n.Iter.Next() -// n.i++ -// return v, ok -// } -// func (n *nthErrIter[T]) Err() error { -// if n.i-1 == n.n { -// return n.err -// } -// return n.Iter.Err() -// } - -// func nthErr[T any](iter Iter[T], n int, err error) Iter[T] { -// return &nthErrIter[T]{Iter: iter, n: n, err: err} -// } - func TestMap(t *testing.T) { for _, c := range []struct { input Iter[int] @@ -41,25 +18,16 @@ func TestMap(t *testing.T) { f: func(i int) int { return i + 1 }, expResults: []int{2, 3, 4}, }, - // { - // input: FromSlice([]int{}), - // f: func(i int) int { return i + 1 }, - // expResults: []int{}, - // }, - // { - // input: FromSlice([]int{1}), - // f: func(i int) int { return i + 1 }, - // expResults: []int{2}, - // }, - // { - // input: FromSlice([]int{1, 2, 3}), 2, errors.New("boom"), - // f: func(i int) (int, error) { return i + 1, nil }, - // expResults: []result{ - // {val: 2}, - // {val: 3}, - // {errContains: "boom"}, - // }, - // }, + { + input: FromSlice([]int{}), + f: func(i int) int { return i + 1 }, + expResults: nil, + }, + { + input: FromSlice([]int{1}), + f: func(i int) int { return i + 1 }, + expResults: []int{2}, + }, } { t.Run(fmt.Sprintf("%v", c.input), func(t *testing.T) { iter := Map(c.input, c.f) diff --git a/routing/http/types/iter/slice.go b/routing/http/types/iter/slice.go index ec3619b7d..fc3ecae15 100644 --- a/routing/http/types/iter/slice.go +++ b/routing/http/types/iter/slice.go @@ -23,3 +23,7 @@ func (s *SliceIter[T]) Next() bool { func (s *SliceIter[T]) Val() T { return s.val } + +func (s *SliceIter[T]) Close() error { + return nil +} diff --git a/routing/http/types/jsonseq/provider.go b/routing/http/types/jsonseq/provider.go index 480b1fdd5..19ece72ac 100644 --- a/routing/http/types/jsonseq/provider.go +++ b/routing/http/types/jsonseq/provider.go @@ -9,7 +9,7 @@ import ( ) // NewReadProvidersResponseIter returns an iterator that reads Read Provider Records from the given reader. -func NewReadProvidersResponseIter(r io.Reader) iter.ResultIter[types.ProviderResponse] { +func NewReadProvidersResponseIter(r io.Reader) iter.Iter[iter.Result[types.ProviderResponse]] { jsonIter := iter.FromReaderJSON[types.UnknownProviderRecord](r) mapFn := func(upr iter.Result[types.UnknownProviderRecord]) iter.Result[types.ProviderResponse] { var result iter.Result[types.ProviderResponse] From a457e86e0deebfb173f897e9c414fabd614f4535 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Sat, 11 Feb 2023 09:42:40 -0500 Subject: [PATCH 05/10] Switch from jsonseq to ndjson These are practically the same, but cid.contact has already started using ndjson so we're just going to use that instead. --- routing/http/client/client.go | 14 ++++---- routing/http/server/server.go | 32 +++++++++---------- .../types/{jsonseq => ndjson}/provider.go | 2 +- 3 files changed, 24 insertions(+), 24 deletions(-) rename routing/http/types/{jsonseq => ndjson}/provider.go (98%) diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 1258141e4..13b5df5c8 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -19,7 +19,7 @@ import ( "github.com/ipfs/boxo/routing/http/types" "github.com/ipfs/boxo/routing/http/types/iter" jsontypes "github.com/ipfs/boxo/routing/http/types/json" - "github.com/ipfs/boxo/routing/http/types/jsonseq" + "github.com/ipfs/boxo/routing/http/types/ndjson" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" record "github.com/libp2p/go-libp2p-record" @@ -34,8 +34,8 @@ var ( ) const ( - mediaTypeJSON = "application/json" - mediaTypeJSONSeq = "application/json-seq" + mediaTypeJSON = "application/json" + mediaTypeNDJSON = "application/x-ndjson" ) type client struct { @@ -107,7 +107,7 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { func WithStreamResultsRequired() option { return func(c *client) { - c.accepts = mediaTypeJSONSeq + c.accepts = mediaTypeNDJSON } } @@ -126,7 +126,7 @@ func New(baseURL string, opts ...option) (*client, error) { httpClient: defaultHTTPClient, validator: ipns.Validator{}, clock: clock.New(), - accepts: strings.Join([]string{mediaTypeJSONSeq, mediaTypeJSON}, ","), + accepts: strings.Join([]string{mediaTypeNDJSON, mediaTypeJSON}, ","), } for _, opt := range opts { @@ -217,8 +217,8 @@ func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Res err = json.NewDecoder(resp.Body).Decode(parsedResp) var sliceIt iter.Iter[types.ProviderResponse] = iter.FromSlice(parsedResp.Providers) it = iter.ToResultIter(sliceIt) - case mediaTypeJSONSeq: - it = jsonseq.NewReadProvidersResponseIter(resp.Body) + case mediaTypeNDJSON: + it = ndjson.NewReadProvidersResponseIter(resp.Body) default: defer resp.Body.Close() logger.Errorw("unknown media type", "MediaType", mediaType, "ContentType", respContentType) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index cef122f9e..a9051f8b4 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -24,8 +24,8 @@ import ( ) const ( - mediaTypeJSON = "application/json" - mediaTypeJSONSeq = "application/json-seq" + mediaTypeJSON = "application/json" + mediaTypeNDJSON = "application/x-ndjson" ) var logger = logging.Logger("service/server/delegatedrouting") @@ -60,10 +60,10 @@ type WriteProvideRequest struct { type serverOption func(s *server) -// WithStreamingResultsDisabled disables jsonseq responses, so that the server only supports JSON responses. +// WithStreamingResultsDisabled disables ndjson responses, so that the server only supports JSON responses. func WithStreamingResultsDisabled() serverOption { return func(s *server) { - s.disableJSONSeq = true + s.disableNDJSON = true } } @@ -84,8 +84,8 @@ func Handler(svc ContentRouter, opts ...serverOption) http.Handler { } type server struct { - svc ContentRouter - disableJSONSeq bool + svc ContentRouter + disableNDJSON bool } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { @@ -166,7 +166,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { var handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) - var supportsJSONSeq bool + var supportsNDJSON bool var supportsJSON bool accepts := httpReq.Header.Values("Accept") if len(accepts) == 0 { @@ -182,13 +182,13 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { switch mediaType { case mediaTypeJSON: supportsJSON = true - case mediaTypeJSONSeq: - supportsJSONSeq = true + case mediaTypeNDJSON: + supportsNDJSON = true } } - if supportsJSONSeq && !s.disableJSONSeq { - handlerFunc = s.findProvidersJSONSeq + if supportsNDJSON && !s.disableNDJSON { + handlerFunc = s.findProvidersNDJSON } else if supportsJSON { handlerFunc = s.findProvidersJSON } else { @@ -227,27 +227,27 @@ func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIt writeJSONResult(w, "FindProviders", response) } -func (s *server) findProvidersJSONSeq(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) { +func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) { defer provIter.Close() - w.Header().Set("Content-Type", mediaTypeJSONSeq) + w.Header().Set("Content-Type", mediaTypeNDJSON) w.WriteHeader(http.StatusOK) for provIter.Next() { res := provIter.Val() if res.Err != nil { - logger.Errorw("FindProviders jsonseq iterator error", "Error", res.Err) + logger.Errorw("FindProviders ndjson iterator error", "Error", res.Err) return } // don't use an encoder because we can't easily differentiate writer errors from encoding errors b, err := drjson.MarshalJSONBytes(res.Val) if err != nil { - logger.Errorw("FindProviders jsonseq marshal error", "Error", err) + logger.Errorw("FindProviders ndjson marshal error", "Error", err) return } _, err = w.Write(b) if err != nil { - logger.Warn("FindProviders jsonseq write error", "Error", err) + logger.Warn("FindProviders ndjson write error", "Error", err) return } if f, ok := w.(http.Flusher); ok { diff --git a/routing/http/types/jsonseq/provider.go b/routing/http/types/ndjson/provider.go similarity index 98% rename from routing/http/types/jsonseq/provider.go rename to routing/http/types/ndjson/provider.go index 19ece72ac..1bf758ad2 100644 --- a/routing/http/types/jsonseq/provider.go +++ b/routing/http/types/ndjson/provider.go @@ -1,4 +1,4 @@ -package jsonseq +package ndjson import ( "encoding/json" From 4f44e873d22c08fbc6d3fc2352b676dd14b4087a Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Mon, 13 Feb 2023 17:03:02 -0500 Subject: [PATCH 06/10] Allow */* media type, defaulting to application/json --- routing/http/server/server.go | 7 ++++--- routing/http/types/iter/iter.go | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index a9051f8b4..38b2ec48d 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -24,8 +24,9 @@ import ( ) const ( - mediaTypeJSON = "application/json" - mediaTypeNDJSON = "application/x-ndjson" + mediaTypeJSON = "application/json" + mediaTypeNDJSON = "application/x-ndjson" + mediaTypeWildcard = "*/*" ) var logger = logging.Logger("service/server/delegatedrouting") @@ -180,7 +181,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { } switch mediaType { - case mediaTypeJSON: + case mediaTypeJSON, mediaTypeWildcard: supportsJSON = true case mediaTypeNDJSON: supportsNDJSON = true diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go index d313d4614..37ca634b7 100644 --- a/routing/http/types/iter/iter.go +++ b/routing/http/types/iter/iter.go @@ -15,6 +15,7 @@ type Iter[T any] interface { // Next sets the iterator to the next value, returning true if an attempt was made to get the next value. Next() bool Val() T + // Close closes the iterator and any underlying resources. Failure to close an iterator may result in resource leakage (goroutines, FDs, conns, etc.). Close() error } From 51b9d17c219580c613e272d4a57eba10071dd635 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Wed, 15 Feb 2023 11:23:15 -0500 Subject: [PATCH 07/10] Add some more tests --- routing/http/client/client.go | 38 +++--- routing/http/client/client_test.go | 186 +++++++++++++++++++++-------- routing/http/server/server.go | 37 +++--- routing/http/server/server_test.go | 1 + routing/http/types/iter/iter.go | 1 + 5 files changed, 176 insertions(+), 87 deletions(-) diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 13b5df5c8..14aa7aaa3 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -29,8 +29,15 @@ import ( ) var ( - _ contentrouter.Client = &client{} - logger = logging.Logger("service/delegatedrouting") + _ contentrouter.Client = &client{} + logger = logging.Logger("service/delegatedrouting") + defaultHTTPClient = &http.Client{ + Transport: &ResponseBodyLimitedTransport{ + RoundTripper: http.DefaultTransport, + LimitBytes: 1 << 20, + UserAgent: defaultUserAgent, + }, + } ) const ( @@ -65,21 +72,21 @@ type httpClient interface { Do(req *http.Request) (*http.Response, error) } -type option func(*client) +type Option func(*client) -func WithIdentity(identity crypto.PrivKey) option { +func WithIdentity(identity crypto.PrivKey) Option { return func(c *client) { c.identity = identity } } -func WithHTTPClient(h httpClient) option { +func WithHTTPClient(h httpClient) Option { return func(c *client) { c.httpClient = h } } -func WithUserAgent(ua string) option { +func WithUserAgent(ua string) Option { return func(c *client) { if ua == "" { return @@ -96,7 +103,7 @@ func WithUserAgent(ua string) option { } } -func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { +func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) Option { return func(c *client) { c.peerID = peerID for _, a := range addrs { @@ -105,7 +112,7 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { } } -func WithStreamResultsRequired() option { +func WithStreamResultsRequired() Option { return func(c *client) { c.accepts = mediaTypeNDJSON } @@ -113,14 +120,7 @@ func WithStreamResultsRequired() option { // New creates a content routing API client. // The Provider and identity parameters are option. If they are nil, the `Provide` method will not function. -func New(baseURL string, opts ...option) (*client, error) { - defaultHTTPClient := &http.Client{ - Transport: &ResponseBodyLimitedTransport{ - RoundTripper: http.DefaultTransport, - LimitBytes: 1 << 20, - UserAgent: defaultUserAgent, - }, - } +func New(baseURL string, opts ...Option) (*client, error) { client := &client{ baseURL: baseURL, httpClient: defaultHTTPClient, @@ -171,6 +171,7 @@ func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Res if err != nil { return nil, err } + req.Header.Set("Accept", c.accepts) m.host = req.Host @@ -189,13 +190,14 @@ func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Res if resp.StatusCode == http.StatusNotFound { resp.Body.Close() m.record(ctx) - return nil, nil + return iter.FromSlice[iter.Result[types.ProviderResponse]](nil), nil } if resp.StatusCode != http.StatusOK { + err := httpError(resp.StatusCode, resp.Body) resp.Body.Close() m.record(ctx) - return nil, httpError(resp.StatusCode, resp.Body) + return nil, err } respContentType := resp.Header.Get("Content-Type") diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index ca724b1c2..05ad997af 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -42,45 +42,76 @@ func (m *mockContentRouter) Provide(ctx context.Context, req *server.WriteProvid } type testDeps struct { - router *mockContentRouter - server *httptest.Server - peerID peer.ID - addrs []multiaddr.Multiaddr - client *client + // recordingHandler records requests received on the server side + recordingHandler *recordingHandler + // recordingHTTPClient records responses received on the client side + recordingHTTPClient *recordingHTTPClient + router *mockContentRouter + server *httptest.Server + peerID peer.ID + addrs []multiaddr.Multiaddr + client *client } -func makeTestDeps(t *testing.T) testDeps { +type recordingHandler struct { + http.Handler + f []func(*http.Request) +} + +func (h *recordingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + for _, f := range h.f { + f(r) + } + h.Handler.ServeHTTP(w, r) +} + +type recordingHTTPClient struct { + httpClient + f []func(*http.Response) +} + +func (c *recordingHTTPClient) Do(req *http.Request) (*http.Response, error) { + resp, err := c.httpClient.Do(req) + for _, f := range c.f { + f(resp) + } + return resp, err +} + +func makeTestDeps(t *testing.T, clientsOpts []Option, serverOpts []server.Option) testDeps { const testUserAgent = "testUserAgent" peerID, addrs, identity := makeProviderAndIdentity() router := &mockContentRouter{} - server := httptest.NewServer(server.Handler(router)) + recordingHandler := &recordingHandler{ + Handler: server.Handler(router, serverOpts...), + f: []func(*http.Request){ + func(r *http.Request) { + assert.Equal(t, testUserAgent, r.Header.Get("User-Agent")) + }, + }, + } + server := httptest.NewServer(recordingHandler) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - c, err := New(serverAddr, WithProviderInfo(peerID, addrs), WithIdentity(identity), WithUserAgent(testUserAgent)) + recordingHTTPClient := &recordingHTTPClient{httpClient: defaultHTTPClient} + defaultClientOpts := []Option{ + WithProviderInfo(peerID, addrs), + WithIdentity(identity), + WithUserAgent(testUserAgent), + WithHTTPClient(recordingHTTPClient), + } + c, err := New(serverAddr, append(defaultClientOpts, clientsOpts...)...) if err != nil { panic(err) } - assertUserAgentOverride(t, c, testUserAgent) return testDeps{ - router: router, - server: server, - peerID: peerID, - addrs: addrs, - client: c, - } -} - -func assertUserAgentOverride(t *testing.T, c *client, expected string) { - httpClient, ok := c.httpClient.(*http.Client) - if !ok { - t.Error("invalid c.httpClient") - } - transport, ok := httpClient.Transport.(*ResponseBodyLimitedTransport) - if !ok { - t.Error("invalid httpClient.Transport") - } - if transport.UserAgent != expected { - t.Error("invalid httpClient.Transport.UserAgent") + recordingHandler: recordingHandler, + recordingHTTPClient: recordingHTTPClient, + router: router, + server: server, + peerID: peerID, + addrs: addrs, + client: c, } } @@ -150,6 +181,10 @@ type osErrContains struct { } func (e *osErrContains) errContains(t *testing.T, err error) { + if e.expContains == "" && e.expContainsWin == "" { + assert.NoError(t, err) + return + } if runtime.GOOS == "windows" && len(e.expContainsWin) != 0 { assert.ErrorContains(t, err, e.expContainsWin) } else { @@ -164,37 +199,90 @@ func TestClient_FindProviders(t *testing.T) { } cases := []struct { - name string - httpStatusCode int - stopServer bool - routerProvs []iter.Result[types.ProviderResponse] - routerErr error - - expProvs []iter.Result[types.ProviderResponse] - expErrContains []osErrContains + name string + httpStatusCode int + stopServer bool + routerProvs []iter.Result[types.ProviderResponse] + routerErr error + clientRequiresStreaming bool + serverStreamingDisabled bool + + expErrContains osErrContains + expProvs []iter.Result[types.ProviderResponse] + expStreamingResponse bool + expJSONResponse bool }{ { - name: "happy case", - routerProvs: bitswapProvs, - expProvs: bitswapProvs, + name: "happy case", + routerProvs: bitswapProvs, + expProvs: bitswapProvs, + expStreamingResponse: true, + }, + { + name: "server doesn't support streaming", + routerProvs: bitswapProvs, + expProvs: bitswapProvs, + serverStreamingDisabled: true, + expJSONResponse: true, + }, + { + name: "client requires streaming but server doesn't support it", + serverStreamingDisabled: true, + clientRequiresStreaming: true, + expErrContains: osErrContains{expContains: "HTTP error with StatusCode=400: no supported content types"}, }, { name: "returns an error if there's a non-200 response", httpStatusCode: 500, - expErrContains: []osErrContains{{expContains: "HTTP error with StatusCode=500: "}}, + expErrContains: osErrContains{expContains: "HTTP error with StatusCode=500"}, }, { name: "returns an error if the HTTP client returns a non-HTTP error", stopServer: true, - expErrContains: []osErrContains{{ + expErrContains: osErrContains{ expContains: "connect: connection refused", expContainsWin: "connectex: No connection could be made because the target machine actively refused it.", - }}, + }, + }, + { + name: "returns no providers if the HTTP server returns a 404 respones", + httpStatusCode: 404, + expProvs: nil, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - deps := makeTestDeps(t) + var clientOpts []Option + var serverOpts []server.Option + var onRespReceived []func(*http.Response) + var onReqReceived []func(*http.Request) + + if c.serverStreamingDisabled { + serverOpts = append(serverOpts, server.WithStreamingResultsDisabled()) + } + if c.clientRequiresStreaming { + clientOpts = append(clientOpts, WithStreamResultsRequired()) + onReqReceived = append(onReqReceived, func(r *http.Request) { + assert.Equal(t, mediaTypeNDJSON, r.Header.Get("Accept")) + }) + } + + if c.expStreamingResponse { + onRespReceived = append(onRespReceived, func(r *http.Response) { + assert.Equal(t, mediaTypeNDJSON, r.Header.Get("Content-Type")) + }) + } + if c.expJSONResponse { + onRespReceived = append(onRespReceived, func(r *http.Response) { + assert.Equal(t, mediaTypeJSON, r.Header.Get("Content-Type")) + }) + } + + deps := makeTestDeps(t, clientOpts, serverOpts) + + deps.recordingHTTPClient.f = append(deps.recordingHTTPClient.f, onRespReceived...) + deps.recordingHandler.f = append(deps.recordingHandler.f, onReqReceived...) + client := deps.client router := deps.router @@ -219,12 +307,7 @@ func TestClient_FindProviders(t *testing.T) { provsIter, err := client.FindProviders(ctx, cid) - for _, exp := range c.expErrContains { - exp.errContains(t, err) - } - if len(c.expErrContains) == 0 { - require.NoError(t, err) - } + c.expErrContains.errContains(t, err) provs := iter.ReadAll[iter.Result[types.ProviderResponse]](provsIter) assert.Equal(t, c.expProvs, provs) @@ -264,8 +347,7 @@ func TestClient_Provide(t *testing.T) { name: "should return a 403 if the payload signature verification fails", cids: []cid.Cid{}, mangleSignature: true, - - expErrContains: "HTTP error with StatusCode=403", + expErrContains: "HTTP error with StatusCode=403", }, { name: "should return error if identity is not provided", @@ -291,7 +373,7 @@ func TestClient_Provide(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - deps := makeTestDeps(t) + deps := makeTestDeps(t, nil, nil) client := deps.client router := deps.router diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 38b2ec48d..8ce7d063b 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -9,6 +9,7 @@ import ( "io" "mime" "net/http" + "strings" "time" "github.com/gorilla/mux" @@ -59,16 +60,16 @@ type WriteProvideRequest struct { Bytes []byte } -type serverOption func(s *server) +type Option func(s *server) // WithStreamingResultsDisabled disables ndjson responses, so that the server only supports JSON responses. -func WithStreamingResultsDisabled() serverOption { +func WithStreamingResultsDisabled() Option { return func(s *server) { s.disableNDJSON = true } } -func Handler(svc ContentRouter, opts ...serverOption) http.Handler { +func Handler(svc ContentRouter, opts ...Option) http.Handler { server := &server{ svc: svc, } @@ -169,22 +170,24 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { var supportsNDJSON bool var supportsJSON bool - accepts := httpReq.Header.Values("Accept") - if len(accepts) == 0 { + acceptHeaders := httpReq.Header.Values("Accept") + if len(acceptHeaders) == 0 { handlerFunc = s.findProvidersJSON } else { - for _, accept := range accepts { - mediaType, _, err := mime.ParseMediaType(accept) - if err != nil { - writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse Accept header: %w", err)) - return - } - - switch mediaType { - case mediaTypeJSON, mediaTypeWildcard: - supportsJSON = true - case mediaTypeNDJSON: - supportsNDJSON = true + for _, acceptHeader := range acceptHeaders { + for _, accept := range strings.Split(acceptHeader, ",") { + mediaType, _, err := mime.ParseMediaType(accept) + if err != nil { + writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse Accept header: %w", err)) + return + } + + switch mediaType { + case mediaTypeJSON, mediaTypeWildcard: + supportsJSON = true + case mediaTypeNDJSON: + supportsNDJSON = true + } } } diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index ab96afe03..fec5eaf9a 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -42,6 +42,7 @@ func TestHeaders(t *testing.T) { resp, err = http.Get(serverAddr + ProvidePath + "BAD_CID") require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, 400, resp.StatusCode) header = resp.Header.Get("Content-Type") require.Equal(t, "text/plain; charset=utf-8", header) diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go index 37ca634b7..67c6dde00 100644 --- a/routing/http/types/iter/iter.go +++ b/routing/http/types/iter/iter.go @@ -37,6 +37,7 @@ func ReadAll[T any](iter Iter[T]) []T { if iter == nil { return nil } + defer iter.Close() var vs []T for iter.Next() { vs = append(vs, iter.Val()) From 296fcefd0a90962f460bd8e808156d9c60a1f77f Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Tue, 21 Mar 2023 09:15:37 -0400 Subject: [PATCH 08/10] Close response body in defer to prevent mistakes --- routing/http/client/client.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 14aa7aaa3..b3a74150c 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -211,18 +211,24 @@ func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Res m.mediaType = mediaType + var skipBodyClose bool + defer func() { + if !skipBodyClose { + resp.Body.Close() + } + }() + var it iter.ResultIter[types.ProviderResponse] switch mediaType { case mediaTypeJSON: - defer resp.Body.Close() parsedResp := &jsontypes.ReadProvidersResponse{} err = json.NewDecoder(resp.Body).Decode(parsedResp) var sliceIt iter.Iter[types.ProviderResponse] = iter.FromSlice(parsedResp.Providers) it = iter.ToResultIter(sliceIt) case mediaTypeNDJSON: + skipBodyClose = true it = ndjson.NewReadProvidersResponseIter(resp.Body) default: - defer resp.Body.Close() logger.Errorw("unknown media type", "MediaType", mediaType, "ContentType", respContentType) return nil, errors.New("unknown content type") } From c3656e729a6bfee7ab8a4a02385e8e1f483d933a Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Tue, 9 May 2023 14:55:48 -0400 Subject: [PATCH 09/10] fix imports in ndjson pkg --- routing/http/types/ndjson/provider.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/routing/http/types/ndjson/provider.go b/routing/http/types/ndjson/provider.go index 1bf758ad2..38e28df9a 100644 --- a/routing/http/types/ndjson/provider.go +++ b/routing/http/types/ndjson/provider.go @@ -4,8 +4,8 @@ import ( "encoding/json" "io" - "github.com/ipfs/go-libipfs/routing/http/types" - "github.com/ipfs/go-libipfs/routing/http/types/iter" + "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" ) // NewReadProvidersResponseIter returns an iterator that reads Read Provider Records from the given reader. From d21f0c87bacb4d65084e17986339bc0d43b2cfa9 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Tue, 9 May 2023 14:57:22 -0400 Subject: [PATCH 10/10] fix imports in routing/http/types/json --- routing/http/types/json/provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/routing/http/types/json/provider.go b/routing/http/types/json/provider.go index 7ad344565..351197338 100644 --- a/routing/http/types/json/provider.go +++ b/routing/http/types/json/provider.go @@ -3,7 +3,7 @@ package json import ( "encoding/json" - "github.com/ipfs/go-libipfs/routing/http/types" + "github.com/ipfs/boxo/routing/http/types" ) // ReadProvidersResponse is the result of a Provide request