Skip to content

Commit

Permalink
add IsUpgradError func for readability (#4144)
Browse files Browse the repository at this point in the history
* add IsUpgradeError function

* Update modules/core/04-channel/types/upgrade.go

Co-authored-by: Jim Fasarakis-Hilliard <[email protected]>

* gofumpt

* Update modules/core/04-channel/types/upgrade.go

Co-authored-by: Charly <[email protected]>

---------

Co-authored-by: Jim Fasarakis-Hilliard <[email protected]>
Co-authored-by: Charly <[email protected]>
  • Loading branch information
3 people authored Jul 25, 2023
1 parent 4072cbd commit d9e8131
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
6 changes: 6 additions & 0 deletions modules/core/04-channel/types/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,9 @@ func (u *UpgradeError) GetErrorReceipt() ErrorReceipt {
Message: fmt.Sprintf("ABCI code: %d: %s", code, restoreErrorString),
}
}

// IsUpgradeError returns true if err is of type UpgradeError, otherwise false.
func IsUpgradeError(err error) bool {
_, ok := err.(*UpgradeError)
return ok
}
42 changes: 42 additions & 0 deletions modules/core/04-channel/types/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,45 @@ func (suite *TypesTestSuite) TestUpgradeErrorUnwrap() {
suite.Require().Equal(types.ErrInvalidChannel, unWrapped, "unwrapped error was not equal to base underlying error")
suite.Require().Equal(originalUpgradeError, postUnwrapUpgradeError, "original error was modified when unwrapped")
}

func (suite *TypesTestSuite) TestIsUpgradeError() {
var err error

testCases := []struct {
msg string
malleate func()
expPass bool
}{
{
"true",
func() {},
true,
},
{
"false with non upgrade error",
func() {
err = errors.New("error")
},
false,
},
{
"false with nil error",
func() {
err = nil
},
false,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.msg, func() {
err = types.NewUpgradeError(1, types.ErrInvalidChannel)

tc.malleate()

res := types.IsUpgradeError(err)
suite.Require().Equal(tc.expPass, res)
})
}
}
8 changes: 4 additions & 4 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,8 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh
upgrade, err := k.ChannelKeeper.ChanUpgradeTry(ctx, msg.PortId, msg.ChannelId, msg.ProposedUpgradeConnectionHops, msg.UpgradeTimeout, msg.CounterpartyProposedUpgrade, msg.CounterpartyUpgradeSequence, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed"))
if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok {
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr)
if channeltypes.IsUpgradeError(err) {
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)
cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)

// NOTE: a FAILURE result is returned to the client and an error receipt is written to state.
Expand Down Expand Up @@ -823,8 +823,8 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh
err = k.ChannelKeeper.ChanUpgradeAck(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyFlushStatus, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed"))
if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok {
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr)
if channeltypes.IsUpgradeError(err) {
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)
cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)

// NOTE: a FAILURE result is returned to the client and an error receipt is written to state.
Expand Down

0 comments on commit d9e8131

Please sign in to comment.