From 4a24cee3645334f5d258b35ac9c9cb6d86f6ca12 Mon Sep 17 00:00:00 2001 From: LinZexiao <55120714+LinZexiao@users.noreply.github.com> Date: Wed, 26 Oct 2022 16:51:15 +0800 Subject: [PATCH] feat: divide unit test for models layer (#225) --- models/badger/cid_info_test.go | 100 +++-- models/badger/fund_test.go | 55 ++- models/badger/paych_test.go | 282 +++++++++----- models/badger/retrieval_ask_test.go | 60 +-- models/badger/retrieval_deal_test.go | 126 ++++--- models/badger/storage_ask_test.go | 65 ++-- models/badger/storage_deal_test.go | 526 ++++++++++++++++++--------- models/cid_info_test.go | 22 +- models/fund_test.go | 16 +- models/mysql/cid_info_test.go | 31 +- models/mysql/fund_test.go | 55 +-- models/mysql/paych_test.go | 86 +++-- models/mysql/retrieval_ask_test.go | 29 +- models/mysql/retrieval_deal_test.go | 53 +-- models/mysql/storage_ask_test.go | 29 +- models/mysql/storage_deal_test.go | 217 ++++++----- models/mysql/testing.go | 6 - models/paych_test.go | 24 +- models/retrieval_ask_test.go | 22 +- models/storage_ask_test.go | 21 +- models/storage_deal_test.go | 32 +- 21 files changed, 1171 insertions(+), 686 deletions(-) diff --git a/models/badger/cid_info_test.go b/models/badger/cid_info_test.go index e519b1de..32f02ca0 100644 --- a/models/badger/cid_info_test.go +++ b/models/badger/cid_info_test.go @@ -5,54 +5,88 @@ import ( "testing" "github.com/filecoin-project/go-fil-markets/piecestore" + "github.com/filecoin-project/venus-market/v2/models/repo" "github.com/filecoin-project/venus/venus-shared/testutil" "github.com/ipfs/go-cid" "github.com/stretchr/testify/assert" ) -func TestCidInfo(t *testing.T) { - ctx := context.Background() +func TestAddPieceBlockLocations(t *testing.T) { + ctx, r, cidInfoCases := prepareCidInfoTest(t) + + pieceCid2cidInfo := make(map[cid.Cid][]piecestore.CIDInfo) + for _, info := range cidInfoCases { + for _, location := range info.PieceBlockLocations { + if _, ok := pieceCid2cidInfo[location.PieceCID]; !ok { + pieceCid2cidInfo[location.PieceCID] = make([]piecestore.CIDInfo, 0) + } + pieceCid2cidInfo[location.PieceCID] = append(pieceCid2cidInfo[location.PieceCID], info) + } + } + + for pieceCid, cidInfo := range pieceCid2cidInfo { + playloadCid2location := make(map[cid.Cid]piecestore.BlockLocation) + for _, info := range cidInfo { + for _, location := range info.PieceBlockLocations { + playloadCid2location[info.CID] = location.BlockLocation + } + } + err := r.AddPieceBlockLocations(ctx, pieceCid, playloadCid2location) + assert.NoError(t, err) + } +} + +func TestGetCIDInfo(t *testing.T) { + ctx, r, cidInfoCases := prepareCidInfoTest(t) + + inSertCidInfo(ctx, t, r, cidInfoCases[0]) + res, err := r.GetCIDInfo(ctx, cidInfoCases[0].CID) + assert.NoError(t, err) + assert.Equal(t, cidInfoCases[0], res) +} + +func TestListCidInfoKeys(t *testing.T) { + ctx, r, cidInfoCases := prepareCidInfoTest(t) + + inSertCidInfo(ctx, t, r, cidInfoCases...) + + cidInfos, err := r.ListCidInfoKeys(ctx) + assert.NoError(t, err) + assert.Equal(t, len(cidInfoCases), len(cidInfos)) + for _, info := range cidInfoCases { + assert.Contains(t, cidInfos, info.CID) + } +} + +func prepareCidInfoTest(t *testing.T) (context.Context, repo.ICidInfoRepo, []piecestore.CIDInfo) { repo := setup(t) r := repo.CidInfoRepo() cidInfoCases := make([]piecestore.CIDInfo, 10) testutil.Provide(t, &cidInfoCases) - t.Run("AddPieceBlockLocations", func(t *testing.T) { - pieceCid2cidInfo := make(map[cid.Cid][]piecestore.CIDInfo) - for _, info := range cidInfoCases { - for _, location := range info.PieceBlockLocations { - if _, ok := pieceCid2cidInfo[location.PieceCID]; !ok { - pieceCid2cidInfo[location.PieceCID] = make([]piecestore.CIDInfo, 0) - } - pieceCid2cidInfo[location.PieceCID] = append(pieceCid2cidInfo[location.PieceCID], info) + return context.Background(), r, cidInfoCases +} + +func inSertCidInfo(ctx context.Context, t *testing.T, r repo.ICidInfoRepo, cidInfoCases ...piecestore.CIDInfo) { + pieceCid2cidInfo := make(map[cid.Cid][]piecestore.CIDInfo) + for _, info := range cidInfoCases { + for _, location := range info.PieceBlockLocations { + if _, ok := pieceCid2cidInfo[location.PieceCID]; !ok { + pieceCid2cidInfo[location.PieceCID] = make([]piecestore.CIDInfo, 0) } + pieceCid2cidInfo[location.PieceCID] = append(pieceCid2cidInfo[location.PieceCID], info) } + } - for pieceCid, cidInfo := range pieceCid2cidInfo { - playloadCid2location := make(map[cid.Cid]piecestore.BlockLocation) - for _, info := range cidInfo { - for _, location := range info.PieceBlockLocations { - playloadCid2location[info.CID] = location.BlockLocation - } + for pieceCid, cidInfo := range pieceCid2cidInfo { + playloadCid2location := make(map[cid.Cid]piecestore.BlockLocation) + for _, info := range cidInfo { + for _, location := range info.PieceBlockLocations { + playloadCid2location[info.CID] = location.BlockLocation } - err := r.AddPieceBlockLocations(ctx, pieceCid, playloadCid2location) - assert.NoError(t, err) } - }) - - t.Run("GetCIDInfo", func(t *testing.T) { - res, err := r.GetCIDInfo(ctx, cidInfoCases[0].CID) + err := r.AddPieceBlockLocations(ctx, pieceCid, playloadCid2location) assert.NoError(t, err) - assert.Equal(t, cidInfoCases[0], res) - }) - - t.Run("ListCidInfoKeys", func(t *testing.T) { - cidInfos, err := r.ListCidInfoKeys(ctx) - assert.NoError(t, err) - assert.Equal(t, len(cidInfoCases), len(cidInfos)) - for _, info := range cidInfoCases { - assert.Contains(t, cidInfos, info.CID) - } - }) + } } diff --git a/models/badger/fund_test.go b/models/badger/fund_test.go index 8f0202bc..42f90690 100644 --- a/models/badger/fund_test.go +++ b/models/badger/fund_test.go @@ -4,34 +4,43 @@ import ( "context" "testing" + "github.com/filecoin-project/venus-market/v2/models/repo" "github.com/filecoin-project/venus/venus-shared/testutil" types "github.com/filecoin-project/venus/venus-shared/types/market" "github.com/stretchr/testify/assert" ) -func TestFund(t *testing.T) { - ctx := context.Background() - repo := setup(t) - r := repo.FundRepo() +func TestSaveFundedAddressState(t *testing.T) { + ctx, r, fundedAddressStateCases := prepareFundTest(t) - fundedAddressStateCases := make([]types.FundedAddressState, 10) - testutil.Provide(t, &fundedAddressStateCases) + for _, state := range fundedAddressStateCases { + err := r.SaveFundedAddressState(ctx, &state) + assert.NoError(t, err) + } +} - t.Run("SaveFundedAddressState", func(t *testing.T) { - for _, state := range fundedAddressStateCases { - err := r.SaveFundedAddressState(ctx, &state) - assert.NoError(t, err) - } - }) +func TestGetFundedAddressState(t *testing.T) { + ctx, r, fundedAddressStateCases := prepareFundTest(t) - t.Run("GetFundedAddressState", func(t *testing.T) { - res, err := r.GetFundedAddressState(ctx, fundedAddressStateCases[0].Addr) + for _, state := range fundedAddressStateCases { + err := r.SaveFundedAddressState(ctx, &state) assert.NoError(t, err) - fundedAddressStateCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, fundedAddressStateCases[0], *res) - }) + } + + res, err := r.GetFundedAddressState(ctx, fundedAddressStateCases[0].Addr) + assert.NoError(t, err) + fundedAddressStateCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, fundedAddressStateCases[0], *res) +} + +func TestListFundedAddressState(t *testing.T) { + ctx, r, fundedAddressStateCases := prepareFundTest(t) + + for _, state := range fundedAddressStateCases { + err := r.SaveFundedAddressState(ctx, &state) + assert.NoError(t, err) + } - // refresh the UpdatedAt field of test cases for i := 0; i < len(fundedAddressStateCases); i++ { res, err := r.GetFundedAddressState(ctx, fundedAddressStateCases[i].Addr) assert.NoError(t, err) @@ -48,3 +57,13 @@ func TestFund(t *testing.T) { } }) } + +func prepareFundTest(t *testing.T) (context.Context, repo.FundRepo, []types.FundedAddressState) { + ctx := context.Background() + repo := setup(t) + r := repo.FundRepo() + + fundedAddressStateCases := make([]types.FundedAddressState, 10) + testutil.Provide(t, &fundedAddressStateCases) + return ctx, r, fundedAddressStateCases +} diff --git a/models/badger/paych_test.go b/models/badger/paych_test.go index d9c62088..542e5cad 100644 --- a/models/badger/paych_test.go +++ b/models/badger/paych_test.go @@ -14,51 +14,50 @@ import ( "github.com/stretchr/testify/assert" ) -func TestPaych(t *testing.T) { - ctx := context.Background() - repo := setup(t) - r := repo.PaychChannelInfoRepo() +func TestSaveChannel(t *testing.T) { + ctx, r, channelInfoCases := preparePaychTest(t) - channelInfoCases := make([]types.ChannelInfo, 10) - testutil.Provide(t, &channelInfoCases) - channelInfoCases[0].Direction = types.DirOutbound + for _, info := range channelInfoCases { + err := r.SaveChannel(ctx, &info) + assert.NoError(t, err) + } +} - t.Run("SaveChannel", func(t *testing.T) { - for _, info := range channelInfoCases { - err := r.SaveChannel(ctx, &info) - assert.NoError(t, err) - } - }) +func TestGetChannelByAddress(t *testing.T) { + ctx, r, channelInfoCases := preparePaychTest(t) - t.Run("GetChannelByAddress", func(t *testing.T) { - res, err := r.GetChannelByAddress(ctx, *channelInfoCases[0].Channel) + for _, info := range channelInfoCases { + err := r.SaveChannel(ctx, &info) assert.NoError(t, err) - channelInfoCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, channelInfoCases[0], *res) - }) + } + + res, err := r.GetChannelByAddress(ctx, *channelInfoCases[0].Channel) + assert.NoError(t, err) + channelInfoCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, channelInfoCases[0], *res) +} + +func TestGetChannelByChannelID(t *testing.T) { + ctx, r, channelInfoCases := preparePaychTest(t) - t.Run("GetChannelByChannelID", func(t *testing.T) { - res, err := r.GetChannelByChannelID(ctx, channelInfoCases[0].ChannelID) + for _, info := range channelInfoCases { + err := r.SaveChannel(ctx, &info) assert.NoError(t, err) - channelInfoCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, channelInfoCases[0], *res) - }) - - t.Run("WithPendingAddFunds", func(t *testing.T) { - expect := make([]types.ChannelInfo, 0) - for _, info := range channelInfoCases { - if info.Direction == types.DirOutbound && (info.CreateMsg != nil || info.AddFundsMsg != nil) { - expect = append(expect, info) - } - } + } + + res, err := r.GetChannelByChannelID(ctx, channelInfoCases[0].ChannelID) + assert.NoError(t, err) + channelInfoCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, channelInfoCases[0], *res) +} + +func TestWithPendingAddFunds(t *testing.T) { + ctx, r, channelInfoCases := preparePaychTest(t) - res, err := r.WithPendingAddFunds(ctx) + for _, info := range channelInfoCases { + err := r.SaveChannel(ctx, &info) assert.NoError(t, err) - assert.Equal(t, len(expect), len(res)) - for i := 0; i < len(res); i++ { - assert.Contains(t, expect, *res[i]) - } - }) + } // refresh the UpdatedAt field of test cases for i := 0; i < len(channelInfoCases); i++ { @@ -67,79 +66,180 @@ func TestPaych(t *testing.T) { channelInfoCases[i].UpdatedAt = res.UpdatedAt } - t.Run("ListChannel", func(t *testing.T) { - res, err := r.ListChannel(ctx) - assert.NoError(t, err) - assert.Equal(t, len(channelInfoCases), len(res)) - addrs := make([]address.Address, 0) - for _, info := range channelInfoCases { - addrs = append(addrs, *info.Channel) - } - for i := 0; i < len(res); i++ { - assert.Contains(t, addrs, res[i]) - } - }) - - t.Run("CreateChannel and GetChannelByMessageCid", func(t *testing.T) { - var paramsCase struct { - From address.Address - To address.Address - CreateMsg cid.Cid - Amt big.Int + expect := make([]types.ChannelInfo, 0) + for _, info := range channelInfoCases { + if info.Direction == types.DirOutbound && (info.CreateMsg != nil || info.AddFundsMsg != nil) { + expect = append(expect, info) } + } + + res, err := r.WithPendingAddFunds(ctx) + assert.NoError(t, err) + assert.Equal(t, len(expect), len(res)) + for i := 0; i < len(res); i++ { + assert.Contains(t, expect, *res[i]) + } +} - testutil.Provide(t, ¶msCase) +func TestListChannel(t *testing.T) { + ctx, r, channelInfoCases := preparePaychTest(t) - _, err := r.CreateChannel(ctx, paramsCase.From, paramsCase.To, paramsCase.CreateMsg, paramsCase.Amt) + for _, info := range channelInfoCases { + err := r.SaveChannel(ctx, &info) assert.NoError(t, err) + } - _, err = r.GetChannelByMessageCid(ctx, paramsCase.CreateMsg) + // refresh the UpdatedAt field of test cases + for i := 0; i < len(channelInfoCases); i++ { + res, err := r.GetChannelByAddress(ctx, *channelInfoCases[i].Channel) assert.NoError(t, err) - }) + channelInfoCases[i].UpdatedAt = res.UpdatedAt + } + + res, err := r.ListChannel(ctx) + assert.NoError(t, err) + assert.Equal(t, len(channelInfoCases), len(res)) + addrs := make([]address.Address, 0) + for _, info := range channelInfoCases { + addrs = append(addrs, *info.Channel) + } + for i := 0; i < len(res); i++ { + assert.Contains(t, addrs, res[i]) + } +} + +func TestCreateChannel(t *testing.T) { + ctx, r, _ := preparePaychTest(t) - t.Run("OutboundActiveByFromTo", func(t *testing.T) { - res, err := r.OutboundActiveByFromTo(ctx, channelInfoCases[0].From(), channelInfoCases[0].To()) + var paramsCase struct { + From address.Address + To address.Address + CreateMsg cid.Cid + Amt big.Int + } + testutil.Provide(t, ¶msCase) + + _, err := r.CreateChannel(ctx, paramsCase.From, paramsCase.To, paramsCase.CreateMsg, paramsCase.Amt) + assert.NoError(t, err) +} + +func TestGetChannelByMessageCid(t *testing.T) { + ctx, r, _ := preparePaychTest(t) + + var paramsCase struct { + From address.Address + To address.Address + CreateMsg cid.Cid + Amt big.Int + } + testutil.Provide(t, ¶msCase) + + _, err := r.CreateChannel(ctx, paramsCase.From, paramsCase.To, paramsCase.CreateMsg, paramsCase.Amt) + assert.NoError(t, err) + + _, err = r.GetChannelByMessageCid(ctx, paramsCase.CreateMsg) + assert.NoError(t, err) +} + +func TestOutboundActiveByFromTo(t *testing.T) { + ctx, r, channelInfoCases := preparePaychTest(t) + + channelInfoCases[0].Direction = types.DirOutbound + for _, info := range channelInfoCases { + err := r.SaveChannel(ctx, &info) assert.NoError(t, err) - assert.Equal(t, channelInfoCases[0], *res) - }) + } - t.Run("RemoveChannel", func(t *testing.T) { - err := r.RemoveChannel(ctx, channelInfoCases[0].ChannelID) + // refresh the UpdatedAt field of test cases + for i := 0; i < len(channelInfoCases); i++ { + res, err := r.GetChannelByAddress(ctx, *channelInfoCases[i].Channel) assert.NoError(t, err) - _, err = r.GetChannelByAddress(ctx, *channelInfoCases[0].Channel) - assert.True(t, errors.Is(err, mrepo.ErrNotFound)) - }) + channelInfoCases[i].UpdatedAt = res.UpdatedAt + } + + res, err := r.OutboundActiveByFromTo(ctx, channelInfoCases[0].From(), channelInfoCases[0].To()) + assert.NoError(t, err) + assert.Equal(t, channelInfoCases[0], *res) } -func TestMessage(t *testing.T) { - ctx := context.Background() - repo := setup(t) - r := repo.PaychMsgInfoRepo() +func TestRemoveChannel(t *testing.T) { + ctx, r, channelInfoCases := preparePaychTest(t) - messageInfoCases := make([]types.MsgInfo, 10) - testutil.Provide(t, &messageInfoCases) + channelInfoCases[0].Direction = types.DirOutbound + for _, info := range channelInfoCases { + err := r.SaveChannel(ctx, &info) + assert.NoError(t, err) + } - t.Run("SaveMessage", func(t *testing.T) { - for _, info := range messageInfoCases { - err := r.SaveMessage(ctx, &info) - assert.NoError(t, err) - } - }) + // refresh the UpdatedAt field of test cases + for i := 0; i < len(channelInfoCases); i++ { + res, err := r.GetChannelByAddress(ctx, *channelInfoCases[i].Channel) + assert.NoError(t, err) + channelInfoCases[i].UpdatedAt = res.UpdatedAt + } + + err := r.RemoveChannel(ctx, channelInfoCases[0].ChannelID) + assert.NoError(t, err) + _, err = r.GetChannelByAddress(ctx, *channelInfoCases[0].Channel) + assert.True(t, errors.Is(err, mrepo.ErrNotFound)) +} + +func TestSaveMessage(t *testing.T) { + ctx, r, messageInfoCases := preparePaychMsgTest(t) - t.Run("GetMessage", func(t *testing.T) { - res, err := r.GetMessage(ctx, messageInfoCases[0].MsgCid) + for _, info := range messageInfoCases { + err := r.SaveMessage(ctx, &info) assert.NoError(t, err) - messageInfoCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, messageInfoCases[0], *res) - }) + } +} - t.Run("SaveMessageResult", func(t *testing.T) { - err := r.SaveMessageResult(ctx, messageInfoCases[0].MsgCid, errors.New("test error")) +func TestGetMessage(t *testing.T) { + ctx, r, messageInfoCases := preparePaychMsgTest(t) + + for _, info := range messageInfoCases { + err := r.SaveMessage(ctx, &info) assert.NoError(t, err) + } + + res, err := r.GetMessage(ctx, messageInfoCases[0].MsgCid) + assert.NoError(t, err) + messageInfoCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, messageInfoCases[0], *res) +} + +func TestSaveMessageResult(t *testing.T) { + ctx, r, messageInfoCases := preparePaychMsgTest(t) - res, err := r.GetMessage(ctx, messageInfoCases[0].MsgCid) + for _, info := range messageInfoCases { + err := r.SaveMessage(ctx, &info) assert.NoError(t, err) + } + + err := r.SaveMessageResult(ctx, messageInfoCases[0].MsgCid, errors.New("test error")) + assert.NoError(t, err) + + res, err := r.GetMessage(ctx, messageInfoCases[0].MsgCid) + assert.NoError(t, err) + + assert.Equal(t, "test error", res.Err) +} + +func preparePaychTest(t *testing.T) (context.Context, mrepo.PaychChannelInfoRepo, []types.ChannelInfo) { + ctx := context.Background() + repo := setup(t).PaychChannelInfoRepo() + channelInfoCases := make([]types.ChannelInfo, 10) + testutil.Provide(t, &channelInfoCases) + channelInfoCases[0].Direction = types.DirOutbound + return ctx, repo, channelInfoCases +} + +func preparePaychMsgTest(t *testing.T) (context.Context, mrepo.PaychMsgInfoRepo, []types.MsgInfo) { + ctx := context.Background() + repo := setup(t) + r := repo.PaychMsgInfoRepo() + + messageInfoCases := make([]types.MsgInfo, 10) + testutil.Provide(t, &messageInfoCases) - assert.Equal(t, "test error", res.Err) - }) + return ctx, r, messageInfoCases } diff --git a/models/badger/retrieval_ask_test.go b/models/badger/retrieval_ask_test.go index 8cd52510..0ac71f86 100644 --- a/models/badger/retrieval_ask_test.go +++ b/models/badger/retrieval_ask_test.go @@ -4,12 +4,13 @@ import ( "context" "testing" + "github.com/filecoin-project/venus-market/v2/models/repo" "github.com/filecoin-project/venus/venus-shared/testutil" types "github.com/filecoin-project/venus/venus-shared/types/market" "github.com/stretchr/testify/assert" ) -func TestRetrievalAsk(t *testing.T) { +func prepareRetrievalAskTest(t *testing.T) (context.Context, repo.IRetrievalAskRepo, []types.RetrievalAsk) { ctx := context.Background() repo := setup(t) r := repo.RetrievalAskRepo() @@ -17,34 +18,51 @@ func TestRetrievalAsk(t *testing.T) { askCases := make([]types.RetrievalAsk, 10) testutil.Provide(t, &askCases) - t.Run("SetAsk", func(t *testing.T) { - for _, ask := range askCases { - err := r.SetAsk(ctx, &ask) - assert.NoError(t, err) - } - }) + return ctx, r, askCases +} + +func TestSetAsk(t *testing.T) { + ctx, r, askCases := prepareRetrievalAskTest(t) - t.Run("GetAsk", func(t *testing.T) { - res, err := r.GetAsk(ctx, askCases[0].Miner) + for _, ask := range askCases { + err := r.SetAsk(ctx, &ask) assert.NoError(t, err) - askCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, askCases[0], *res) - }) + } +} - // refresh UpdatedAt field +func TestGetAsk(t *testing.T) { + ctx, r, askCases := prepareRetrievalAskTest(t) + for _, ask := range askCases { + err := r.SetAsk(ctx, &ask) + assert.NoError(t, err) + } + + res, err := r.GetAsk(ctx, askCases[0].Miner) + assert.NoError(t, err) + askCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, askCases[0], *res) +} + +func TestListAsk(t *testing.T) { + ctx, r, askCases := prepareRetrievalAskTest(t) + + for _, ask := range askCases { + err := r.SetAsk(ctx, &ask) + assert.NoError(t, err) + } + + // refresh UpdatedAt field for i := 0; i < len(askCases); i++ { res, err := r.GetAsk(ctx, askCases[i].Miner) assert.NoError(t, err) askCases[i].UpdatedAt = res.UpdatedAt } - t.Run("ListAsk", func(t *testing.T) { - res, err := r.ListAsk(ctx) - assert.NoError(t, err) - assert.Equal(t, len(askCases), len(res)) - for _, ask := range res { - assert.Contains(t, askCases, *ask) - } - }) + res, err := r.ListAsk(ctx) + assert.NoError(t, err) + assert.Equal(t, len(askCases), len(res)) + for _, ask := range res { + assert.Contains(t, askCases, *ask) + } } diff --git a/models/badger/retrieval_deal_test.go b/models/badger/retrieval_deal_test.go index 2f665534..d322990b 100644 --- a/models/badger/retrieval_deal_test.go +++ b/models/badger/retrieval_deal_test.go @@ -6,6 +6,7 @@ import ( "github.com/filecoin-project/go-address" "github.com/filecoin-project/go-fil-markets/retrievalmarket" + "github.com/filecoin-project/venus-market/v2/models/repo" types "github.com/filecoin-project/venus/venus-shared/types/market" cbg "github.com/whyrusleeping/cbor-gen" @@ -21,47 +22,79 @@ func init() { }) } -func TestRetrievalDeal(t *testing.T) { +func prepareRetrievalDealTest(t *testing.T) (context.Context, repo.IRetrievalDealRepo, []types.ProviderDealState) { ctx := context.Background() repo := setup(t) r := repo.RetrievalDealRepo() dealCases := make([]types.ProviderDealState, 10) testutil.Provide(t, &dealCases) + return ctx, r, dealCases +} - t.Run("SaveDeal", func(t *testing.T) { - for _, deal := range dealCases { - err := r.SaveDeal(ctx, &deal) - assert.NoError(t, err) - } - }) +func TestSaveDeal(t *testing.T) { + ctx, r, dealCases := prepareRetrievalDealTest(t) - t.Run("GetDeal", func(t *testing.T) { - res, err := r.GetDeal(ctx, dealCases[0].Receiver, dealCases[0].ID) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - dealCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, dealCases[0], *res) - }) + } +} - t.Run("GetDealByTransferId", func(t *testing.T) { - res, err := r.GetDealByTransferId(ctx, *dealCases[0].ChannelID) +func TestGetDeal(t *testing.T) { + ctx, r, dealCases := prepareRetrievalDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - dealCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, dealCases[0], *res) - }) + } + + res, err := r.GetDeal(ctx, dealCases[0].Receiver, dealCases[0].ID) + assert.NoError(t, err) + dealCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, dealCases[0], *res) +} - t.Run("HasDeal", func(t *testing.T) { - dealCase_not_exist := types.ProviderDealState{} - testutil.Provide(t, &dealCase_not_exist) - res, err := r.HasDeal(ctx, dealCase_not_exist.Receiver, dealCase_not_exist.ID) +func TestGetDealByTransferId(t *testing.T) { + ctx, r, dealCases := prepareRetrievalDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - assert.False(t, res) + } + + res, err := r.GetDealByTransferId(ctx, *dealCases[0].ChannelID) + assert.NoError(t, err) + dealCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, dealCases[0], *res) +} + +func TestHasDeal(t *testing.T) { + ctx, r, dealCases := prepareRetrievalDealTest(t) - res, err = r.HasDeal(ctx, dealCases[0].Receiver, dealCases[0].ID) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - assert.True(t, res) - }) + } + + dealCase_not_exist := types.ProviderDealState{} + testutil.Provide(t, &dealCase_not_exist) + res, err := r.HasDeal(ctx, dealCase_not_exist.Receiver, dealCase_not_exist.ID) + assert.NoError(t, err) + assert.False(t, res) + + res, err = r.HasDeal(ctx, dealCases[0].Receiver, dealCases[0].ID) + assert.NoError(t, err) + assert.True(t, res) +} +func TestListDeals(t *testing.T) { + ctx, r, dealCases := prepareRetrievalDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } // refresh UpdatedAt for i := 0; i < len(dealCases); i++ { res, err := r.GetDeal(ctx, dealCases[i].Receiver, dealCases[i].ID) @@ -69,22 +102,33 @@ func TestRetrievalDeal(t *testing.T) { dealCases[i].UpdatedAt = res.UpdatedAt } - t.Run("ListDeals", func(t *testing.T) { - res, err := r.ListDeals(ctx, 1, 10) - assert.NoError(t, err) - assert.Equal(t, len(dealCases), len(res)) - for _, res := range res { - assert.Contains(t, dealCases, *res) - } - }) + res, err := r.ListDeals(ctx, 1, 10) + assert.NoError(t, err) + assert.Equal(t, len(dealCases), len(res)) + for _, res := range res { + assert.Contains(t, dealCases, *res) + } +} - t.Run("GroupRetrievalDealNumberByStatus", func(t *testing.T) { - expect := map[retrievalmarket.DealStatus]int64{} - for _, deal := range dealCases { - expect[deal.Status]++ - } - res, err := r.GroupRetrievalDealNumberByStatus(ctx, address.Undef) +func TestGroupRetrievalDealNumberByStatus(t *testing.T) { + ctx, r, dealCases := prepareRetrievalDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - assert.Equal(t, expect, res) - }) + } + // refresh UpdatedAt + for i := 0; i < len(dealCases); i++ { + res, err := r.GetDeal(ctx, dealCases[i].Receiver, dealCases[i].ID) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + } + + expect := map[retrievalmarket.DealStatus]int64{} + for _, deal := range dealCases { + expect[deal.Status]++ + } + res, err := r.GroupRetrievalDealNumberByStatus(ctx, address.Undef) + assert.NoError(t, err) + assert.Equal(t, expect, res) } diff --git a/models/badger/storage_ask_test.go b/models/badger/storage_ask_test.go index fc73fb4c..cbaf76d5 100644 --- a/models/badger/storage_ask_test.go +++ b/models/badger/storage_ask_test.go @@ -4,34 +4,53 @@ import ( "context" "testing" + "github.com/filecoin-project/venus-market/v2/models/repo" "github.com/filecoin-project/venus/venus-shared/testutil" types "github.com/filecoin-project/venus/venus-shared/types/market" "github.com/stretchr/testify/assert" ) -func TestStorageAsk(t *testing.T) { - ctx := context.Background() +func prepareStorageAskTest(t *testing.T) (ctx context.Context, r repo.IStorageAskRepo, askCases []types.SignedStorageAsk) { + ctx = context.Background() repo := setup(t) - r := repo.StorageAskRepo() + r = repo.StorageAskRepo() - askCases := make([]types.SignedStorageAsk, 10) + askCases = make([]types.SignedStorageAsk, 10) testutil.Provide(t, &askCases) + return ctx, r, askCases +} - t.Run("SetAsk", func(t *testing.T) { - for _, ask := range askCases { - err := r.SetAsk(ctx, &ask) - assert.NoError(t, err) - } - }) +func TestSetStorageAsk(t *testing.T) { + ctx, r, askCases := prepareStorageAskTest(t) - t.Run("GetAsk", func(t *testing.T) { - res, err := r.GetAsk(ctx, askCases[0].Ask.Miner) + for _, ask := range askCases { + err := r.SetAsk(ctx, &ask) assert.NoError(t, err) - askCases[0].UpdatedAt = res.UpdatedAt - assert.Equal(t, askCases[0], *res) - }) + } +} + +func TestGetStorageAsk(t *testing.T) { + ctx, r, askCases := prepareStorageAskTest(t) + + for _, ask := range askCases { + err := r.SetAsk(ctx, &ask) + assert.NoError(t, err) + } - // refresh UpdatedAt field + res, err := r.GetAsk(ctx, askCases[0].Ask.Miner) + assert.NoError(t, err) + askCases[0].UpdatedAt = res.UpdatedAt + assert.Equal(t, askCases[0], *res) +} + +// refresh UpdatedAt field +func TestListStorageAsk(t *testing.T) { + ctx, r, askCases := prepareStorageAskTest(t) + + for _, ask := range askCases { + err := r.SetAsk(ctx, &ask) + assert.NoError(t, err) + } for i := 0; i < len(askCases); i++ { res, err := r.GetAsk(ctx, askCases[i].Ask.Miner) @@ -39,12 +58,10 @@ func TestStorageAsk(t *testing.T) { askCases[i].UpdatedAt = res.UpdatedAt } - t.Run("ListAsk", func(t *testing.T) { - res, err := r.ListAsk(ctx) - assert.NoError(t, err) - assert.Equal(t, len(askCases), len(res)) - for _, ask := range res { - assert.Contains(t, askCases, *ask) - } - }) + res, err := r.ListAsk(ctx) + assert.NoError(t, err) + assert.Equal(t, len(askCases), len(res)) + for _, ask := range res { + assert.Contains(t, askCases, *ask) + } } diff --git a/models/badger/storage_deal_test.go b/models/badger/storage_deal_test.go index 87e65578..5ba6416a 100644 --- a/models/badger/storage_deal_test.go +++ b/models/badger/storage_deal_test.go @@ -11,7 +11,7 @@ import ( "github.com/ipfs/go-cid" "github.com/filecoin-project/go-address" - + "github.com/filecoin-project/venus-market/v2/models/repo" "github.com/stretchr/testify/assert" m_types "github.com/filecoin-project/venus/venus-shared/types/market" @@ -27,29 +27,48 @@ func init() { }) } -func TestStorageDeal(t *testing.T) { +func prepareStorageDealTest(t *testing.T) (context.Context, repo.StorageDealRepo, []m_types.MinerDeal) { ctx := context.Background() repo := setup(t) r := repo.StorageDealRepo() dealCases := make([]m_types.MinerDeal, 10) testutil.Provide(t, &dealCases) - dealCases[0].PieceStatus = m_types.Assigned - t.Run("SaveDeal", func(t *testing.T) { - for _, deal := range dealCases { - err := r.SaveDeal(ctx, &deal) - assert.NoError(t, err) - } - }) + return ctx, r, dealCases +} + +func TestSaveStorageDeal(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) - t.Run("GetDeal", func(t *testing.T) { - res, err := r.GetDeal(ctx, dealCases[0].ProposalCid) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - dealCases[0].UpdatedAt = res.UpdatedAt - dealCases[0].CreationTime = res.CreationTime - assert.Equal(t, dealCases[0], *res) - }) + } +} + +func TestGetStorageDeal(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + res, err := r.GetDeal(ctx, dealCases[0].ProposalCid) + assert.NoError(t, err) + dealCases[0].UpdatedAt = res.UpdatedAt + dealCases[0].CreationTime = res.CreationTime + assert.Equal(t, dealCases[0], *res) +} + +func TestListStorageDeal(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } // refresh UpdatedAt and CreationTime for i := range dealCases { @@ -59,204 +78,381 @@ func TestStorageDeal(t *testing.T) { dealCases[i].CreationTime = res.CreationTime } - t.Run("ListDeal", func(t *testing.T) { - res, err := r.ListDeal(ctx) + res, err := r.ListDeal(ctx) + assert.NoError(t, err) + assert.Equal(t, len(dealCases), len(res)) + for _, deal := range res { + assert.Contains(t, dealCases, *deal) + } +} + +func TestGetStorageDealByDealID(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - assert.Equal(t, len(dealCases), len(res)) - for _, deal := range res { - assert.Contains(t, dealCases, *deal) - } - }) + } - t.Run("GetDealByDealID", func(t *testing.T) { - res, err := r.GetDealByDealID(ctx, dealCases[0].Proposal.Provider, dealCases[0].DealID) + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) assert.NoError(t, err) - assert.Equal(t, dealCases[0], *res) - }) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } - t.Run("GetDeals", func(t *testing.T) { - res, err := r.GetDeals(ctx, dealCases[0].Proposal.Provider, 0, 10) + res, err := r.GetDealByDealID(ctx, dealCases[0].Proposal.Provider, dealCases[0].DealID) + assert.NoError(t, err) + assert.Equal(t, dealCases[0], *res) +} + +func TestGetStorageDeals(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + res, err := r.GetDeals(ctx, dealCases[0].Proposal.Provider, 0, 10) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, dealCases[0], *res[0]) +} + +func TestGetStorageDealsByPieceCidAndStatus(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + res, err := r.GetDealsByPieceCidAndStatus(ctx, dealCases[0].Proposal.PieceCID, dealCases[0].State) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, dealCases[0], *res[0]) +} + +func TestGetStorageDealsByPieceStatusAndDealStatus(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + t.Run("With DealStatus", func(t *testing.T) { + res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].PieceStatus, dealCases[0].State) assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, dealCases[0], *res[0]) }) - t.Run("GetDealsByPieceCidAndStatus", func(t *testing.T) { - res, err := r.GetDealsByPieceCidAndStatus(ctx, dealCases[0].Proposal.PieceCID, dealCases[0].State) + t.Run("Without DealStatus", func(t *testing.T) { + res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].PieceStatus) assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, dealCases[0], *res[0]) }) - t.Run("GetDealsByPieceStatusAndDealStatus", func(t *testing.T) { - t.Run("With DealStatus", func(t *testing.T) { - res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].PieceStatus, dealCases[0].State) - assert.NoError(t, err) - assert.Equal(t, 1, len(res)) - assert.Equal(t, dealCases[0], *res[0]) - }) - - t.Run("Without DealStatus", func(t *testing.T) { - res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].PieceStatus) - assert.NoError(t, err) - assert.Equal(t, 1, len(res)) - assert.Equal(t, dealCases[0], *res[0]) - }) - - t.Run("Without Provider", func(t *testing.T) { - res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, address.Undef, dealCases[0].PieceStatus, dealCases[0].State) - assert.NoError(t, err) - assert.Equal(t, 1, len(res)) - assert.Equal(t, dealCases[0], *res[0]) - }) - - t.Run("Will Return None", func(t *testing.T) { - res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, address.Undef, dealCases[0].PieceStatus, 0) - assert.NoError(t, err) - assert.Equal(t, 0, len(res)) - }) + t.Run("Without Provider", func(t *testing.T) { + res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, address.Undef, dealCases[0].PieceStatus, dealCases[0].State) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, dealCases[0], *res[0]) }) - t.Run("GetDealsByDataCidAndDealStatus", func(t *testing.T) { - t.Run("With Provider", func(t *testing.T) { - res, err := r.GetDealsByDataCidAndDealStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].Ref.Root, []m_types.PieceStatus{dealCases[0].PieceStatus}) - assert.NoError(t, err) - assert.Equal(t, 1, len(res)) - assert.Equal(t, dealCases[0], *res[0]) - }) - - t.Run("Without Provider", func(t *testing.T) { - res, err := r.GetDealsByDataCidAndDealStatus(ctx, address.Undef, dealCases[0].Ref.Root, []m_types.PieceStatus{dealCases[0].PieceStatus}) - assert.NoError(t, err) - assert.Equal(t, 1, len(res)) - assert.Equal(t, dealCases[0], *res[0]) - }) + t.Run("Will Return None", func(t *testing.T) { + res, err := r.GetDealsByPieceStatusAndDealStatus(ctx, address.Undef, dealCases[0].PieceStatus, 0) + assert.NoError(t, err) + assert.Equal(t, 0, len(res)) }) +} - t.Run("GetDealByAddrAndStatus", func(t *testing.T) { - t.Run("With Provider", func(t *testing.T) { - res, err := r.GetDealByAddrAndStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].State) - assert.NoError(t, err) - assert.Equal(t, 1, len(res)) - assert.Equal(t, dealCases[0], *res[0]) - }) - - t.Run("Without Provider", func(t *testing.T) { - res, err := r.GetDealByAddrAndStatus(ctx, address.Undef, dealCases[0].State) - assert.NoError(t, err) - assert.Equal(t, 1, len(res)) - assert.Equal(t, dealCases[0], *res[0]) - }) - }) +func TestGetStorageDealsByDataCidAndDealStatus(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) - t.Run("ListDealByAddr", func(t *testing.T) { - res, err := r.ListDealByAddr(ctx, dealCases[0].Proposal.Provider) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + t.Run("With Provider", func(t *testing.T) { + res, err := r.GetDealsByDataCidAndDealStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].Ref.Root, []m_types.PieceStatus{dealCases[0].PieceStatus}) assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, dealCases[0], *res[0]) }) - t.Run("GetPieceInfo", func(t *testing.T) { - res, err := r.GetPieceInfo(ctx, dealCases[0].Proposal.PieceCID) + t.Run("Without Provider", func(t *testing.T) { + res, err := r.GetDealsByDataCidAndDealStatus(ctx, address.Undef, dealCases[0].Ref.Root, []m_types.PieceStatus{dealCases[0].PieceStatus}) assert.NoError(t, err) - expect := piecestore.PieceInfo{ - PieceCID: dealCases[0].Proposal.PieceCID, - Deals: nil, - } - expect.Deals = append(expect.Deals, piecestore.DealInfo{ - DealID: dealCases[0].DealID, - SectorID: dealCases[0].SectorNumber, - Offset: dealCases[0].Offset, - Length: dealCases[0].Proposal.PieceSize, - }) - assert.Equal(t, expect, *res) + assert.Equal(t, 1, len(res)) + assert.Equal(t, dealCases[0], *res[0]) }) +} - t.Run("ListPieceInfoKeys", func(t *testing.T) { - res, err := r.ListPieceInfoKeys(ctx) +func TestGetStorageDealByAddrAndStatus(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - assert.Equal(t, len(dealCases), len(res)) - exp := make([]cid.Cid, 0, len(dealCases)) - for _, deal := range dealCases { - exp = append(exp, deal.Proposal.PieceCID) - } - for _, id := range res { - assert.Contains(t, exp, id) - } + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + t.Run("With Provider", func(t *testing.T) { + res, err := r.GetDealByAddrAndStatus(ctx, dealCases[0].Proposal.Provider, dealCases[0].State) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, dealCases[0], *res[0]) }) - t.Run("GetPieceSize", func(t *testing.T) { - PLSize, PSize, err := r.GetPieceSize(ctx, dealCases[0].Proposal.PieceCID) + t.Run("Without Provider", func(t *testing.T) { + res, err := r.GetDealByAddrAndStatus(ctx, address.Undef, dealCases[0].State) assert.NoError(t, err) - assert.Equal(t, dealCases[0].Proposal.PieceSize, PSize) - assert.Equal(t, dealCases[0].PayloadSize, PLSize) + assert.Equal(t, 1, len(res)) + assert.Equal(t, dealCases[0], *res[0]) }) +} + +func TestListStorageDealByAddr(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) - t.Run("UpdateDealStatus", func(t *testing.T) { - err := r.UpdateDealStatus(ctx, dealCases[0].ProposalCid, storagemarket.StorageDealActive, m_types.Proving) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) assert.NoError(t, err) - res, err := r.GetDeal(ctx, dealCases[0].ProposalCid) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) assert.NoError(t, err) - assert.Equal(t, storagemarket.StorageDealActive, res.State) - assert.Equal(t, m_types.Proving, res.PieceStatus) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + res, err := r.ListDealByAddr(ctx, dealCases[0].Proposal.Provider) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, dealCases[0], *res[0]) +} + +func TestGetStoragePieceInfo(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + res, err := r.GetPieceInfo(ctx, dealCases[0].Proposal.PieceCID) + assert.NoError(t, err) + expect := piecestore.PieceInfo{ + PieceCID: dealCases[0].Proposal.PieceCID, + Deals: nil, + } + expect.Deals = append(expect.Deals, piecestore.DealInfo{ + DealID: dealCases[0].DealID, + SectorID: dealCases[0].SectorNumber, + Offset: dealCases[0].Offset, + Length: dealCases[0].Proposal.PieceSize, }) + assert.Equal(t, expect, *res) +} - t.Run("GroupStorageDealNumberByStatus", func(t *testing.T) { - t.Run("correct", func(t *testing.T) { - repo := setup(t) - r := repo.StorageDealRepo() +func TestListPieceInfoKeys(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) - deals := make([]m_types.MinerDeal, 100) - testutil.Provide(t, &deals) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } - var addrs []address.Address - addrGetter := address.NewForTestGetter() - for i := 0; i < 3; i++ { - addrs = append(addrs, addrGetter()) - } + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } - for index := range deals { - deals[index].ClientDealProposal.Proposal.Provider = addrs[rand.Intn(len(addrs))] - deals[index].State = storagemarket.StorageDealStatus(rand.Intn(int(storagemarket.StorageDealReserveProviderFunds))) - } + res, err := r.ListPieceInfoKeys(ctx) + assert.NoError(t, err) + assert.Equal(t, len(dealCases), len(res)) + exp := make([]cid.Cid, 0, len(dealCases)) + for _, deal := range dealCases { + exp = append(exp, deal.Proposal.PieceCID) + } + for _, id := range res { + assert.Contains(t, exp, id) + } +} - for _, deal := range deals { - err := r.SaveDeal(ctx, &deal) - assert.Nil(t, err) - } +func TestGetPieceSize(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) - result := map[storagemarket.StorageDealStatus]int64{} - for _, deal := range deals { - if deal.Proposal.Provider != addrs[0] { - continue - } - result[deal.State]++ - } - result2, err := r.GroupStorageDealNumberByStatus(ctx, addrs[0]) - assert.Nil(t, err) - assert.Equal(t, result, result2) - }) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } - t.Run("undefined address", func(t *testing.T) { - repo := setup(t) - r := repo.StorageDealRepo() + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } - deals := make([]m_types.MinerDeal, 10) - testutil.Provide(t, &deals) + PLSize, PSize, err := r.GetPieceSize(ctx, dealCases[0].Proposal.PieceCID) + assert.NoError(t, err) + assert.Equal(t, dealCases[0].Proposal.PieceSize, PSize) + assert.Equal(t, dealCases[0].PayloadSize, PLSize) +} - result := map[storagemarket.StorageDealStatus]int64{} - for _, deal := range deals { - result[deal.State]++ - } +func TestUpdateStorageDealStatus(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) - for _, deal := range deals { - err := r.SaveDeal(ctx, &deal) - assert.Nil(t, err) + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + err := r.UpdateDealStatus(ctx, dealCases[0].ProposalCid, storagemarket.StorageDealActive, m_types.Proving) + assert.NoError(t, err) + res, err := r.GetDeal(ctx, dealCases[0].ProposalCid) + assert.NoError(t, err) + assert.Equal(t, storagemarket.StorageDealActive, res.State) + assert.Equal(t, m_types.Proving, res.PieceStatus) +} + +func TestGroupStorageDealNumberByStatus(t *testing.T) { + ctx, r, dealCases := prepareStorageDealTest(t) + + for _, deal := range dealCases { + err := r.SaveDeal(ctx, &deal) + assert.NoError(t, err) + } + + // refresh UpdatedAt and CreationTime + for i := range dealCases { + res, err := r.GetDeal(ctx, dealCases[i].ProposalCid) + assert.NoError(t, err) + dealCases[i].UpdatedAt = res.UpdatedAt + dealCases[i].CreationTime = res.CreationTime + } + + t.Run("correct", func(t *testing.T) { + repo := setup(t) + r := repo.StorageDealRepo() + + deals := make([]m_types.MinerDeal, 100) + testutil.Provide(t, &deals) + + var addrs []address.Address + addrGetter := address.NewForTestGetter() + for i := 0; i < 3; i++ { + addrs = append(addrs, addrGetter()) + } + + for index := range deals { + deals[index].ClientDealProposal.Proposal.Provider = addrs[rand.Intn(len(addrs))] + deals[index].State = storagemarket.StorageDealStatus(rand.Intn(int(storagemarket.StorageDealReserveProviderFunds))) + } + + for _, deal := range deals { + err := r.SaveDeal(ctx, &deal) + assert.Nil(t, err) + } + + result := map[storagemarket.StorageDealStatus]int64{} + for _, deal := range deals { + if deal.Proposal.Provider != addrs[0] { + continue } + result[deal.State]++ + } + result2, err := r.GroupStorageDealNumberByStatus(ctx, addrs[0]) + assert.Nil(t, err) + assert.Equal(t, result, result2) + }) + + t.Run("undefined address", func(t *testing.T) { + repo := setup(t) + r := repo.StorageDealRepo() - result2, err := r.GroupStorageDealNumberByStatus(ctx, address.Undef) + deals := make([]m_types.MinerDeal, 10) + testutil.Provide(t, &deals) + + result := map[storagemarket.StorageDealStatus]int64{} + for _, deal := range deals { + result[deal.State]++ + } + + for _, deal := range deals { + err := r.SaveDeal(ctx, &deal) assert.Nil(t, err) - assert.Equal(t, result, result2) - }) + } + + result2, err := r.GroupStorageDealNumberByStatus(ctx, address.Undef) + assert.Nil(t, err) + assert.Equal(t, result, result2) }) } diff --git a/models/cid_info_test.go b/models/cid_info_test.go index 1ce050ac..60806796 100644 --- a/models/cid_info_test.go +++ b/models/cid_info_test.go @@ -13,18 +13,16 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCIDInfo(t *testing.T) { - t.Run("badger", func(t *testing.T) { - db := BadgerDB(t) - doTestCidinfo(t, badger.NewBadgerCidInfoRepo(db)) - }) - - t.Run("mysql", func(t *testing.T) { - repo := MysqlDB(t) - cidInfoRepo := repo.CidInfoRepo() - defer func() { require.NoError(t, repo.Close()) }() - doTestCidinfo(t, cidInfoRepo) - }) +func TestCIDInfoBadger(t *testing.T) { + db := BadgerDB(t) + doTestCidinfo(t, badger.NewBadgerCidInfoRepo(db)) +} + +func TestCIDInfoMysql(t *testing.T) { + repo := MysqlDB(t) + cidInfoRepo := repo.CidInfoRepo() + defer func() { require.NoError(t, repo.Close()) }() + doTestCidinfo(t, cidInfoRepo) } func doTestCidinfo(t *testing.T, repo repo.ICidInfoRepo) { diff --git a/models/fund_test.go b/models/fund_test.go index 668ca64a..c632a636 100644 --- a/models/fund_test.go +++ b/models/fund_test.go @@ -11,15 +11,13 @@ import ( "github.com/stretchr/testify/assert" ) -func TestFund(t *testing.T) { - t.Run("mysql", func(t *testing.T) { - testFund(t, MysqlDB(t).FundRepo()) - }) - - t.Run("badger", func(t *testing.T) { - db := BadgerDB(t) - testFund(t, badger.NewFundRepo(db)) - }) +func TestFundMysql(t *testing.T) { + testFund(t, MysqlDB(t).FundRepo()) +} + +func TestFundBadger(t *testing.T) { + db := BadgerDB(t) + testFund(t, badger.NewFundRepo(db)) } func testFund(t *testing.T, fundRepo repo.FundRepo) { diff --git a/models/mysql/cid_info_test.go b/models/mysql/cid_info_test.go index 84bff67f..1dfb1735 100644 --- a/models/mysql/cid_info_test.go +++ b/models/mysql/cid_info_test.go @@ -14,22 +14,19 @@ import ( "github.com/stretchr/testify/assert" ) -var cidInfoCases []cidInfo - -func TestCidInfo(t *testing.T) { +func prepareCIDInfoTest(t *testing.T) (repo.Repo, sqlmock.Sqlmock, []cidInfo, func()) { r, mock, sqlDB := setup(t) - - cidInfoCases = make([]cidInfo, 10) + cidInfoCases := make([]cidInfo, 10) testutil.Provide(t, &cidInfoCases) - - t.Run("mysql test GetCIDInfo", wrapper(testGetCIDInfo, r, mock)) - t.Run("mysql test ListCidInfoKeys", wrapper(testListCidInfoKeys, r, mock)) - t.Run("mysql test AddPieceBlockLocations", wrapper(testAddPieceBlockLocations, r, mock)) - - assert.NoError(t, closeDB(mock, sqlDB)) + return r, mock, cidInfoCases, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } -func testGetCIDInfo(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetCIDInfo(t *testing.T) { + r, mock, cidInfoCases, done := prepareCIDInfoTest(t) + defer done() + cidInfoCase := cidInfoCases[0] pCidinfo := piecestore.CIDInfo{ @@ -58,7 +55,10 @@ func testGetCIDInfo(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, pCidinfo, res) } -func testListCidInfoKeys(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListCidInfoKeys(t *testing.T) { + r, mock, cidInfoCases, done := prepareCIDInfoTest(t) + defer done() + db, err := getMysqlDryrunDB() assert.NoError(t, err) @@ -74,7 +74,10 @@ func testListCidInfoKeys(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, []cid.Cid{cidInfoCases[0].PayloadCid.cid(), cidInfoCases[1].PayloadCid.cid()}, res) } -func testAddPieceBlockLocations(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestAddPieceBlockLocations(t *testing.T) { + r, mock, _, done := prepareCIDInfoTest(t) + defer done() + cid1, err := getTestCid() assert.NoError(t, err) cid2, err := getTestCid() diff --git a/models/mysql/fund_test.go b/models/mysql/fund_test.go index 97eeea2c..dffb3fbf 100644 --- a/models/mysql/fund_test.go +++ b/models/mysql/fund_test.go @@ -14,34 +14,33 @@ import ( "github.com/stretchr/testify/assert" ) -var fundedAddressStatesCase = []*market_types.FundedAddressState{ - { - Addr: address.TestAddress, - AmtReserved: abi.NewTokenAmount(100), - MsgCid: nil, - TimeStamp: market_types.TimeStamp{CreatedAt: uint64(time.Now().Unix()), UpdatedAt: uint64(time.Now().Unix())}, - }, - { - Addr: address.TestAddress2, - AmtReserved: abi.NewTokenAmount(100), - MsgCid: nil, - TimeStamp: market_types.TimeStamp{CreatedAt: uint64(time.Now().Unix()), UpdatedAt: uint64(time.Now().Unix())}, - }, -} - var fundedAddressStateColumns = []string{"addr", "amt_reserved", "msg_cid", "created_at", "updated_at"} -func TestFundAddrState(t *testing.T) { +var prepareFundAddrStateTest = func(t *testing.T) (repo.Repo, sqlmock.Sqlmock, []*market_types.FundedAddressState, func()) { r, mock, sqlDB := setup(t) - - t.Run("mysql test SaveFundedAddressState", wrapper(testSaveFundedAddressState, r, mock)) - t.Run("mysql test GetFundedAddressState", wrapper(testGetFundedAddressState, r, mock)) - t.Run("mysql test ListFundedAddressState", wrapper(testListFundedAddressState, r, mock)) - - assert.NoError(t, closeDB(mock, sqlDB)) + fundedAddressStatesCase := []*market_types.FundedAddressState{ + { + Addr: address.TestAddress, + AmtReserved: abi.NewTokenAmount(100), + MsgCid: nil, + TimeStamp: market_types.TimeStamp{CreatedAt: uint64(time.Now().Unix()), UpdatedAt: uint64(time.Now().Unix())}, + }, + { + Addr: address.TestAddress2, + AmtReserved: abi.NewTokenAmount(100), + MsgCid: nil, + TimeStamp: market_types.TimeStamp{CreatedAt: uint64(time.Now().Unix()), UpdatedAt: uint64(time.Now().Unix())}, + }, + } + return r, mock, fundedAddressStatesCase, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } -func testSaveFundedAddressState(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestSaveFundedAddressState(t *testing.T) { + r, mock, fundedAddressStatesCase, done := prepareFundAddrStateTest(t) + defer done() + ctx := context.Background() fas := fromFundedAddressState(fundedAddressStatesCase[0]) @@ -52,7 +51,10 @@ func testSaveFundedAddressState(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) assert.NoError(t, err) } -func testGetFundedAddressState(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetFundedAddressState(t *testing.T) { + r, mock, fundedAddressStatesCase, done := prepareFundAddrStateTest(t) + defer done() + fas := fromFundedAddressState(fundedAddressStatesCase[0]) mock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `funded_address_state` WHERE addr = ? LIMIT 1")).WithArgs(fas.Addr).WillReturnRows(sqlmock.NewRows(fundedAddressStateColumns).AddRow([]byte(fas.Addr.String()), fas.AmtReserved, []byte(fas.MsgCid.String()), fas.CreatedAt, fas.UpdatedAt)) res, err := r.FundRepo().GetFundedAddressState(context.Background(), fundedAddressStatesCase[0].Addr) @@ -60,7 +62,10 @@ func testGetFundedAddressState(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) assert.Equal(t, fundedAddressStatesCase[0], res) } -func testListFundedAddressState(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListFundedAddressState(t *testing.T) { + r, mock, fundedAddressStatesCase, done := prepareFundAddrStateTest(t) + defer done() + rows := sqlmock.NewRows(fundedAddressStateColumns) for _, fas := range fundedAddressStatesCase { fas_ := fromFundedAddressState(fas) diff --git a/models/mysql/paych_test.go b/models/mysql/paych_test.go index 6551525f..c132f7eb 100644 --- a/models/mysql/paych_test.go +++ b/models/mysql/paych_test.go @@ -18,46 +18,39 @@ import ( "gorm.io/gorm/clause" ) -var ( - channelInfosCases []*types.ChannelInfo - msgInfosCase []*types.MsgInfo -) - -func TestChannelInfo(t *testing.T) { +func prepareChannelInfoTest(t *testing.T) (repo.Repo, sqlmock.Sqlmock, []*types.ChannelInfo, func()) { r, mock, sqlDB := setup(t) - channelInfosCases = make([]*types.ChannelInfo, 10) + channelInfosCases := make([]*types.ChannelInfo, 10) testutil.Provide(t, &channelInfosCases) for _, ch := range channelInfosCases { ch.PendingAvailableAmount = big.NewInt(0) ch.AvailableAmount = big.NewInt(0) } - msgInfosCase = make([]*types.MsgInfo, 10) - testutil.Provide(t, &msgInfosCase) + // msgInfosCase := make([]*types.MsgInfo, 10) + // testutil.Provide(t, &msgInfosCase) - t.Run("mysql test SaveChannel", wrapper(testSaveChannel, r, mock)) - t.Run("mysql test GetChannelByAddress", wrapper(testGetChannelByAddress, r, mock)) - t.Run("mysql test GetChannelByChannelID", wrapper(testGetChannelByChannelID, r, mock)) - t.Run("mysql test OutboundActiveByFromTo", wrapper(testOutboundActiveByFromTo, r, mock)) - t.Run("mysql test WithPendingAddFunds", wrapper(testWithPendingAddFunds, r, mock)) - t.Run("mysql test ListChannel", wrapper(testListChannel, r, mock)) - t.Run("mysql test RemoveChannel", wrapper(testRemoveChannel, r, mock)) - - assert.NoError(t, closeDB(mock, sqlDB)) + return r, mock, channelInfosCases, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } -func TestMegInfo(t *testing.T) { +func prepareMegInfoTest(t *testing.T) (repo.Repo, sqlmock.Sqlmock, []*types.MsgInfo, func()) { r, mock, sqlDB := setup(t) - t.Run("mysql test GetMessage", wrapper(testGetMessage, r, mock)) - t.Run("mysql test SaveMessage", wrapper(testSaveMessage, r, mock)) - t.Run("mysql test SaveMessageResult", wrapper(testSaveMessageResult, r, mock)) + msgInfosCase := make([]*types.MsgInfo, 10) + testutil.Provide(t, &msgInfosCase) - assert.NoError(t, closeDB(mock, sqlDB)) + return r, mock, msgInfosCase, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } -func testSaveChannel(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestSaveChannel(t *testing.T) { + r, mock, channelInfosCases, done := prepareChannelInfoTest(t) + defer done() + channelInfo := channelInfosCases[0] dbChannelInfo := fromChannelInfo(channelInfo) @@ -78,7 +71,10 @@ func testSaveChannel(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.NoError(t, err) } -func testGetChannelByAddress(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetChannelByAddress(t *testing.T) { + r, mock, channelInfosCases, done := prepareChannelInfoTest(t) + defer done() + channelInfoCase := channelInfosCases[0] dbChannelInfo := fromChannelInfo(channelInfoCase) @@ -99,7 +95,10 @@ func testGetChannelByAddress(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, channelInfoCase, res) } -func testGetChannelByChannelID(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetChannelByChannelID(t *testing.T) { + r, mock, channelInfosCases, done := prepareChannelInfoTest(t) + defer done() + channelInfoCase := channelInfosCases[0] dbChannelInfo := fromChannelInfo(channelInfoCase) @@ -120,7 +119,10 @@ func testGetChannelByChannelID(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) assert.Equal(t, channelInfoCase, res) } -func testOutboundActiveByFromTo(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestOutboundActiveByFromTo(t *testing.T) { + r, mock, channelInfosCases, done := prepareChannelInfoTest(t) + defer done() + channelInfoCase := channelInfosCases[0] channelInfoCase.Direction = types.DirOutbound channelInfoCase.Settling = false @@ -145,7 +147,10 @@ func testOutboundActiveByFromTo(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) assert.Equal(t, channelInfoCase, res) } -func testWithPendingAddFunds(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestWithPendingAddFunds(t *testing.T) { + r, mock, channelInfosCases, done := prepareChannelInfoTest(t) + defer done() + dbChannelInfos := make([]*channelInfo, len(channelInfosCases)) for i, channelInfo := range channelInfosCases { tempChannelInfo := channelInfo @@ -170,7 +175,10 @@ func testWithPendingAddFunds(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, channelInfosCases, res) } -func testListChannel(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListChannel(t *testing.T) { + r, mock, channelInfosCases, done := prepareChannelInfoTest(t) + defer done() + dbChannelInfos := make([]*channelInfo, len(channelInfosCases)) for i, channelInfo := range channelInfosCases { dbChannelInfos[i] = fromChannelInfo(channelInfo) @@ -198,7 +206,10 @@ func testListChannel(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, addrs, res) } -func testRemoveChannel(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestRemoveChannel(t *testing.T) { + r, mock, channelInfosCases, done := prepareChannelInfoTest(t) + defer done() + channelInfo := channelInfosCases[0] dbChannelInfo := fromChannelInfo(channelInfo) @@ -215,7 +226,10 @@ func testRemoveChannel(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.NoError(t, err) } -func testGetMessage(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetMessage(t *testing.T) { + r, mock, msgInfosCase, done := prepareMegInfoTest(t) + defer done() + msgInfo := msgInfosCase[0] dbMsgInfo := fromMsgInfo(msgInfo) @@ -228,7 +242,10 @@ func testGetMessage(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, msgInfo, res) } -func testSaveMessage(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestSaveMessage(t *testing.T) { + r, mock, msgInfosCase, done := prepareMegInfoTest(t) + defer done() + msgInfo := msgInfosCase[0] dbMsgInfo := fromMsgInfo(msgInfo) @@ -240,7 +257,10 @@ func testSaveMessage(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.NoError(t, err) } -func testSaveMessageResult(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestSaveMessageResult(t *testing.T) { + r, mock, msgInfosCase, done := prepareMegInfoTest(t) + defer done() + msgInfo := msgInfosCase[0] dbMsgInfo := fromMsgInfo(msgInfo) diff --git a/models/mysql/retrieval_ask_test.go b/models/mysql/retrieval_ask_test.go index ef54fa0f..b2700e1b 100644 --- a/models/mysql/retrieval_ask_test.go +++ b/models/mysql/retrieval_ask_test.go @@ -12,14 +12,12 @@ import ( "github.com/stretchr/testify/assert" ) -var retrievalAskCase *market_types.RetrievalAsk - -func TestRetrievalAsk(t *testing.T) { +func prepareRetrievalAskTest(t *testing.T) (repo.Repo, sqlmock.Sqlmock, *market_types.RetrievalAsk, func()) { r, mock, sqlDB := setup(t) addr := getTestAddress() - retrievalAskCase = &market_types.RetrievalAsk{ + retrievalAskCase := &market_types.RetrievalAsk{ Miner: addr, PricePerByte: types.NewInt(1), PaymentInterval: 2, @@ -27,14 +25,15 @@ func TestRetrievalAsk(t *testing.T) { UnsealPrice: types.NewInt(4), } - t.Run("mysql test GetAsk", wrapper(testRetrievalGetAsk, r, mock)) - t.Run("mysql test SetAsk", wrapper(testSetRetrievalAsk, r, mock)) - t.Run("mysql test ListAsk", wrapper(testListRetrievalAsk, r, mock)) - - assert.NoError(t, closeDB(mock, sqlDB)) + return r, mock, retrievalAskCase, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } -func testRetrievalGetAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestRetrievalGetAsk(t *testing.T) { + r, mock, retrievalAskCase, done := prepareRetrievalAskTest(t) + defer done() + ctx := context.Background() rows := mock.NewRows([]string{"price_per_byte", "payment_interval", "payment_interval_increase", "unseal_price"}) @@ -45,7 +44,10 @@ func testRetrievalGetAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, retrievalAskCase, result) } -func testSetRetrievalAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestSetRetrievalAsk(t *testing.T) { + r, mock, retrievalAskCase, done := prepareRetrievalAskTest(t) + defer done() + ctx := context.Background() mock.ExpectBegin() @@ -56,7 +58,10 @@ func testSetRetrievalAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Nil(t, err) } -func testListRetrievalAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListRetrievalAsk(t *testing.T) { + r, mock, retrievalAskCase, done := prepareRetrievalAskTest(t) + defer done() + ctx := context.Background() rows := mock.NewRows([]string{"address", "price_per_byte", "unseal_price", "payment_interval", "payment_interval_increase", "created_at", "updated_at"}) diff --git a/models/mysql/retrieval_deal_test.go b/models/mysql/retrieval_deal_test.go index 85ca3389..34a5e084 100644 --- a/models/mysql/retrieval_deal_test.go +++ b/models/mysql/retrieval_deal_test.go @@ -17,16 +17,11 @@ import ( "gorm.io/gorm/clause" ) -var ( - dbRetrievalDealCase *retrievalDeal - RetrievaldealStateCase *types.ProviderDealState -) - -func TestRetrievalDealRepo(t *testing.T) { +func prepareRetrievalDealRepoTest(t *testing.T) (repo.Repo, sqlmock.Sqlmock, *retrievalDeal, *types.ProviderDealState, func()) { peerId, err := getTestPeerId() assert.NoError(t, err) - dbRetrievalDealCase = &retrievalDeal{ + dbRetrievalDealCase := &retrievalDeal{ DealProposal: DealProposal{ ID: 1, PricePerByte: mtypes.NewInt(1), @@ -44,24 +39,22 @@ func TestRetrievalDealRepo(t *testing.T) { TimeStampOrm: TimeStampOrm{CreatedAt: uint64(time.Now().Unix()), UpdatedAt: uint64(time.Now().Unix())}, } - RetrievaldealStateCase, err = toProviderDealState(dbRetrievalDealCase) + RetrievaldealStateCase, err := toProviderDealState(dbRetrievalDealCase) assert.NoError(t, err) RetrievaldealStateCase.ChannelID = &datatransfer.ChannelID{ ID: datatransfer.TransferID(dbRetrievalDealCase.ChannelID.ID), } r, mock, sqlDB := setup(t) - t.Run("mysql test SaveDeal", wrapper(testSaveRetrievalDeal, r, mock)) - t.Run("mysql test GetDeal", wrapper(testRetrievalGetDeal, r, mock)) - t.Run("mysql test GetDealByTransferId", wrapper(testGetRetrievalDealByTransferId, r, mock)) - t.Run("mysql test HasDeal", wrapper(testHasRetrievalDeal, r, mock)) - t.Run("mysql test ListDeals", wrapper(testListRetrievalDeals, r, mock)) - t.Run("mysql test GroupRetrievalDealNumberByStatus", wrapper(testGroupRetrievalDealNumberByStatus, r, mock)) - - assert.NoError(t, closeDB(mock, sqlDB)) + + return r, mock, dbRetrievalDealCase, RetrievaldealStateCase, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } -func testSaveRetrievalDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestSaveRetrievalDeal(t *testing.T) { + r, mock, _, RetrievaldealStateCase, close := prepareRetrievalDealRepoTest(t) + defer close() ctx := context.Background() dbDeal := fromProviderDealState(RetrievaldealStateCase) @@ -80,7 +73,10 @@ func testSaveRetrievalDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Nil(t, err) } -func testRetrievalGetDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestRetrievalGetDeal(t *testing.T) { + r, mock, dbRetrievalDealCase, _, close := prepareRetrievalDealRepoTest(t) + defer close() + ctx := context.Background() peerid, err := peer.Decode(dbRetrievalDealCase.Receiver) @@ -98,7 +94,10 @@ func testRetrievalGetDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, res, dealState) } -func testGetRetrievalDealByTransferId(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetRetrievalDealByTransferId(t *testing.T) { + r, mock, dbRetrievalDealCase, _, close := prepareRetrievalDealRepoTest(t) + defer close() + ctx := context.Background() rows, err := getFullRows(dbRetrievalDealCase) @@ -115,8 +114,12 @@ func testGetRetrievalDealByTransferId(t *testing.T, r repo.Repo, mock sqlmock.Sq assert.Equal(t, res, dealState) } -func testHasRetrievalDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestHasRetrievalDeal(t *testing.T) { + r, mock, _, _, close := prepareRetrievalDealRepoTest(t) + defer close() + ctx := context.Background() + did := retrievalmarket.DealID(1) peerId, err := getTestPeerId() assert.Nil(t, err) @@ -132,7 +135,10 @@ func testHasRetrievalDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.True(t, has) } -func testListRetrievalDeals(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListRetrievalDeals(t *testing.T) { + r, mock, dbRetrievalDealCase, _, close := prepareRetrievalDealRepoTest(t) + defer close() + ctx := context.Background() rows := mock.NewRows([]string{"cdp_proposal_id", "cdp_payload_cid", "cdp_selector", "cdp_piece_cid", "cdp_price_perbyte", "cdp_payment_interval", "cdp_payment_interval_increase", "cdp_unseal_price", "store_id", "ci_initiator", "ci_responder", "ci_channel_id", "sel_proposal_cid", "status", "receiver", "total_sent", "funds_received", "message", "current_interval", "legacy_protocol", "created_at", "updated_at"}) @@ -153,7 +159,10 @@ func testListRetrievalDeals(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, res2[0], dealState) } -func testGroupRetrievalDealNumberByStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGroupRetrievalDealNumberByStatus(t *testing.T) { + r, mock, _, _, close := prepareRetrievalDealRepoTest(t) + defer close() + ctx := context.Background() expectResult := map[retrievalmarket.DealStatus]int64{ retrievalmarket.DealStatusAccepted: 1, diff --git a/models/mysql/storage_ask_test.go b/models/mysql/storage_ask_test.go index 3dfa58dc..227d92eb 100644 --- a/models/mysql/storage_ask_test.go +++ b/models/mysql/storage_ask_test.go @@ -14,13 +14,11 @@ import ( "gorm.io/gorm/clause" ) -var storageAskCases []types.SignedStorageAsk - -func TestStorageAsk(t *testing.T) { +func prepareStorageAskTest(t *testing.T) (repo.Repo, sqlmock.Sqlmock, []types.SignedStorageAsk, func()) { addr1 := getTestAddress() addr2 := getTestAddress() - storageAskCases = []types.SignedStorageAsk{ + storageAskCases := []types.SignedStorageAsk{ { Ask: &storagemarket.StorageAsk{ Miner: addr1, @@ -39,14 +37,15 @@ func TestStorageAsk(t *testing.T) { r, mock, sqlDB := setup(t) - t.Run("mysql test GetAsk", wrapper(testGetStorageAsk, r, mock)) - t.Run("mysql test SetAsk", wrapper(testSetStorageAsk, r, mock)) - t.Run("mysql test ListAsk", wrapper(testListStorageAsk, r, mock)) - - assert.NoError(t, closeDB(mock, sqlDB)) + return r, mock, storageAskCases, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } -func testGetStorageAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetStorageAsk(t *testing.T) { + r, mock, storageAskCases, done := prepareStorageAskTest(t) + defer done() + ask := storageAskCases[0] dbAsk := fromStorageAsk(&ask) @@ -66,7 +65,10 @@ func testGetStorageAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, ask, *ask2) } -func testSetStorageAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestSetStorageAsk(t *testing.T) { + r, mock, storageAskCases, done := prepareStorageAskTest(t) + defer done() + db, err := getMysqlDryrunDB() assert.NoError(t, err) @@ -90,7 +92,10 @@ func testSetStorageAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.NoError(t, err) } -func testListStorageAsk(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListStorageAsk(t *testing.T) { + r, mock, storageAskCases, done := prepareStorageAskTest(t) + defer done() + db, err := getMysqlDryrunDB() assert.NoError(t, err) diff --git a/models/mysql/storage_deal_test.go b/models/mysql/storage_deal_test.go index d86ddcc3..71bd652f 100644 --- a/models/mysql/storage_deal_test.go +++ b/models/mysql/storage_deal_test.go @@ -21,72 +21,9 @@ import ( "github.com/filecoin-project/go-address" ) -var ( - dbStorageDealCases []*storageDeal - storageDealCases []*types.MinerDeal -) - -func testSaveDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { - cid1, err := getTestCid() - assert.NoError(t, err) - cid2, err := getTestCid() - assert.NoError(t, err) - - peer1, err := getTestPeerId() - assert.NoError(t, err) - peer2, err := getTestPeerId() - assert.NoError(t, err) - - temp := []*types.MinerDeal{ - { - ProposalCid: cid1, - Miner: peer1, - Client: peer1, - ClientDealProposal: market.ClientDealProposal{ - Proposal: market.DealProposal{ - Provider: getTestAddress(), - PieceCID: cid1, - }, - }, - State: storagemarket.StorageDealActive, - TimeStamp: types.TimeStamp{ - CreatedAt: uint64(time.Now().Unix()), - UpdatedAt: uint64(time.Now().Unix()), - }, - Ref: &storagemarket.DataRef{ - Root: cid1, - }, - }, - { - ProposalCid: cid2, - Miner: peer2, - Client: peer2, - ClientDealProposal: market.ClientDealProposal{ - Proposal: market.DealProposal{ - Provider: getTestAddress(), - PieceCID: cid2, - }, - }, - State: storagemarket.StorageDealActive, - TimeStamp: types.TimeStamp{ - CreatedAt: uint64(time.Now().Unix()), - UpdatedAt: uint64(time.Now().Unix()), - }, - Ref: &storagemarket.DataRef{ - Root: cid2, - }, - }, - } - - storageDealCases = make([]*types.MinerDeal, 0) - dbStorageDealCases = make([]*storageDeal, 0) - for _, v := range temp { - dbDeal := fromStorageDeal(v) - deal, err := toStorageDeal(dbDeal) - assert.NoError(t, err) - dbStorageDealCases = append(dbStorageDealCases, dbDeal) - storageDealCases = append(storageDealCases, deal) - } +func TestSaveDeal(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() db, err := getMysqlDryrunDB() assert.NoError(t, err) @@ -106,7 +43,10 @@ func testSaveDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.NoError(t, err) } -func testGetDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetDeal(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + db, err := getMysqlDryrunDB() assert.NoError(t, err) @@ -130,7 +70,10 @@ func testGetDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, storageDealCases[0], res[0]) } -func testGetDealsByPieceCidAndStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetDealsByPieceCidAndStatus(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] @@ -153,7 +96,10 @@ func testGetDealsByPieceCidAndStatus(t *testing.T, r repo.Repo, mock sqlmock.Sql assert.Equal(t, deal, res[0]) } -func testGetDealsByDataCidAndDealStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetDealsByDataCidAndDealStatus(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] @@ -175,7 +121,10 @@ func testGetDealsByDataCidAndDealStatus(t *testing.T, r repo.Repo, mock sqlmock. assert.Equal(t, deal, res[0]) } -func testGetDealByAddrAndStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetDealByAddrAndStatus(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] @@ -197,7 +146,10 @@ func testGetDealByAddrAndStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) assert.Equal(t, deal, res[0]) } -func testGetGetDealByDealID(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetGetDealByDealID(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] @@ -218,7 +170,10 @@ func testGetGetDealByDealID(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, deal, res) } -func testGetDealsByPieceStatusAndDealStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetDealsByPieceStatusAndDealStatus(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] @@ -258,7 +213,10 @@ func testGetDealsByPieceStatusAndDealStatus(t *testing.T, r repo.Repo, mock sqlm }) } -func testUpdateDealStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestUpdateDealStatus(t *testing.T) { + r, mock, _, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] targetDealStatus := storagemarket.StorageDealAwaitingPreCommit @@ -280,7 +238,10 @@ func testUpdateDealStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.NoError(t, err) } -func testListDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListDeal(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + db, err := getMysqlDryrunDB() assert.NoError(t, err) @@ -300,7 +261,10 @@ func testListDeal(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { } } -func testListDealByAddr(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListDealByAddr(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] db, err := getMysqlDryrunDB() @@ -321,7 +285,10 @@ func testListDealByAddr(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, deal, res[0]) } -func testGetPieceInfo(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetPieceInfo(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] @@ -354,7 +321,10 @@ func testGetPieceInfo(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, pInfo, res) } -func testListPieceInfoKeys(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestListPieceInfoKeys(t *testing.T) { + r, mock, dbStorageDealCases, _, done := prepareStorageDealRepoTest(t) + defer done() + dbDeal := dbStorageDealCases[0] cids, err := cid.Decode(dbDeal.PieceCID.String()) assert.NoError(t, err) @@ -378,7 +348,10 @@ func testListPieceInfoKeys(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, cids, res[0]) } -func testGetPieceSize(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func TestGetPieceSize(t *testing.T) { + r, mock, dbStorageDealCases, storageDealCases, done := prepareStorageDealRepoTest(t) + defer done() + deal := storageDealCases[0] dbDeal := dbStorageDealCases[0] @@ -400,7 +373,10 @@ func testGetPieceSize(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { assert.Equal(t, abi.PaddedPieceSize(dbDeal.PieceSize), paddedPieceSize) } -func test_storageDealRepo_GroupDealsByStatus(t *testing.T, r repo.Repo, mock sqlmock.Sqlmock) { +func Test_storageDealRepo_GroupDealsByStatus(t *testing.T) { + r, mock, _, _, done := prepareStorageDealRepoTest(t) + defer done() + ctx := context.Background() t.Run("correct", func(t *testing.T) { @@ -435,25 +411,74 @@ func test_storageDealRepo_GroupDealsByStatus(t *testing.T, r repo.Repo, mock sql }) } -func TestStorageDealRepo(t *testing.T) { - r, mock, sqlDB := setup(t) +func prepareStorageDealRepoTest(t *testing.T) (repo.Repo, sqlmock.Sqlmock, []*storageDeal, []*types.MinerDeal, func()) { + var dbStorageDealCases []*storageDeal + var storageDealCases []*types.MinerDeal + + cid1, err := getTestCid() + assert.NoError(t, err) + cid2, err := getTestCid() + assert.NoError(t, err) + + peer1, err := getTestPeerId() + assert.NoError(t, err) + peer2, err := getTestPeerId() + assert.NoError(t, err) - t.Run("mysql test SaveDeal", wrapper(testSaveDeal, r, mock)) - t.Run("mysql test GetDeal", wrapper(testGetDeal, r, mock)) - t.Run("mysql test GetDealsByPieceCidAndStatus", wrapper(testGetDealsByPieceCidAndStatus, r, mock)) - t.Run("mysql test GetDealsByDataCidAndDealStatus", wrapper(testGetDealsByDataCidAndDealStatus, r, mock)) + temp := []*types.MinerDeal{ + { + ProposalCid: cid1, + Miner: peer1, + Client: peer1, + ClientDealProposal: market.ClientDealProposal{ + Proposal: market.DealProposal{ + Provider: getTestAddress(), + PieceCID: cid1, + }, + }, + State: storagemarket.StorageDealActive, + TimeStamp: types.TimeStamp{ + CreatedAt: uint64(time.Now().Unix()), + UpdatedAt: uint64(time.Now().Unix()), + }, + Ref: &storagemarket.DataRef{ + Root: cid1, + }, + }, + { + ProposalCid: cid2, + Miner: peer2, + Client: peer2, + ClientDealProposal: market.ClientDealProposal{ + Proposal: market.DealProposal{ + Provider: getTestAddress(), + PieceCID: cid2, + }, + }, + State: storagemarket.StorageDealActive, + TimeStamp: types.TimeStamp{ + CreatedAt: uint64(time.Now().Unix()), + UpdatedAt: uint64(time.Now().Unix()), + }, + Ref: &storagemarket.DataRef{ + Root: cid2, + }, + }, + } - t.Run("mysql test GetDealByAddrAndStatus", wrapper(testGetDealByAddrAndStatus, r, mock)) - t.Run("mysql test GetDealByDealID", wrapper(testGetGetDealByDealID, r, mock)) - t.Run("mysql test GetDealsByPieceStatus", wrapper(testGetDealsByPieceStatusAndDealStatus, r, mock)) + storageDealCases = make([]*types.MinerDeal, 0) + dbStorageDealCases = make([]*storageDeal, 0) + for _, v := range temp { + dbDeal := fromStorageDeal(v) + deal, err := toStorageDeal(dbDeal) + assert.NoError(t, err) + dbStorageDealCases = append(dbStorageDealCases, dbDeal) + storageDealCases = append(storageDealCases, deal) + } - t.Run("mysql test UpdateDealStatus", wrapper(testUpdateDealStatus, r, mock)) - t.Run("mysql test ListDeal", wrapper(testListDeal, r, mock)) - t.Run("mysql test ListDealByAddr", wrapper(testListDealByAddr, r, mock)) - t.Run("mysql test GetPieceInfo", wrapper(testGetPieceInfo, r, mock)) - t.Run("mysql test ListPieceInfoKeys", wrapper(testListPieceInfoKeys, r, mock)) - t.Run("mysql test GetPieceSize", wrapper(testGetPieceSize, r, mock)) - t.Run("mysql test GroupStorageDealNumberByStatus", wrapper(test_storageDealRepo_GroupDealsByStatus, r, mock)) + r, mock, sqlDB := setup(t) - assert.NoError(t, closeDB(mock, sqlDB)) + return r, mock, dbStorageDealCases, storageDealCases, func() { + assert.NoError(t, closeDB(mock, sqlDB)) + } } diff --git a/models/mysql/testing.go b/models/mysql/testing.go index 342a83ba..3115355b 100644 --- a/models/mysql/testing.go +++ b/models/mysql/testing.go @@ -41,12 +41,6 @@ func setup(t *testing.T) (repo.Repo, sqlmock.Sqlmock, *sql.DB) { return MysqlRepo{DB: gormDB}, mock, sqlDB } -func wrapper(f func(*testing.T, repo.Repo, sqlmock.Sqlmock), repo repo.Repo, mock sqlmock.Sqlmock) func(t *testing.T) { - return func(t *testing.T) { - f(t, repo, mock) - } -} - func closeDB(mock sqlmock.Sqlmock, sqlDB *sql.DB) error { mock.ExpectClose() return sqlDB.Close() diff --git a/models/paych_test.go b/models/paych_test.go index c5c022b1..bd110cf1 100644 --- a/models/paych_test.go +++ b/models/paych_test.go @@ -15,19 +15,17 @@ import ( "github.com/stretchr/testify/assert" ) -func TestPaych(t *testing.T) { - t.Run("mysql", func(t *testing.T) { - testChannelInfo(t, MysqlDB(t).PaychChannelInfoRepo(), MysqlDB(t).PaychMsgInfoRepo()) - testMsgInfo(t, MysqlDB(t).PaychMsgInfoRepo()) - }) - - t.Run("badger", func(t *testing.T) { - db := BadgerDB(t) - msgPaych := badger.NewPayMsgRepo(badger.NewPayChanMsgDs(db)) - ps := badger.NewPaychRepo(badger.NewPayChanInfoDs(db), msgPaych) - testChannelInfo(t, ps, msgPaych) - testMsgInfo(t, msgPaych) - }) +func TestPaychMysql(t *testing.T) { + testChannelInfo(t, MysqlDB(t).PaychChannelInfoRepo(), MysqlDB(t).PaychMsgInfoRepo()) + testMsgInfo(t, MysqlDB(t).PaychMsgInfoRepo()) +} + +func TestPaychBadger(t *testing.T) { + db := BadgerDB(t) + msgPaych := badger.NewPayMsgRepo(badger.NewPayChanMsgDs(db)) + ps := badger.NewPaychRepo(badger.NewPayChanInfoDs(db), msgPaych) + testChannelInfo(t, ps, msgPaych) + testMsgInfo(t, msgPaych) } func testChannelInfo(t *testing.T, channelRepo repo.PaychChannelInfoRepo, msgRepo repo.PaychMsgInfoRepo) { diff --git a/models/retrieval_ask_test.go b/models/retrieval_ask_test.go index 3585d4d2..d80d2bf2 100644 --- a/models/retrieval_ask_test.go +++ b/models/retrieval_ask_test.go @@ -14,18 +14,16 @@ import ( ) // go test -v ./models -test.run TestRetrievalAsk -mysql='root:ko2005@tcp(127.0.0.1:3306)/storage_market?charset=utf8mb4&parseTime=True&loc=Local&timeout=10s' -func TestRetrievalAsk(t *testing.T) { - t.Run("mysql", func(t *testing.T) { - repo := MysqlDB(t) - retrievalAskRepo := repo.RetrievalAskRepo() - defer func() { require.NoError(t, repo.Close()) }() - testRetrievalAsk(t, retrievalAskRepo) - }) - - t.Run("badger", func(t *testing.T) { - db := BadgerDB(t) - testRetrievalAsk(t, badger.NewRetrievalAskRepo(db)) - }) +func TestRetrievalAskMysql(t *testing.T) { + repo := MysqlDB(t) + retrievalAskRepo := repo.RetrievalAskRepo() + defer func() { require.NoError(t, repo.Close()) }() + testRetrievalAsk(t, retrievalAskRepo) +} + +func TestRetrievalAskBadger(t *testing.T) { + db := BadgerDB(t) + testRetrievalAsk(t, badger.NewRetrievalAskRepo(db)) } func testRetrievalAsk(t *testing.T, rtAskRepo repo.IRetrievalAskRepo) { diff --git a/models/storage_ask_test.go b/models/storage_ask_test.go index 65af050a..8724c403 100644 --- a/models/storage_ask_test.go +++ b/models/storage_ask_test.go @@ -16,17 +16,16 @@ import ( "github.com/stretchr/testify/require" ) -func TestStorageAsk(t *testing.T) { - t.Run("mysql", func(t *testing.T) { - repo := MysqlDB(t) - askRepo := repo.StorageAskRepo() - defer func() { require.NoError(t, repo.Close()) }() - testStorageAsk(t, askRepo) - }) - t.Run("badger", func(t *testing.T) { - db := BadgerDB(t) - testStorageAsk(t, badger.NewStorageAskRepo(db)) - }) +func TestStorageAskMysql(t *testing.T) { + repo := MysqlDB(t) + askRepo := repo.StorageAskRepo() + defer func() { require.NoError(t, repo.Close()) }() + testStorageAsk(t, askRepo) +} + +func TestStorageAskBadger(t *testing.T) { + db := BadgerDB(t) + testStorageAsk(t, badger.NewStorageAskRepo(db)) } func testStorageAsk(t *testing.T, askRepo repo.IStorageAskRepo) { diff --git a/models/storage_deal_test.go b/models/storage_deal_test.go index ac9d3ba3..b96e6171 100644 --- a/models/storage_deal_test.go +++ b/models/storage_deal_test.go @@ -24,22 +24,22 @@ import ( typegen "github.com/whyrusleeping/cbor-gen" ) -func TestStorageDeal(t *testing.T) { - t.Run("MinerDealMarshal", testCborMarshal) - - t.Run("mysql", func(t *testing.T) { - repo := MysqlDB(t) - dealRepo := repo.StorageDealRepo() - defer func() { - _ = repo.Close() - }() - testStorageDeal(t, dealRepo) - }) - - t.Run("badger", func(t *testing.T) { - db := BadgerDB(t) - testStorageDeal(t, badger.NewStorageDealRepo(db)) - }) +func TestStorageDealMinerDealMarshal(t *testing.T) { + testCborMarshal(t) +} + +func TestStorageMysql(t *testing.T) { + repo := MysqlDB(t) + dealRepo := repo.StorageDealRepo() + defer func() { + _ = repo.Close() + }() + testStorageDeal(t, dealRepo) +} + +func TestStorageDealBadger(t *testing.T) { + db := BadgerDB(t) + testStorageDeal(t, badger.NewStorageDealRepo(db)) } func getTestMinerDeal(t *testing.T) *types.MinerDeal {