diff --git a/modules/core/04-channel/types/upgrade.go b/modules/core/04-channel/types/upgrade.go index 9789b50472f..44eda872f31 100644 --- a/modules/core/04-channel/types/upgrade.go +++ b/modules/core/04-channel/types/upgrade.go @@ -117,3 +117,9 @@ func (u *UpgradeError) GetErrorReceipt() ErrorReceipt { Message: fmt.Sprintf("ABCI code: %d: %s", code, restoreErrorString), } } + +// IsUpgradeError returns true if err is of type UpgradeError, otherwise false. +func IsUpgradeError(err error) bool { + _, ok := err.(*UpgradeError) + return ok +} diff --git a/modules/core/04-channel/types/upgrade_test.go b/modules/core/04-channel/types/upgrade_test.go index 5dc3babc8d9..87749cc5d49 100644 --- a/modules/core/04-channel/types/upgrade_test.go +++ b/modules/core/04-channel/types/upgrade_test.go @@ -171,3 +171,45 @@ func (suite *TypesTestSuite) TestUpgradeErrorUnwrap() { suite.Require().Equal(types.ErrInvalidChannel, unWrapped, "unwrapped error was not equal to base underlying error") suite.Require().Equal(originalUpgradeError, postUnwrapUpgradeError, "original error was modified when unwrapped") } + +func (suite *TypesTestSuite) TestIsUpgradeError() { + var err error + + testCases := []struct { + msg string + malleate func() + expPass bool + }{ + { + "true", + func() {}, + true, + }, + { + "false with non upgrade error", + func() { + err = errors.New("error") + }, + false, + }, + { + "false with nil error", + func() { + err = nil + }, + false, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.msg, func() { + err = types.NewUpgradeError(1, types.ErrInvalidChannel) + + tc.malleate() + + res := types.IsUpgradeError(err) + suite.Require().Equal(tc.expPass, res) + }) + } +} diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index bfb4ba148b0..94d3f6df43e 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -768,8 +768,8 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh upgrade, err := k.ChannelKeeper.ChanUpgradeTry(ctx, msg.PortId, msg.ChannelId, msg.ProposedUpgradeConnectionHops, msg.UpgradeTimeout, msg.CounterpartyProposedUpgrade, msg.CounterpartyUpgradeSequence, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight) if err != nil { ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed")) - if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok { - k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr) + if channeltypes.IsUpgradeError(err) { + k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err) cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId) // NOTE: a FAILURE result is returned to the client and an error receipt is written to state. @@ -823,8 +823,8 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh err = k.ChannelKeeper.ChanUpgradeAck(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyFlushStatus, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight) if err != nil { ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed")) - if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok { - k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr) + if channeltypes.IsUpgradeError(err) { + k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err) cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId) // NOTE: a FAILURE result is returned to the client and an error receipt is written to state.