diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 4e5cfb8f4..b3a74150c 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -6,7 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "mime" "net/http" + "strings" "time" "github.com/benbjohnson/clock" @@ -15,6 +17,9 @@ import ( "github.com/ipfs/boxo/routing/http/internal/drjson" "github.com/ipfs/boxo/routing/http/server" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" + jsontypes "github.com/ipfs/boxo/routing/http/types/json" + "github.com/ipfs/boxo/routing/http/types/ndjson" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" record "github.com/libp2p/go-libp2p-record" @@ -23,7 +28,22 @@ import ( "github.com/multiformats/go-multiaddr" ) -var logger = logging.Logger("service/delegatedrouting") +var ( + _ contentrouter.Client = &client{} + logger = logging.Logger("service/delegatedrouting") + defaultHTTPClient = &http.Client{ + Transport: &ResponseBodyLimitedTransport{ + RoundTripper: http.DefaultTransport, + LimitBytes: 1 << 20, + UserAgent: defaultUserAgent, + }, + } +) + +const ( + mediaTypeJSON = "application/json" + mediaTypeNDJSON = "application/x-ndjson" +) type client struct { baseURL string @@ -31,6 +51,8 @@ type client struct { validator record.Validator clock clock.Clock + accepts string + peerID peer.ID addrs []types.Multiaddr identity crypto.PrivKey @@ -50,21 +72,21 @@ type httpClient interface { Do(req *http.Request) (*http.Response, error) } -type option func(*client) +type Option func(*client) -func WithIdentity(identity crypto.PrivKey) option { +func WithIdentity(identity crypto.PrivKey) Option { return func(c *client) { c.identity = identity } } -func WithHTTPClient(h httpClient) option { +func WithHTTPClient(h httpClient) Option { return func(c *client) { c.httpClient = h } } -func WithUserAgent(ua string) option { +func WithUserAgent(ua string) Option { return func(c *client) { if ua == "" { return @@ -81,7 +103,7 @@ func WithUserAgent(ua string) option { } } -func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { +func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) Option { return func(c *client) { c.peerID = peerID for _, a := range addrs { @@ -90,21 +112,21 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { } } +func WithStreamResultsRequired() Option { + return func(c *client) { + c.accepts = mediaTypeNDJSON + } +} + // New creates a content routing API client. // The Provider and identity parameters are option. If they are nil, the `Provide` method will not function. -func New(baseURL string, opts ...option) (*client, error) { - defaultHTTPClient := &http.Client{ - Transport: &ResponseBodyLimitedTransport{ - RoundTripper: http.DefaultTransport, - LimitBytes: 1 << 20, - UserAgent: defaultUserAgent, - }, - } +func New(baseURL string, opts ...Option) (*client, error) { client := &client{ baseURL: baseURL, httpClient: defaultHTTPClient, validator: ipns.Validator{}, clock: clock.New(), + accepts: strings.Join([]string{mediaTypeNDJSON, mediaTypeJSON}, ","), } for _, opt := range opts { @@ -118,43 +140,100 @@ func New(baseURL string, opts ...option) (*client, error) { return client, nil } -func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs []types.ProviderResponse, err error) { - measurement := newMeasurement("FindProviders") - defer func() { - measurement.length = len(provs) - measurement.record(ctx) - }() +// measuringIter measures the length of the iter and then publishes metrics about the whole req once the iter is closed. +// Of course, if the caller forgets to close the iter, this won't publish anything. +type measuringIter[T any] struct { + iter.Iter[T] + ctx context.Context + m *measurement +} + +func (c *measuringIter[T]) Next() bool { + c.m.length++ + return c.Iter.Next() +} + +func (c *measuringIter[T]) Val() T { + return c.Iter.Val() +} + +func (c *measuringIter[T]) Close() error { + c.m.record(c.ctx) + return c.Iter.Close() +} + +func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.ResultIter[types.ProviderResponse], err error) { + // TODO test measurements + m := newMeasurement("FindProviders") url := c.baseURL + server.ProvidePath + key.String() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } - measurement.host = req.Host + req.Header.Set("Accept", c.accepts) + + m.host = req.Host start := c.clock.Now() resp, err := c.httpClient.Do(req) - measurement.err = err - measurement.latency = c.clock.Since(start) + m.err = err + m.latency = c.clock.Since(start) if err != nil { + m.record(ctx) return nil, err } - defer resp.Body.Close() - measurement.statusCode = resp.StatusCode + m.statusCode = resp.StatusCode if resp.StatusCode == http.StatusNotFound { - return nil, nil + resp.Body.Close() + m.record(ctx) + return iter.FromSlice[iter.Result[types.ProviderResponse]](nil), nil } if resp.StatusCode != http.StatusOK { - return nil, httpError(resp.StatusCode, resp.Body) + err := httpError(resp.StatusCode, resp.Body) + resp.Body.Close() + m.record(ctx) + return nil, err + } + + respContentType := resp.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(respContentType) + if err != nil { + resp.Body.Close() + m.err = err + m.record(ctx) + return nil, fmt.Errorf("parsing Content-Type: %w", err) + } + + m.mediaType = mediaType + + var skipBodyClose bool + defer func() { + if !skipBodyClose { + resp.Body.Close() + } + }() + + var it iter.ResultIter[types.ProviderResponse] + switch mediaType { + case mediaTypeJSON: + parsedResp := &jsontypes.ReadProvidersResponse{} + err = json.NewDecoder(resp.Body).Decode(parsedResp) + var sliceIt iter.Iter[types.ProviderResponse] = iter.FromSlice(parsedResp.Providers) + it = iter.ToResultIter(sliceIt) + case mediaTypeNDJSON: + skipBodyClose = true + it = ndjson.NewReadProvidersResponseIter(resp.Body) + default: + logger.Errorw("unknown media type", "MediaType", mediaType, "ContentType", respContentType) + return nil, errors.New("unknown content type") } - parsedResp := &types.ReadProvidersResponse{} - err = json.NewDecoder(resp.Body).Decode(parsedResp) - return parsedResp.Providers, err + return &measuringIter[iter.Result[types.ProviderResponse]]{Iter: it, ctx: ctx, m: m}, nil } func (c *client) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Duration) (time.Duration, error) { @@ -202,7 +281,7 @@ func (c *client) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Du // ProvideAsync makes a provide request to a delegated router func (c *client) provideSignedBitswapRecord(ctx context.Context, bswp *types.WriteBitswapProviderRecord) (time.Duration, error) { - req := types.WriteProvidersRequest{Providers: []types.WriteProviderRecord{bswp}} + req := jsontypes.WriteProvidersRequest{Providers: []types.WriteProviderRecord{bswp}} url := c.baseURL + server.ProvidePath @@ -225,7 +304,7 @@ func (c *client) provideSignedBitswapRecord(ctx context.Context, bswp *types.Wri if resp.StatusCode != http.StatusOK { return 0, httpError(resp.StatusCode, resp.Body) } - var provideResult types.WriteProvidersResponse + var provideResult jsontypes.WriteProvidersResponse err = json.NewDecoder(resp.Body).Decode(&provideResult) if err != nil { return 0, err diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index d8fc4abac..05ad997af 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -12,7 +12,9 @@ import ( "github.com/benbjohnson/clock" "github.com/ipfs/boxo/routing/http/server" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" + "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multiaddr" @@ -25,9 +27,9 @@ import ( type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).([]types.ProviderResponse), args.Error(1) + return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) @@ -40,45 +42,76 @@ func (m *mockContentRouter) Provide(ctx context.Context, req *server.WriteProvid } type testDeps struct { - router *mockContentRouter - server *httptest.Server - peerID peer.ID - addrs []multiaddr.Multiaddr - client *client + // recordingHandler records requests received on the server side + recordingHandler *recordingHandler + // recordingHTTPClient records responses received on the client side + recordingHTTPClient *recordingHTTPClient + router *mockContentRouter + server *httptest.Server + peerID peer.ID + addrs []multiaddr.Multiaddr + client *client +} + +type recordingHandler struct { + http.Handler + f []func(*http.Request) +} + +func (h *recordingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + for _, f := range h.f { + f(r) + } + h.Handler.ServeHTTP(w, r) +} + +type recordingHTTPClient struct { + httpClient + f []func(*http.Response) +} + +func (c *recordingHTTPClient) Do(req *http.Request) (*http.Response, error) { + resp, err := c.httpClient.Do(req) + for _, f := range c.f { + f(resp) + } + return resp, err } -func makeTestDeps(t *testing.T) testDeps { +func makeTestDeps(t *testing.T, clientsOpts []Option, serverOpts []server.Option) testDeps { const testUserAgent = "testUserAgent" peerID, addrs, identity := makeProviderAndIdentity() router := &mockContentRouter{} - server := httptest.NewServer(server.Handler(router)) + recordingHandler := &recordingHandler{ + Handler: server.Handler(router, serverOpts...), + f: []func(*http.Request){ + func(r *http.Request) { + assert.Equal(t, testUserAgent, r.Header.Get("User-Agent")) + }, + }, + } + server := httptest.NewServer(recordingHandler) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - c, err := New(serverAddr, WithProviderInfo(peerID, addrs), WithIdentity(identity), WithUserAgent(testUserAgent)) + recordingHTTPClient := &recordingHTTPClient{httpClient: defaultHTTPClient} + defaultClientOpts := []Option{ + WithProviderInfo(peerID, addrs), + WithIdentity(identity), + WithUserAgent(testUserAgent), + WithHTTPClient(recordingHTTPClient), + } + c, err := New(serverAddr, append(defaultClientOpts, clientsOpts...)...) if err != nil { panic(err) } - assertUserAgentOverride(t, c, testUserAgent) return testDeps{ - router: router, - server: server, - peerID: peerID, - addrs: addrs, - client: c, - } -} - -func assertUserAgentOverride(t *testing.T, c *client, expected string) { - httpClient, ok := c.httpClient.(*http.Client) - if !ok { - t.Error("invalid c.httpClient") - } - transport, ok := httpClient.Transport.(*ResponseBodyLimitedTransport) - if !ok { - t.Error("invalid httpClient.Transport") - } - if transport.UserAgent != expected { - t.Error("invalid httpClient.Transport.UserAgent") + recordingHandler: recordingHandler, + recordingHTTPClient: recordingHTTPClient, + router: router, + server: server, + peerID: peerID, + addrs: addrs, + client: c, } } @@ -142,44 +175,120 @@ func makeProviderAndIdentity() (peer.ID, []multiaddr.Multiaddr, crypto.PrivKey) return peerID, []multiaddr.Multiaddr{ma1, ma2}, priv } +type osErrContains struct { + expContains string + expContainsWin string +} + +func (e *osErrContains) errContains(t *testing.T, err error) { + if e.expContains == "" && e.expContainsWin == "" { + assert.NoError(t, err) + return + } + if runtime.GOOS == "windows" && len(e.expContainsWin) != 0 { + assert.ErrorContains(t, err, e.expContainsWin) + } else { + assert.ErrorContains(t, err, e.expContains) + } +} + func TestClient_FindProviders(t *testing.T) { bsReadProvResp := makeBSReadProviderResp() - bitswapProvs := []types.ProviderResponse{&bsReadProvResp} + bitswapProvs := []iter.Result[types.ProviderResponse]{ + {Val: &bsReadProvResp}, + } cases := []struct { - name string - httpStatusCode int - stopServer bool - routerProvs []types.ProviderResponse - routerErr error - - expProvs []types.ProviderResponse - expErrContains []string - expWinErrContains []string + name string + httpStatusCode int + stopServer bool + routerProvs []iter.Result[types.ProviderResponse] + routerErr error + clientRequiresStreaming bool + serverStreamingDisabled bool + + expErrContains osErrContains + expProvs []iter.Result[types.ProviderResponse] + expStreamingResponse bool + expJSONResponse bool }{ { - name: "happy case", - routerProvs: bitswapProvs, - expProvs: bitswapProvs, + name: "happy case", + routerProvs: bitswapProvs, + expProvs: bitswapProvs, + expStreamingResponse: true, + }, + { + name: "server doesn't support streaming", + routerProvs: bitswapProvs, + expProvs: bitswapProvs, + serverStreamingDisabled: true, + expJSONResponse: true, + }, + { + name: "client requires streaming but server doesn't support it", + serverStreamingDisabled: true, + clientRequiresStreaming: true, + expErrContains: osErrContains{expContains: "HTTP error with StatusCode=400: no supported content types"}, }, { name: "returns an error if there's a non-200 response", httpStatusCode: 500, - expErrContains: []string{"HTTP error with StatusCode=500: "}, + expErrContains: osErrContains{expContains: "HTTP error with StatusCode=500"}, }, { - name: "returns an error if the HTTP client returns a non-HTTP error", - stopServer: true, - expErrContains: []string{"connect: connection refused"}, - expWinErrContains: []string{"connectex: No connection could be made because the target machine actively refused it."}, + name: "returns an error if the HTTP client returns a non-HTTP error", + stopServer: true, + expErrContains: osErrContains{ + expContains: "connect: connection refused", + expContainsWin: "connectex: No connection could be made because the target machine actively refused it.", + }, + }, + { + name: "returns no providers if the HTTP server returns a 404 respones", + httpStatusCode: 404, + expProvs: nil, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - deps := makeTestDeps(t) + var clientOpts []Option + var serverOpts []server.Option + var onRespReceived []func(*http.Response) + var onReqReceived []func(*http.Request) + + if c.serverStreamingDisabled { + serverOpts = append(serverOpts, server.WithStreamingResultsDisabled()) + } + if c.clientRequiresStreaming { + clientOpts = append(clientOpts, WithStreamResultsRequired()) + onReqReceived = append(onReqReceived, func(r *http.Request) { + assert.Equal(t, mediaTypeNDJSON, r.Header.Get("Accept")) + }) + } + + if c.expStreamingResponse { + onRespReceived = append(onRespReceived, func(r *http.Response) { + assert.Equal(t, mediaTypeNDJSON, r.Header.Get("Content-Type")) + }) + } + if c.expJSONResponse { + onRespReceived = append(onRespReceived, func(r *http.Response) { + assert.Equal(t, mediaTypeJSON, r.Header.Get("Content-Type")) + }) + } + + deps := makeTestDeps(t, clientOpts, serverOpts) + + deps.recordingHTTPClient.f = append(deps.recordingHTTPClient.f, onRespReceived...) + deps.recordingHandler.f = append(deps.recordingHandler.f, onReqReceived...) + client := deps.client router := deps.router + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + if c.httpStatusCode != 0 { deps.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(c.httpStatusCode) @@ -191,25 +300,16 @@ func TestClient_FindProviders(t *testing.T) { } cid := makeCID() + findProvsIter := iter.FromSlice(c.routerProvs) + router.On("FindProviders", mock.Anything, cid). - Return(c.routerProvs, c.routerErr) + Return(findProvsIter, c.routerErr) - provs, err := client.FindProviders(context.Background(), cid) + provsIter, err := client.FindProviders(ctx, cid) - var errList []string - if runtime.GOOS == "windows" && len(c.expWinErrContains) != 0 { - errList = c.expWinErrContains - } else { - errList = c.expErrContains - } - - for _, exp := range errList { - require.ErrorContains(t, err, exp) - } - if len(errList) == 0 { - require.NoError(t, err) - } + c.expErrContains.errContains(t, err) + provs := iter.ReadAll[iter.Result[types.ProviderResponse]](provsIter) assert.Equal(t, c.expProvs, provs) }) } @@ -247,8 +347,7 @@ func TestClient_Provide(t *testing.T) { name: "should return a 403 if the payload signature verification fails", cids: []cid.Cid{}, mangleSignature: true, - - expErrContains: "HTTP error with StatusCode=403", + expErrContains: "HTTP error with StatusCode=403", }, { name: "should return error if identity is not provided", @@ -274,8 +373,7 @@ func TestClient_Provide(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - // deps := makeTestDeps(t) - deps := makeTestDeps(t) + deps := makeTestDeps(t, nil, nil) client := deps.client router := deps.router diff --git a/routing/http/client/measures.go b/routing/http/client/measures.go index 942460518..9cc1911b3 100644 --- a/routing/http/client/measures.go +++ b/routing/http/client/measures.go @@ -23,6 +23,7 @@ var ( keyHost = tag.MustNewKey("host") keyStatusCode = tag.MustNewKey("code") keyError = tag.MustNewKey("error") + keyMediaType = tag.MustNewKey("mediatype") ViewLatency = &view.View{ Measure: measureLatency, @@ -42,6 +43,7 @@ var ( ) type measurement struct { + mediaType string operation string err error latency time.Duration @@ -51,32 +53,22 @@ type measurement struct { } func (m measurement) record(ctx context.Context) { - stats.RecordWithTags( - ctx, - []tag.Mutator{ - tag.Upsert(keyHost, m.host), - tag.Upsert(keyOperation, m.operation), - tag.Upsert(keyStatusCode, strconv.Itoa(m.statusCode)), - tag.Upsert(keyError, metricsErrStr(m.err)), - }, - measureLatency.M(m.latency.Milliseconds()), - ) - if m.err == nil { - stats.RecordWithTags( - ctx, - []tag.Mutator{ - tag.Upsert(keyHost, m.host), - tag.Upsert(keyOperation, m.operation), - }, - measureLength.M(int64(m.length)), - ) + muts := []tag.Mutator{ + tag.Upsert(keyHost, m.host), + tag.Upsert(keyOperation, m.operation), + tag.Upsert(keyStatusCode, strconv.Itoa(m.statusCode)), + tag.Upsert(keyError, metricsErrStr(m.err)), + tag.Upsert(keyMediaType, m.mediaType), } + stats.RecordWithTags(ctx, muts, measureLatency.M(m.latency.Milliseconds())) + stats.RecordWithTags(ctx, muts, measureLength.M(int64(m.length))) } -func newMeasurement(operation string) measurement { - return measurement{ +func newMeasurement(operation string) *measurement { + return &measurement{ operation: operation, host: "None", + mediaType: "None", } } diff --git a/routing/http/contentrouter/contentrouter.go b/routing/http/contentrouter/contentrouter.go index 572ac2655..8318a3163 100644 --- a/routing/http/contentrouter/contentrouter.go +++ b/routing/http/contentrouter/contentrouter.go @@ -7,6 +7,7 @@ import ( "github.com/ipfs/boxo/routing/http/internal" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/peer" @@ -21,7 +22,7 @@ const ttl = 24 * time.Hour type Client interface { ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl time.Duration) (time.Duration, error) - FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) + FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) } type contentRouter struct { @@ -101,24 +102,24 @@ func (c *contentRouter) Ready() bool { return true } -func (c *contentRouter) FindProvidersAsync(ctx context.Context, key cid.Cid, numResults int) <-chan peer.AddrInfo { - results, err := c.client.FindProviders(ctx, key) - if err != nil { - logger.Warnw("error finding providers", "CID", key, "Error", err) - ch := make(chan peer.AddrInfo) - close(ch) - return ch - } - - ch := make(chan peer.AddrInfo, len(results)) - for _, r := range results { - if r.GetSchema() == types.SchemaBitswap { - result, ok := r.(*types.ReadBitswapProviderRecord) +// readProviderResponses reads bitswap records from the iterator into the given channel, dropping non-bitswap records. +func readProviderResponses(iter iter.ResultIter[types.ProviderResponse], ch chan<- peer.AddrInfo) { + defer close(ch) + defer iter.Close() + for iter.Next() { + res := iter.Val() + if res.Err != nil { + logger.Warnw("error iterating provider responses: %s", res.Err) + continue + } + v := res.Val + if v.GetSchema() == types.SchemaBitswap { + result, ok := v.(*types.ReadBitswapProviderRecord) if !ok { logger.Errorw( "problem casting find providers result", - "Schema", r.GetSchema(), - "Type", reflect.TypeOf(r).String(), + "Schema", v.GetSchema(), + "Type", reflect.TypeOf(v).String(), ) continue } @@ -133,8 +134,18 @@ func (c *contentRouter) FindProvidersAsync(ctx context.Context, key cid.Cid, num Addrs: addrs, } } + } +} +func (c *contentRouter) FindProvidersAsync(ctx context.Context, key cid.Cid, numResults int) <-chan peer.AddrInfo { + resultsIter, err := c.client.FindProviders(ctx, key) + if err != nil { + logger.Warnw("error finding providers", "CID", key, "Error", err) + ch := make(chan peer.AddrInfo) + close(ch) + return ch } - close(ch) + ch := make(chan peer.AddrInfo) + go readProviderResponses(resultsIter, ch) return ch } diff --git a/routing/http/contentrouter/contentrouter_test.go b/routing/http/contentrouter/contentrouter_test.go index 643ad301f..4ca620c5d 100644 --- a/routing/http/contentrouter/contentrouter_test.go +++ b/routing/http/contentrouter/contentrouter_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multihash" @@ -21,9 +22,10 @@ func (m *mockClient) ProvideBitswap(ctx context.Context, keys []cid.Cid, ttl tim args := m.Called(ctx, keys, ttl) return args.Get(0).(time.Duration), args.Error(1) } -func (m *mockClient) FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) { + +func (m *mockClient) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).([]types.ProviderResponse), args.Error(1) + return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockClient) Ready(ctx context.Context) (bool, error) { args := m.Called(ctx) @@ -120,8 +122,9 @@ func TestFindProvidersAsync(t *testing.T) { Protocol: "UNKNOWN", }, } + aisIter := iter.ToResultIter[types.ProviderResponse](iter.FromSlice(ais)) - client.On("FindProviders", ctx, key).Return(ais, nil) + client.On("FindProviders", ctx, key).Return(aisIter, nil) aiChan := crc.FindProvidersAsync(ctx, key, 2) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index ca650ff28..8ce7d063b 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -7,12 +7,16 @@ import ( "errors" "fmt" "io" + "mime" "net/http" + "strings" "time" "github.com/gorilla/mux" "github.com/ipfs/boxo/routing/http/internal/drjson" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" + jsontypes "github.com/ipfs/boxo/routing/http/types/json" "github.com/ipfs/go-cid" "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multiaddr" @@ -20,13 +24,24 @@ import ( logging "github.com/ipfs/go-log/v2" ) +const ( + mediaTypeJSON = "application/json" + mediaTypeNDJSON = "application/x-ndjson" + mediaTypeWildcard = "*/*" +) + var logger = logging.Logger("service/server/delegatedrouting") const ProvidePath = "/routing/v1/providers/" const FindProvidersPath = "/routing/v1/providers/{cid}" +type FindProvidersAsyncResponse struct { + ProviderResponse types.ProviderResponse + Error error +} + type ContentRouter interface { - FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) + FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error) } @@ -45,9 +60,16 @@ type WriteProvideRequest struct { Bytes []byte } -type serverOption func(s *server) +type Option func(s *server) + +// WithStreamingResultsDisabled disables ndjson responses, so that the server only supports JSON responses. +func WithStreamingResultsDisabled() Option { + return func(s *server) { + s.disableNDJSON = true + } +} -func Handler(svc ContentRouter, opts ...serverOption) http.Handler { +func Handler(svc ContentRouter, opts ...Option) http.Handler { server := &server{ svc: svc, } @@ -64,11 +86,12 @@ func Handler(svc ContentRouter, opts ...serverOption) http.Handler { } type server struct { - svc ContentRouter + svc ContentRouter + disableNDJSON bool } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { - req := types.WriteProvidersRequest{} + req := jsontypes.WriteProvidersRequest{} err := json.NewDecoder(httpReq.Body).Decode(&req) _ = httpReq.Body.Close() if err != nil { @@ -76,7 +99,7 @@ func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { return } - resp := types.WriteProvidersResponse{} + resp := jsontypes.WriteProvidersResponse{} for i, prov := range req.Providers { switch v := prov.(type) { @@ -131,7 +154,7 @@ func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { return } } - writeResult(w, "Provide", resp) + writeJSONResult(w, "Provide", resp) } func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { @@ -142,17 +165,103 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse CID: %w", err)) return } - providers, err := s.svc.FindProviders(httpReq.Context(), cid) + + var handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) + + var supportsNDJSON bool + var supportsJSON bool + acceptHeaders := httpReq.Header.Values("Accept") + if len(acceptHeaders) == 0 { + handlerFunc = s.findProvidersJSON + } else { + for _, acceptHeader := range acceptHeaders { + for _, accept := range strings.Split(acceptHeader, ",") { + mediaType, _, err := mime.ParseMediaType(accept) + if err != nil { + writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse Accept header: %w", err)) + return + } + + switch mediaType { + case mediaTypeJSON, mediaTypeWildcard: + supportsJSON = true + case mediaTypeNDJSON: + supportsNDJSON = true + } + } + } + + if supportsNDJSON && !s.disableNDJSON { + handlerFunc = s.findProvidersNDJSON + } else if supportsJSON { + handlerFunc = s.findProvidersJSON + } else { + writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types")) + return + } + } + + provIter, err := s.svc.FindProviders(httpReq.Context(), cid) if err != nil { writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return } - response := types.ReadProvidersResponse{Providers: providers} - writeResult(w, "FindProviders", response) + + handlerFunc(w, provIter) +} + +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) { + defer provIter.Close() + + var ( + providers []types.ProviderResponse + i int + ) + + for provIter.Next() { + res := provIter.Val() + if res.Err != nil { + writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error on result %d: %w", i, res.Err)) + return + } + providers = append(providers, res.Val) + i++ + } + response := jsontypes.ReadProvidersResponse{Providers: providers} + writeJSONResult(w, "FindProviders", response) +} + +func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.ProviderResponse]) { + defer provIter.Close() + + w.Header().Set("Content-Type", mediaTypeNDJSON) + w.WriteHeader(http.StatusOK) + for provIter.Next() { + res := provIter.Val() + if res.Err != nil { + logger.Errorw("FindProviders ndjson iterator error", "Error", res.Err) + return + } + // don't use an encoder because we can't easily differentiate writer errors from encoding errors + b, err := drjson.MarshalJSONBytes(res.Val) + if err != nil { + logger.Errorw("FindProviders ndjson marshal error", "Error", err) + return + } + + _, err = w.Write(b) + if err != nil { + logger.Warn("FindProviders ndjson write error", "Error", err) + return + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } } -func writeResult(w http.ResponseWriter, method string, val any) { - w.Header().Add("Content-Type", "application/json") +func writeJSONResult(w http.ResponseWriter, method string, val any) { + w.Header().Add("Content-Type", mediaTypeJSON) // keep the marshaling separate from the writing, so we can distinguish bugs (which surface as 500) // from transient network issues (which surface as transport errors) diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 5609b2ed3..fec5eaf9a 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -19,28 +20,29 @@ func TestHeaders(t *testing.T) { t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - result := []types.ProviderResponse{ - &types.ReadBitswapProviderRecord{ + results := iter.FromSlice([]iter.Result[types.ProviderResponse]{ + {Val: &types.ReadBitswapProviderRecord{ Protocol: "transport-bitswap", Schema: types.SchemaBitswap, - }, - } + }}}, + ) c := "baeabep4vu3ceru7nerjjbk37sxb7wmftteve4hcosmyolsbsiubw2vr6pqzj6mw7kv6tbn6nqkkldnklbjgm5tzbi4hkpkled4xlcr7xz4bq" cb, err := cid.Decode(c) require.NoError(t, err) router.On("FindProviders", mock.Anything, cb). - Return(result, nil) + Return(results, nil) resp, err := http.Get(serverAddr + ProvidePath + c) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) header := resp.Header.Get("Content-Type") - require.Equal(t, "application/json", header) + require.Equal(t, mediaTypeJSON, header) resp, err = http.Get(serverAddr + ProvidePath + "BAD_CID") require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, 400, resp.StatusCode) header = resp.Header.Get("Content-Type") require.Equal(t, "text/plain; charset=utf-8", header) @@ -48,9 +50,9 @@ func TestHeaders(t *testing.T) { type mockContentRouter struct{ mock.Mock } -func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) ([]types.ProviderResponse, error) { +func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) { args := m.Called(ctx, key) - return args.Get(0).([]types.ProviderResponse), args.Error(1) + return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1) } func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) { args := m.Called(ctx, req) diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go new file mode 100644 index 000000000..67c6dde00 --- /dev/null +++ b/routing/http/types/iter/iter.go @@ -0,0 +1,46 @@ +package iter + +// Iter is an iterator of arbitrary values. +// Iterators are generally not goroutine-safe, to make them safe just read from them into a channel. +// For our use cases, these usually have a single reader. This motivates iterators instead of channels, +// since the overhead of goroutines+channels has a significant performance cost. +// Using an iterator, you can read results directly without necessarily involving the Go scheduler. +// +// There are a lot of options for an iterator interface, this one was picked for ease-of-use +// and for highest probability of consumers using it correctly. +// E.g. because there is a separate method for the value, it's easier to use in a loop but harder to implement. +// +// Hopefully in the future, Go will include an iterator in the language and we can remove this. +type Iter[T any] interface { + // Next sets the iterator to the next value, returning true if an attempt was made to get the next value. + Next() bool + Val() T + // Close closes the iterator and any underlying resources. Failure to close an iterator may result in resource leakage (goroutines, FDs, conns, etc.). + Close() error +} + +type ResultIter[T any] interface{ Iter[Result[T]] } + +type Result[T any] struct { + Val T + Err error +} + +// ToResultIter returns an iterator that wraps each value in a Result. +func ToResultIter[T any](iter Iter[T]) Iter[Result[T]] { + return Map(iter, func(t T) Result[T] { + return Result[T]{Val: t} + }) +} + +func ReadAll[T any](iter Iter[T]) []T { + if iter == nil { + return nil + } + defer iter.Close() + var vs []T + for iter.Next() { + vs = append(vs, iter.Val()) + } + return vs +} diff --git a/routing/http/types/iter/json.go b/routing/http/types/iter/json.go new file mode 100644 index 000000000..428331e28 --- /dev/null +++ b/routing/http/types/iter/json.go @@ -0,0 +1,60 @@ +package iter + +import ( + "encoding/json" + "errors" + "io" +) + +// FromReaderJSON returns an iterator over the given reader that reads whitespace-delimited JSON values. +func FromReaderJSON[T any](r io.Reader) *JSONIter[T] { + return &JSONIter[T]{Decoder: json.NewDecoder(r), Reader: r} +} + +// JSONIter iterates over whitespace-delimited JSON values of a byte stream. +// This closes the reader if it is a closer, to faciliate easy reading of HTTP responses. +type JSONIter[T any] struct { + Decoder *json.Decoder + Reader io.Reader + + done bool + res Result[T] +} + +func (j *JSONIter[T]) Next() bool { + var val T + + if j.done { + return false + } + + err := j.Decoder.Decode(&val) + + j.res.Val, j.res.Err = val, err + + // EOF is not an error, it just marks the end of iteration + if errors.Is(err, io.EOF) { + j.done = true + j.res.Err = nil + return false + } + + // stop iterating on an error + if j.res.Err != nil { + j.done = true + } + + return true +} + +func (j *JSONIter[T]) Val() Result[T] { + return j.res +} + +func (j *JSONIter[T]) Close() error { + j.done = true + if closer, ok := j.Reader.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/routing/http/types/iter/json_test.go b/routing/http/types/iter/json_test.go new file mode 100644 index 000000000..99c3bde07 --- /dev/null +++ b/routing/http/types/iter/json_test.go @@ -0,0 +1,73 @@ +package iter + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONIter(t *testing.T) { + type obj struct { + A int + } + + type expResult struct { + val obj + errContains string + } + + for _, c := range []struct { + jsonStr string + expResults []expResult + }{ + { + jsonStr: "{\"a\":1}\n{\"a\":2}", + expResults: []expResult{ + {val: obj{A: 1}}, + {val: obj{A: 2}}, + }, + }, + { + jsonStr: "", + expResults: nil, + }, + { + jsonStr: "\n", + expResults: nil, + }, + { + jsonStr: "{\"a\":1}{\"a\":2}", + expResults: []expResult{ + {val: obj{A: 1}}, + {val: obj{A: 2}}, + }, + }, + { + jsonStr: "{\"a\":1}{\"a\":asdf}", + expResults: []expResult{ + {val: obj{A: 1}}, + {errContains: "invalid character"}, + }, + }, + } { + t.Run(c.jsonStr, func(t *testing.T) { + reader := bytes.NewReader([]byte(c.jsonStr)) + iter := FromReaderJSON[obj](reader) + results := ReadAll[Result[obj]](iter) + + require.Len(t, results, len(c.expResults)) + for i, res := range results { + expRes := c.expResults[i] + if expRes.errContains != "" { + assert.ErrorContains(t, res.Err, expRes.errContains) + } else { + assert.NoError(t, res.Err) + assert.Equal(t, expRes.val, res.Val) + } + } + }) + } + +} diff --git a/routing/http/types/iter/map.go b/routing/http/types/iter/map.go new file mode 100644 index 000000000..b7e636f3f --- /dev/null +++ b/routing/http/types/iter/map.go @@ -0,0 +1,39 @@ +package iter + +// Map invokes f on each element of iter. +func Map[T any, U any](iter Iter[T], f func(t T) U) *MapIter[T, U] { + return &MapIter[T, U]{iter: iter, f: f} +} + +type MapIter[T any, U any] struct { + iter Iter[T] + f func(T) U + + done bool + val U +} + +func (m *MapIter[T, U]) Next() bool { + if m.done { + return false + } + + ok := m.iter.Next() + m.done = !ok + + if m.done { + return false + } + + m.val = m.f(m.iter.Val()) + + return true +} + +func (m *MapIter[T, U]) Val() U { + return m.val +} + +func (m *MapIter[T, U]) Close() error { + return m.iter.Close() +} diff --git a/routing/http/types/iter/map_test.go b/routing/http/types/iter/map_test.go new file mode 100644 index 000000000..d0605df85 --- /dev/null +++ b/routing/http/types/iter/map_test.go @@ -0,0 +1,41 @@ +package iter + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMap(t *testing.T) { + for _, c := range []struct { + input Iter[int] + f func(int) int + expResults []int + }{ + { + input: FromSlice([]int{1, 2, 3}), + f: func(i int) int { return i + 1 }, + expResults: []int{2, 3, 4}, + }, + { + input: FromSlice([]int{}), + f: func(i int) int { return i + 1 }, + expResults: nil, + }, + { + input: FromSlice([]int{1}), + f: func(i int) int { return i + 1 }, + expResults: []int{2}, + }, + } { + t.Run(fmt.Sprintf("%v", c.input), func(t *testing.T) { + iter := Map(c.input, c.f) + var res []int + for iter.Next() { + res = append(res, iter.Val()) + } + assert.Equal(t, c.expResults, res) + }) + } +} diff --git a/routing/http/types/iter/slice.go b/routing/http/types/iter/slice.go new file mode 100644 index 000000000..fc3ecae15 --- /dev/null +++ b/routing/http/types/iter/slice.go @@ -0,0 +1,29 @@ +package iter + +// FromSlice returns an iterator over the given slice. +func FromSlice[T any](s []T) *SliceIter[T] { + return &SliceIter[T]{Slice: s, i: -1} +} + +type SliceIter[T any] struct { + Slice []T + i int + val T +} + +func (s *SliceIter[T]) Next() bool { + s.i++ + if s.i >= len(s.Slice) { + return false + } + s.val = s.Slice[s.i] + return true +} + +func (s *SliceIter[T]) Val() T { + return s.val +} + +func (s *SliceIter[T]) Close() error { + return nil +} diff --git a/routing/http/types/iter/slice_test.go b/routing/http/types/iter/slice_test.go new file mode 100644 index 000000000..a3c8ff619 --- /dev/null +++ b/routing/http/types/iter/slice_test.go @@ -0,0 +1,37 @@ +package iter + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSliceIter(t *testing.T) { + for _, c := range []struct { + slice []int + expSlice []int + }{ + { + slice: []int{1, 2, 3}, + expSlice: []int{1, 2, 3}, + }, + { + slice: nil, + expSlice: nil, + }, + { + slice: []int{}, + expSlice: nil, + }, + } { + t.Run(fmt.Sprintf("%+v", c.slice), func(t *testing.T) { + iter := FromSlice(c.slice) + var vals []int + for iter.Next() { + vals = append(vals, iter.Val()) + } + require.Equal(t, c.expSlice, vals) + }) + } +} diff --git a/routing/http/types/json/provider.go b/routing/http/types/json/provider.go new file mode 100644 index 000000000..351197338 --- /dev/null +++ b/routing/http/types/json/provider.go @@ -0,0 +1,116 @@ +package json + +import ( + "encoding/json" + + "github.com/ipfs/boxo/routing/http/types" +) + +// ReadProvidersResponse is the result of a Provide request +type ReadProvidersResponse struct { + Providers []types.ProviderResponse +} + +func (r *ReadProvidersResponse) UnmarshalJSON(b []byte) error { + var tempFPR struct{ Providers []json.RawMessage } + err := json.Unmarshal(b, &tempFPR) + if err != nil { + return err + } + + for _, provBytes := range tempFPR.Providers { + var readProv types.UnknownProviderRecord + err := json.Unmarshal(provBytes, &readProv) + if err != nil { + return err + } + + switch readProv.Schema { + case types.SchemaBitswap: + var prov types.ReadBitswapProviderRecord + err := json.Unmarshal(readProv.Bytes, &prov) + if err != nil { + return err + } + r.Providers = append(r.Providers, &prov) + default: + r.Providers = append(r.Providers, &readProv) + } + + } + return nil +} + +type WriteProvidersRequest struct { + Providers []types.WriteProviderRecord +} + +func (r *WriteProvidersRequest) UnmarshalJSON(b []byte) error { + type wpr struct{ Providers []json.RawMessage } + var tempWPR wpr + err := json.Unmarshal(b, &tempWPR) + if err != nil { + return err + } + + for _, provBytes := range tempWPR.Providers { + var rawProv types.UnknownProviderRecord + err := json.Unmarshal(provBytes, &rawProv) + if err != nil { + return err + } + + switch rawProv.Schema { + case types.SchemaBitswap: + var prov types.WriteBitswapProviderRecord + err := json.Unmarshal(rawProv.Bytes, &prov) + if err != nil { + return err + } + r.Providers = append(r.Providers, &prov) + default: + var prov types.UnknownProviderRecord + err := json.Unmarshal(b, &prov) + if err != nil { + return err + } + r.Providers = append(r.Providers, &prov) + } + } + return nil +} + +// WriteProvidersResponse is the result of a Provide operation +type WriteProvidersResponse struct { + ProvideResults []types.ProviderResponse +} + +func (r *WriteProvidersResponse) UnmarshalJSON(b []byte) error { + var tempWPR struct{ ProvideResults []json.RawMessage } + err := json.Unmarshal(b, &tempWPR) + if err != nil { + return err + } + + for _, provBytes := range tempWPR.ProvideResults { + var rawProv types.UnknownProviderRecord + err := json.Unmarshal(provBytes, &rawProv) + if err != nil { + return err + } + + switch rawProv.Schema { + case types.SchemaBitswap: + var prov types.WriteBitswapProviderRecordResponse + err := json.Unmarshal(rawProv.Bytes, &prov) + if err != nil { + return err + } + r.ProvideResults = append(r.ProvideResults, &prov) + default: + r.ProvideResults = append(r.ProvideResults, &rawProv) + } + } + + return nil +} diff --git a/routing/http/types/ndjson/provider.go b/routing/http/types/ndjson/provider.go new file mode 100644 index 000000000..38e28df9a --- /dev/null +++ b/routing/http/types/ndjson/provider.go @@ -0,0 +1,36 @@ +package ndjson + +import ( + "encoding/json" + "io" + + "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" +) + +// NewReadProvidersResponseIter returns an iterator that reads Read Provider Records from the given reader. +func NewReadProvidersResponseIter(r io.Reader) iter.Iter[iter.Result[types.ProviderResponse]] { + jsonIter := iter.FromReaderJSON[types.UnknownProviderRecord](r) + mapFn := func(upr iter.Result[types.UnknownProviderRecord]) iter.Result[types.ProviderResponse] { + var result iter.Result[types.ProviderResponse] + if upr.Err != nil { + result.Err = upr.Err + return result + } + switch upr.Val.Schema { + case types.SchemaBitswap: + var prov types.ReadBitswapProviderRecord + err := json.Unmarshal(upr.Val.Bytes, &prov) + if err != nil { + result.Err = err + return result + } + result.Val = &prov + default: + result.Val = &upr.Val + } + return result + } + + return iter.Map[iter.Result[types.UnknownProviderRecord]](jsonIter, mapFn) +} diff --git a/routing/http/types/provider.go b/routing/http/types/provider.go index ef9e95ada..6e8e303f7 100644 --- a/routing/http/types/provider.go +++ b/routing/http/types/provider.go @@ -1,9 +1,5 @@ package types -import ( - "encoding/json" -) - // WriteProviderRecord is a type that enforces structs to imlement it to avoid confusion type WriteProviderRecord interface { IsWriteProviderRecord() @@ -14,129 +10,8 @@ type ReadProviderRecord interface { IsReadProviderRecord() } -type WriteProvidersRequest struct { - Providers []WriteProviderRecord -} - -func (r *WriteProvidersRequest) UnmarshalJSON(b []byte) error { - type wpr struct { - Providers []json.RawMessage - } - var tempWPR wpr - err := json.Unmarshal(b, &tempWPR) - if err != nil { - return err - } - - for _, provBytes := range tempWPR.Providers { - var rawProv UnknownProviderRecord - err := json.Unmarshal(provBytes, &rawProv) - if err != nil { - return err - } - - switch rawProv.Schema { - case SchemaBitswap: - var prov WriteBitswapProviderRecord - err := json.Unmarshal(rawProv.Bytes, &prov) - if err != nil { - return err - } - r.Providers = append(r.Providers, &prov) - default: - var prov UnknownProviderRecord - err := json.Unmarshal(b, &prov) - if err != nil { - return err - } - r.Providers = append(r.Providers, &prov) - } - } - return nil -} - // ProviderResponse is implemented for any ProviderResponse. It needs to have a Protocol field. type ProviderResponse interface { GetProtocol() string GetSchema() string } - -// WriteProvidersResponse is the result of a Provide operation -type WriteProvidersResponse struct { - ProvideResults []ProviderResponse -} - -// rawWriteProvidersResponse is a helper struct to make possible to parse WriteProvidersResponse's -type rawWriteProvidersResponse struct { - ProvideResults []json.RawMessage -} - -func (r *WriteProvidersResponse) UnmarshalJSON(b []byte) error { - var tempWPR rawWriteProvidersResponse - err := json.Unmarshal(b, &tempWPR) - if err != nil { - return err - } - - for _, provBytes := range tempWPR.ProvideResults { - var rawProv UnknownProviderRecord - err := json.Unmarshal(provBytes, &rawProv) - if err != nil { - return err - } - - switch rawProv.Schema { - case SchemaBitswap: - var prov WriteBitswapProviderRecordResponse - err := json.Unmarshal(rawProv.Bytes, &prov) - if err != nil { - return err - } - r.ProvideResults = append(r.ProvideResults, &prov) - default: - r.ProvideResults = append(r.ProvideResults, &rawProv) - } - } - - return nil -} - -// ReadProvidersResponse is the result of a Provide request -type ReadProvidersResponse struct { - Providers []ProviderResponse -} - -// rawReadProvidersResponse is a helper struct to make possible to parse ReadProvidersResponse's -type rawReadProvidersResponse struct { - Providers []json.RawMessage -} - -func (r *ReadProvidersResponse) UnmarshalJSON(b []byte) error { - var tempFPR rawReadProvidersResponse - err := json.Unmarshal(b, &tempFPR) - if err != nil { - return err - } - - for _, provBytes := range tempFPR.Providers { - var readProv UnknownProviderRecord - err := json.Unmarshal(provBytes, &readProv) - if err != nil { - return err - } - - switch readProv.Schema { - case SchemaBitswap: - var prov ReadBitswapProviderRecord - err := json.Unmarshal(readProv.Bytes, &prov) - if err != nil { - return err - } - r.Providers = append(r.Providers, &prov) - default: - r.Providers = append(r.Providers, &readProv) - } - - } - return nil -}