diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 3d50a312c66..3d3c8152c54 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -178,14 +178,9 @@ func (k Keeper) ChanUpgradeTry( return upgrade, nil } -// WriteUpgradeTryChannel writes a channel which has successfully passed the UpgradeTry handshake step. +// WriteUpgradeTryChannel writes the channel end and upgrade to state after successfully passing the UpgradeTry handshake step. // An event is emitted for the handshake step. -func (k Keeper) WriteUpgradeTryChannel( - ctx sdk.Context, - portID, channelID string, - proposedUpgrade types.Upgrade, - flushStatus types.FlushStatus, -) { +func (k Keeper) WriteUpgradeTryChannel(ctx sdk.Context, portID, channelID string, upgrade types.Upgrade, upgradeVersion string) (types.Channel, types.Upgrade) { defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-try") channel, found := k.GetChannel(ctx, portID, channelID) @@ -195,13 +190,22 @@ func (k Keeper) WriteUpgradeTryChannel( previousState := channel.State channel.State = types.TRYUPGRADE - channel.FlushStatus = flushStatus + // TODO: determine flush status + // channel.FlushStatus = flushStatus + + upgrade.Fields.Version = upgradeVersion k.SetChannel(ctx, portID, channelID, channel) - k.SetUpgrade(ctx, portID, channelID, proposedUpgrade) + k.SetUpgrade(ctx, portID, channelID, upgrade) k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.TRYUPGRADE.String()) - emitChannelUpgradeTryEvent(ctx, portID, channelID, channel, proposedUpgrade) + emitChannelUpgradeTryEvent(ctx, portID, channelID, channel, upgrade) + + return channel, upgrade +} + +func (k Keeper) AbortUpgrade(ctx sdk.Context, portID, channelID string, err error) error { + return nil } // startFlushUpgradeHandshake will verify the counterparty proposed upgrade and the current channel state. diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index e25f753e889..33df77fb5b9 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -751,7 +751,56 @@ func (k Keeper) ChannelUpgradeInit(goCtx context.Context, msg *channeltypes.MsgC // ChannelUpgradeTry defines a rpc handler method for MsgChannelUpgradeTry. func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgChannelUpgradeTry) (*channeltypes.MsgChannelUpgradeTryResponse, error) { - return nil, nil + ctx := sdk.UnwrapSDKContext(goCtx) + + module, _, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId) + if err != nil { + ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrap(err, "could not retrieve module from port-id")) + return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") + } + + cbs, ok := k.Router.GetRoute(module) + if !ok { + ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) + } + + upgrade, err := k.ChannelKeeper.ChanUpgradeTry(ctx, msg.PortId, msg.ChannelId, msg.ProposedUpgradeConnectionHops, msg.UpgradeTimeout, msg.CounterpartyProposedUpgrade, msg.CounterpartyUpgradeSequence, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight) + if err != nil { + if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok { + if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr); err != nil { + return nil, err + } + + // NOTE: a FAILURE result is returned to the client and an error receipt is written to state. + // This signals to the relayer to begin the cancel upgrade handshake subprotocol. + return &channeltypes.MsgChannelUpgradeTryResponse{Result: channeltypes.FAILURE}, nil + } + + // NOTE: an error is returned to baseapp and transaction state is not committed. + return nil, err + } + + cacheCtx, writeFn := ctx.CacheContext() + upgradeVersion, err := cbs.OnChanUpgradeTry(cacheCtx, msg.PortId, msg.ChannelId, upgrade.Fields.Ordering, upgrade.Fields.ConnectionHops, upgrade.Fields.Version) + if err != nil { + if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, err); err != nil { + return nil, err + } + + return &channeltypes.MsgChannelUpgradeTryResponse{Result: channeltypes.FAILURE}, nil + } + + writeFn() + + channel, upgrade := k.ChannelKeeper.WriteUpgradeTryChannel(ctx, msg.PortId, msg.ChannelId, upgrade, upgradeVersion) + + return &channeltypes.MsgChannelUpgradeTryResponse{ + Result: channeltypes.SUCCESS, + ChannelId: msg.ChannelId, + Upgrade: upgrade, + UpgradeSequence: channel.UpgradeSequence, + }, nil } // ChannelUpgradeAck defines a rpc handler method for MsgChannelUpgradeAck. diff --git a/modules/core/keeper/msg_server_test.go b/modules/core/keeper/msg_server_test.go index b46dbcff4c8..3dd53ad2d7d 100644 --- a/modules/core/keeper/msg_server_test.go +++ b/modules/core/keeper/msg_server_test.go @@ -1,7 +1,10 @@ package keeper_test import ( + "fmt" + sdk "github.com/cosmos/cosmos-sdk/types" + capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types" upgradetypes "github.com/cosmos/cosmos-sdk/x/upgrade/types" clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" @@ -774,3 +777,150 @@ func (suite *KeeperTestSuite) TestUpgradeClient() { } } } + +func (suite *KeeperTestSuite) TestChannelUpgradeTry() { + var ( + path *ibctesting.Path + msg *channeltypes.MsgChannelUpgradeTry + ) + + cases := []struct { + name string + malleate func() + expResult func(res *channeltypes.MsgChannelUpgradeTryResponse, err error) + }{ + { + "success", + func() {}, + func(res *channeltypes.MsgChannelUpgradeTryResponse, err error) { + suite.Require().NoError(err) + suite.Require().NotNil(res) + suite.Require().Equal(channeltypes.SUCCESS, res.Result) + + channel := path.EndpointB.GetChannel() + suite.Require().Equal(channeltypes.TRYUPGRADE, channel.State) + suite.Require().Equal(uint64(1), channel.UpgradeSequence) + }, + }, + { + "module capability not found", + func() { + msg.PortId = "invalid-port" + msg.ChannelId = "invalid-channel" + }, + func(res *channeltypes.MsgChannelUpgradeTryResponse, err error) { + suite.Require().Error(err) + suite.Require().Nil(res) + + suite.Require().ErrorIs(err, capabilitytypes.ErrCapabilityNotFound) + }, + }, + { + "elapsed upgrade timeout returns error", + func() { + msg.UpgradeTimeout = channeltypes.NewTimeout(clienttypes.NewHeight(1, 10), 0) + suite.coordinator.CommitNBlocks(suite.chainB, 100) + }, + func(res *channeltypes.MsgChannelUpgradeTryResponse, err error) { + suite.Require().Error(err) + suite.Require().Nil(res) + suite.Require().ErrorIs(err, channeltypes.ErrInvalidUpgrade) + + errorReceipt, found := suite.chainB.GetSimApp().GetIBCKeeper().ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().Empty(errorReceipt) + suite.Require().False(found) + }, + }, + { + "unsynchronized upgrade sequence writes upgrade error receipt", + func() { + channel := path.EndpointB.GetChannel() + channel.UpgradeSequence = 100 + + path.EndpointB.SetChannel(channel) + }, + func(res *channeltypes.MsgChannelUpgradeTryResponse, err error) { + suite.Require().NoError(err) + + suite.Require().NotNil(res) + suite.Require().Equal(channeltypes.FAILURE, res.Result) + + // TODO: assert error receipt exists for the upgrade sequence when RestoreChannel / AbortUpgrade is called + // errorReceipt, found := suite.chainB.GetSimApp().GetIBCKeeper().ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + // suite.Require().True(found) + }, + }, + { + "application callback error writes upgrade error receipt", + func() { + suite.chainB.GetSimApp().IBCMockModule.IBCApp.OnChanUpgradeTry = func( + ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string, + ) (string, error) { + // set arbitrary value in store to mock application state changes + store := ctx.KVStore(suite.chainB.GetSimApp().GetKey(exported.ModuleName)) + store.Set([]byte("foo"), []byte("bar")) + return "", fmt.Errorf("mock app callback failed") + } + }, + func(res *channeltypes.MsgChannelUpgradeTryResponse, err error) { + suite.Require().NoError(err) + + suite.Require().NotNil(res) + suite.Require().Equal(channeltypes.FAILURE, res.Result) + + // TODO: assert error receipt exists for the upgrade sequence when RestoreChannel / AbortUpgrade is called + // errorReceipt, found := suite.chainB.GetSimApp().GetIBCKeeper().ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + // suite.Require().True(found) + + // assert application state changes are not committed + store := suite.chainB.GetContext().KVStore(suite.chainB.GetSimApp().GetKey(exported.ModuleName)) + suite.Require().False(store.Has([]byte("foo"))) + }, + }, + } + + for _, tc := range cases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + // configure the channel upgrade version on testing endpoints + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointB.UpdateClient() + suite.Require().NoError(err) + + counterpartySequence := path.EndpointA.GetChannel().UpgradeSequence + counterpartyUpgrade, found := suite.chainA.GetSimApp().GetIBCKeeper().ChannelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + + proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + + msg = &channeltypes.MsgChannelUpgradeTry{ + PortId: path.EndpointB.ChannelConfig.PortID, + ChannelId: path.EndpointB.ChannelID, + ProposedUpgradeConnectionHops: []string{ibctesting.FirstConnectionID}, + UpgradeTimeout: channeltypes.NewTimeout(path.EndpointA.Chain.GetTimeoutHeight(), 0), + CounterpartyUpgradeSequence: counterpartySequence, + CounterpartyProposedUpgrade: counterpartyUpgrade, + ProofChannel: proofChannel, + ProofUpgrade: proofUpgrade, + ProofHeight: proofHeight, + Signer: suite.chainB.SenderAccount.GetAddress().String(), + } + + tc.malleate() + + res, err := suite.chainB.GetSimApp().GetIBCKeeper().ChannelUpgradeTry(suite.chainB.GetContext(), msg) + + tc.expResult(res, err) + }) + } +}