Skip to content

Commit

Permalink
Remove callbacks from fetch::poll
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Jan 17, 2024
1 parent 226f886 commit 6ab4d1a
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 507 deletions.
52 changes: 23 additions & 29 deletions fetch/mesh_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,69 +201,63 @@ func (f *Fetch) GetPoetProof(ctx context.Context, id types.Hash32) error {
}
}

func (f *Fetch) GetMaliciousIDs(
ctx context.Context,
peers []p2p.Peer,
okCB func([]byte, p2p.Peer),
errCB func(error, p2p.Peer),
) error {
return poll(ctx, f.servers[malProtocol], peers, []byte{}, okCB, errCB)
func (f *Fetch) GetMaliciousIDs(ctx context.Context, peers []p2p.Peer) <-chan Result {
return poll(ctx, f.servers[malProtocol], peers, []byte{})
}

// GetLayerData get layer data from peers.
func (f *Fetch) GetLayerData(
ctx context.Context,
peers []p2p.Peer,
lid types.LayerID,
okCB func([]byte, p2p.Peer),
errCB func(error, p2p.Peer),
) error {
) (<-chan Result, error) {
lidBytes, err := codec.Encode(&lid)
if err != nil {
return err
return nil, err

Check warning on line 216 in fetch/mesh_data.go

View check run for this annotation

Codecov / codecov/patch

fetch/mesh_data.go#L216

Added line #L216 was not covered by tests
}
return poll(ctx, f.servers[lyrDataProtocol], peers, lidBytes, okCB, errCB)
return poll(ctx, f.servers[lyrDataProtocol], peers, lidBytes), nil
}

func (f *Fetch) GetLayerOpinions(
ctx context.Context,
peers []p2p.Peer,
lid types.LayerID,
okCB func([]byte, p2p.Peer),
errCB func(error, p2p.Peer),
) error {
) (<-chan Result, error) {

Check warning on line 225 in fetch/mesh_data.go

View check run for this annotation

Codecov / codecov/patch

fetch/mesh_data.go#L225

Added line #L225 was not covered by tests
req := OpinionRequest{
Layer: lid,
}
reqData, err := codec.Encode(&req)
if err != nil {
return err
return nil, err

Check warning on line 231 in fetch/mesh_data.go

View check run for this annotation

Codecov / codecov/patch

fetch/mesh_data.go#L231

Added line #L231 was not covered by tests
}
return poll(ctx, f.servers[OpnProtocol], peers, reqData, okCB, errCB)
return poll(ctx, f.servers[OpnProtocol], peers, reqData), nil

Check warning on line 233 in fetch/mesh_data.go

View check run for this annotation

Codecov / codecov/patch

fetch/mesh_data.go#L233

Added line #L233 was not covered by tests
}

type Result struct {
Data []byte
Peer p2p.Peer
Err error
}

func poll(
ctx context.Context,
srv requester,
peers []p2p.Peer,
req []byte,
okCB func([]byte, p2p.Peer),
errCB func(error, p2p.Peer),
) error {
var eg errgroup.Group
) <-chan Result {
result := make(chan Result, len(peers))
for _, p := range peers {
peer := p
eg.Go(func() error {
go func() {
data, err := srv.Request(ctx, peer, req)
if err != nil {
errCB(err, peer)
} else {
okCB(data, peer)
result <- Result{
Data: data,
Peer: peer,
Err: err,
}
return nil
})
}()
}
return nil
return result
}

// PeerEpochInfo get the epoch info published in the given epoch from the specified peer.
Expand Down
134 changes: 57 additions & 77 deletions fetch/mesh_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ package fetch
import (
"context"
"errors"
"os"
"sync"
"testing"

mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"

"github.com/spacemeshos/go-spacemesh/activation"
Expand All @@ -31,15 +30,6 @@ const (
txsForProposal = iota
)

const layersPerEpoch = 3

func TestMain(m *testing.M) {
types.SetLayersPerEpoch(layersPerEpoch)

res := m.Run()
os.Exit(res)
}

func (f *testFetch) withMethod(method int) *testFetch {
f.method = method
return f
Expand Down Expand Up @@ -157,7 +147,7 @@ func TestFetch_getHashes(t *testing.T) {
for _, peer := range peers {
f.peers.Add(peer)
}
f.mh.EXPECT().ID().Return(p2p.Peer("self")).AnyTimes()
f.mh.EXPECT().ID().Return("self").AnyTimes()
f.RegisterPeerHashes(peers[0], hashes[:2])
f.RegisterPeerHashes(peers[1], hashes[2:])

Expand Down Expand Up @@ -477,128 +467,118 @@ func TestGetPoetProof(t *testing.T) {
}

func TestFetch_GetMaliciousIDs(t *testing.T) {
peers := []p2p.Peer{"p0", "p1", "p3", "p4"}
errUnknown := errors.New("unknown")
tt := []struct {
name string
errs []error
name string
peers map[p2p.Peer]error
}{
{
name: "all peers returns",
errs: []error{nil, nil, nil, nil},
name: "all peers returns",
peers: map[p2p.Peer]error{"p0": nil, "p1": nil, "p2": nil, "p3": nil},
},
{
name: "some peers errors",
errs: []error{nil, errUnknown, nil, errUnknown},
name: "some peers errors",
peers: map[p2p.Peer]error{"p0": nil, "p1": errUnknown, "p2": nil, "p3": errUnknown},
},
}

for _, tc := range tt {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

require.Equal(t, len(peers), len(tc.errs))
f := createFetch(t)
oks := make(chan struct{}, len(peers))
errs := make(chan struct{}, len(peers))
var wg sync.WaitGroup
wg.Add(len(peers))
okFunc := func([]byte, p2p.Peer) {
oks <- struct{}{}
wg.Done()
}
errFunc := func(error, p2p.Peer) {
errs <- struct{}{}
wg.Done()
}

var expOk, expErr int
for i, p := range peers {
if tc.errs[i] == nil {
for peer, err := range tc.peers {
err := err
if err == nil {
expOk++
} else {
expErr++
}
idx := i
f.mMalS.EXPECT().
Request(gomock.Any(), p, []byte{}).
Request(gomock.Any(), peer, []byte{}).
DoAndReturn(
func(_ context.Context, _ p2p.Peer, _ []byte) ([]byte, error) {
if tc.errs[idx] == nil {
if err == nil {
return generateMaliciousIDs(t), nil
}
return nil, tc.errs[idx]
return nil, err
})
}
require.NoError(t, f.GetMaliciousIDs(context.Background(), peers, okFunc, errFunc))
wg.Wait()
require.Len(t, oks, expOk)
require.Len(t, errs, expErr)
resp := f.GetMaliciousIDs(context.Background(), maps.Keys(tc.peers))
var oks, errs int
for i := 0; i < len(tc.peers); i++ {
r := <-resp
require.ErrorIs(t, r.Err, tc.peers[r.Peer])
if r.Err != nil {
errs += 1
} else {
oks += 1
}

}
require.Equal(t, oks, expOk)
require.Equal(t, errs, expErr)
})
}
}

func TestFetch_GetLayerData(t *testing.T) {
peers := []p2p.Peer{"p0", "p1", "p3", "p4"}
errUnknown := errors.New("unknown")
tt := []struct {
name string
errs []error
name string
peers map[p2p.Peer]error
}{
{
name: "all peers returns",
errs: []error{nil, nil, nil, nil},
name: "all peers returns",
peers: map[p2p.Peer]error{"p0": nil, "p1": nil, "p2": nil, "p3": nil},
},
{
name: "some peers errors",
errs: []error{nil, errUnknown, nil, errUnknown},
name: "some peers errors",
peers: map[p2p.Peer]error{"p0": nil, "p1": errUnknown, "p2": nil, "p3": errUnknown},
},
}

for _, tc := range tt {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

require.Equal(t, len(peers), len(tc.errs))
f := createFetch(t)
oks := make(chan struct{}, len(peers))
errs := make(chan struct{}, len(peers))
var wg sync.WaitGroup
wg.Add(len(peers))
okFunc := func(data []byte, peer p2p.Peer) {
oks <- struct{}{}
wg.Done()
}
errFunc := func(err error, peer p2p.Peer) {
errs <- struct{}{}
wg.Done()
}

var expOk, expErr int
for i, p := range peers {
if tc.errs[i] == nil {
for peer, err := range tc.peers {
err := err
if err == nil {
expOk++
} else {
expErr++
}
idx := i
f.mLyrS.EXPECT().
Request(gomock.Any(), p, gomock.Any()).
Request(gomock.Any(), peer, gomock.Any()).
DoAndReturn(
func(_ context.Context, _ p2p.Peer, _ []byte) ([]byte, error) {
if err := tc.errs[idx]; err != nil {
if err != nil {
return nil, err
}
return generateLayerContent(t), nil
})
}
require.NoError(
t,
f.GetLayerData(context.Background(), peers, types.LayerID(111), okFunc, errFunc),
)
wg.Wait()
require.Len(t, oks, expOk)
require.Len(t, errs, expErr)
resp, err := f.GetLayerData(context.Background(), maps.Keys(tc.peers), types.LayerID(111))
require.NoError(t, err)
var oks, errs int
for i := 0; i < len(tc.peers); i++ {
r := <-resp
require.ErrorIs(t, r.Err, tc.peers[r.Peer])
if r.Err != nil {
errs += 1
} else {
oks += 1
}

}
require.Equal(t, oks, expOk)
require.Equal(t, errs, expErr)
})
}
}
Expand Down Expand Up @@ -635,7 +615,7 @@ func Test_PeerEpochInfo(t *testing.T) {
t.Parallel()

f := createFetch(t)
f.mh.EXPECT().ID().Return(p2p.Peer("self")).AnyTimes()
f.mh.EXPECT().ID().Return("self").AnyTimes()
var expected *EpochData
f.mAtxS.EXPECT().
Request(gomock.Any(), peer, gomock.Any()).
Expand Down
Loading

0 comments on commit 6ab4d1a

Please sign in to comment.