Skip to content

Commit

Permalink
Combine functions TestGetSourceCallbackDataTransfer and TestGetDestCa…
Browse files Browse the repository at this point in the history
…llbackDataTransfer (#7694)

* Combine get dest and source callback tests

* lint
  • Loading branch information
lacsomot authored Dec 21, 2024
1 parent 0a7756f commit 3e7348d
Showing 1 changed file with 101 additions and 117 deletions.
218 changes: 101 additions & 117 deletions modules/apps/callbacks/types/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
transfertypes "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types"
clienttypes "github.com/cosmos/ibc-go/v9/modules/core/02-client/types"
channeltypes "github.com/cosmos/ibc-go/v9/modules/core/04-channel/types"
porttypes "github.com/cosmos/ibc-go/v9/modules/core/05-port/types"
ibcexported "github.com/cosmos/ibc-go/v9/modules/core/exported"
ibctesting "github.com/cosmos/ibc-go/v9/testing"
ibcmock "github.com/cosmos/ibc-go/v9/testing/mock"
Expand Down Expand Up @@ -567,154 +568,134 @@ type bytesProvider interface {
GetBytes() []byte
}

func (s *CallbacksTypesTestSuite) TestGetSourceCallbackDataTransfer() {
func (s *CallbacksTypesTestSuite) TestGetDestSourceCallbackDataTransfer() {
sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String()
receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String()

testCases := []struct {
name string
var (
packetData bytesProvider
expCallbackData types.CallbackData
malleate func()
)

expSrcCallBack := types.CallbackData{
CallbackAddress: sender,
SenderAddress: sender,
ExecutionGasLimit: 1_000_000,
CommitGasLimit: 1_000_000,
ApplicationVersion: transfertypes.V1,
}

expDstCallBack := types.CallbackData{
CallbackAddress: sender,
SenderAddress: "",
ExecutionGasLimit: 1_000_000,
CommitGasLimit: 1_000_000,
ApplicationVersion: transfertypes.V1,
}

testCases := []struct {
name string
malleate func()
callbackFn func(
ctx sdk.Context,
packetDataUnmarshaler porttypes.PacketDataUnmarshaler,
packet channeltypes.Packet,
maxGas uint64,
) (types.CallbackData, error)
getSrc bool
}{
{
"success: v1",
transfertypes.FungibleTokenPacketData{
Denom: ibctesting.TestCoin.Denom,
Amount: ibctesting.TestCoin.Amount.String(),
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender),
},
types.CallbackData{
CallbackAddress: sender,
SenderAddress: sender,
ExecutionGasLimit: 1_000_000,
CommitGasLimit: 1_000_000,
ApplicationVersion: transfertypes.V1,
},
"success: src_callback v1",
func() {
packetData = transfertypes.FungibleTokenPacketData{
Denom: ibctesting.TestCoin.Denom,
Amount: ibctesting.TestCoin.Amount.String(),
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender),
}

expCallbackData = expSrcCallBack

s.path.EndpointA.ChannelConfig.Version = transfertypes.V1
s.path.EndpointA.ChannelConfig.PortID = transfertypes.ModuleName
s.path.EndpointB.ChannelConfig.Version = transfertypes.V1
s.path.EndpointB.ChannelConfig.PortID = transfertypes.ModuleName
},
types.GetSourceCallbackData,
true,
},
{
"success: v2",
transfertypes.FungibleTokenPacketDataV2{
Tokens: transfertypes.Tokens{
{
Denom: transfertypes.NewDenom(ibctesting.TestCoin.Denom),
Amount: ibctesting.TestCoin.Amount.String(),
},
},
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender),
},
types.CallbackData{
CallbackAddress: sender,
SenderAddress: sender,
ExecutionGasLimit: 1_000_000,
CommitGasLimit: 1_000_000,
ApplicationVersion: transfertypes.V2,
},
"success: src_callback v2",
func() {
packetData = transfertypes.FungibleTokenPacketDataV2{
Tokens: transfertypes.Tokens{
{
Denom: transfertypes.NewDenom(ibctesting.TestCoin.Denom),
Amount: ibctesting.TestCoin.Amount.String(),
},
},
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender),
}

expCallbackData = expSrcCallBack
expCallbackData.ApplicationVersion = transfertypes.V2

s.path.EndpointA.ChannelConfig.Version = transfertypes.V2
s.path.EndpointA.ChannelConfig.PortID = transfertypes.ModuleName
s.path.EndpointB.ChannelConfig.Version = transfertypes.V2
s.path.EndpointB.ChannelConfig.PortID = transfertypes.ModuleName
},
types.GetSourceCallbackData,
true,
},
}

for _, tc := range testCases {
tc := tc
s.Run(tc.name, func() {
s.SetupTest()

tc.malleate()

packetDataBytes := tc.packetData.GetBytes()

transferStack, ok := s.chainA.App.GetIBCKeeper().PortKeeper.Route(transfertypes.ModuleName)
s.Require().True(ok)

packetUnmarshaler, ok := transferStack.(types.CallbacksCompatibleModule)
s.Require().True(ok)

s.path.Setup()

gasMeter := storetypes.NewGasMeter(2_000_000)
ctx := s.chainA.GetContext().WithGasMeter(gasMeter)
packet := channeltypes.NewPacket(packetDataBytes, 0, transfertypes.PortID, s.path.EndpointA.ChannelID, transfertypes.PortID, s.path.EndpointB.ChannelID, clienttypes.ZeroHeight(), 0)
callbackData, err := types.GetSourceCallbackData(ctx, packetUnmarshaler, packet, 1_000_000)
s.Require().NoError(err)
s.Require().Equal(tc.expCallbackData, callbackData)
})
}
}

func (s *CallbacksTypesTestSuite) TestGetDestCallbackDataTransfer() {
sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String()
receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String()

testCases := []struct {
name string
packetData bytesProvider
expCallbackdata types.CallbackData
malleate func()
}{
{
"success: v1",
transfertypes.FungibleTokenPacketData{
Denom: ibctesting.TestCoin.Denom,
Amount: ibctesting.TestCoin.Amount.String(),
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, sender),
},
types.CallbackData{
CallbackAddress: sender,
SenderAddress: "",
ExecutionGasLimit: 1_000_000,
CommitGasLimit: 1_000_000,
ApplicationVersion: transfertypes.V1,
},
"success: dest_callback v1",
func() {
packetData = transfertypes.FungibleTokenPacketData{
Denom: ibctesting.TestCoin.Denom,
Amount: ibctesting.TestCoin.Amount.String(),
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, sender),
}

expCallbackData = expDstCallBack

s.path.EndpointA.ChannelConfig.Version = transfertypes.V1
s.path.EndpointA.ChannelConfig.PortID = transfertypes.ModuleName
s.path.EndpointB.ChannelConfig.Version = transfertypes.V1
s.path.EndpointB.ChannelConfig.PortID = transfertypes.ModuleName
},
types.GetDestCallbackData,
false,
},
{
"success: v2",
transfertypes.FungibleTokenPacketDataV2{
Tokens: transfertypes.Tokens{
{
Denom: transfertypes.NewDenom(ibctesting.TestCoin.Denom),
Amount: ibctesting.TestCoin.Amount.String(),
},
},
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, sender),
},
types.CallbackData{
CallbackAddress: sender,
SenderAddress: "",
ExecutionGasLimit: 1_000_000,
CommitGasLimit: 1_000_000,
ApplicationVersion: transfertypes.V2,
},
"success: dest_callback v2",
func() {
packetData = transfertypes.FungibleTokenPacketDataV2{
Tokens: transfertypes.Tokens{
{
Denom: transfertypes.NewDenom(ibctesting.TestCoin.Denom),
Amount: ibctesting.TestCoin.Amount.String(),
},
},
Sender: sender,
Receiver: receiver,
Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, sender),
}

expCallbackData = expDstCallBack
expCallbackData.ApplicationVersion = transfertypes.V2

s.path.EndpointA.ChannelConfig.Version = transfertypes.V2
s.path.EndpointA.ChannelConfig.PortID = transfertypes.ModuleName
s.path.EndpointB.ChannelConfig.Version = transfertypes.V2
s.path.EndpointB.ChannelConfig.PortID = transfertypes.ModuleName
},
types.GetDestCallbackData,
false,
},
}

Expand All @@ -725,8 +706,6 @@ func (s *CallbacksTypesTestSuite) TestGetDestCallbackDataTransfer() {

tc.malleate()

packetDataBytes := tc.packetData.GetBytes()

transferStack, ok := s.chainA.App.GetIBCKeeper().PortKeeper.Route(transfertypes.ModuleName)
s.Require().True(ok)

Expand All @@ -737,10 +716,15 @@ func (s *CallbacksTypesTestSuite) TestGetDestCallbackDataTransfer() {

gasMeter := storetypes.NewGasMeter(2_000_000)
ctx := s.chainA.GetContext().WithGasMeter(gasMeter)
packet := channeltypes.NewPacket(packetDataBytes, 0, transfertypes.PortID, s.path.EndpointB.ChannelID, transfertypes.PortID, s.path.EndpointA.ChannelID, clienttypes.ZeroHeight(), 0)
callbackData, err := types.GetDestCallbackData(ctx, packetUnmarshaler, packet, 1_000_000)
var packet channeltypes.Packet
if tc.getSrc {
packet = channeltypes.NewPacket(packetData.GetBytes(), 0, transfertypes.PortID, s.path.EndpointA.ChannelID, transfertypes.PortID, s.path.EndpointB.ChannelID, clienttypes.ZeroHeight(), 0)
} else {
packet = channeltypes.NewPacket(packetData.GetBytes(), 0, transfertypes.PortID, s.path.EndpointB.ChannelID, transfertypes.PortID, s.path.EndpointA.ChannelID, clienttypes.ZeroHeight(), 0)
}
callbackData, err := tc.callbackFn(ctx, packetUnmarshaler, packet, 1_000_000)
s.Require().NoError(err)
s.Require().Equal(tc.expCallbackdata, callbackData)
s.Require().Equal(expCallbackData, callbackData)
})
}
}
Expand Down

0 comments on commit 3e7348d

Please sign in to comment.