Skip to content

Commit

Permalink
Use connection id instead of channel id in IsMiddlewareEnabled (cosmo…
Browse files Browse the repository at this point in the history
  • Loading branch information
chatton authored Sep 20, 2022
1 parent 888c4a0 commit 6b7d67f
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 34 deletions.
32 changes: 26 additions & 6 deletions modules/apps/27-interchain-accounts/controller/ibc_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type IBCMiddleware struct {
keeper keeper.Keeper
}

// IBCMiddleware creates a new IBCMiddleware given the associated keeper and underlying application
// NewIBCMiddleware creates a new IBCMiddleware given the associated keeper and underlying application
func NewIBCMiddleware(app porttypes.IBCModule, k keeper.Keeper) IBCMiddleware {
return IBCMiddleware{
app: app,
Expand Down Expand Up @@ -63,7 +63,7 @@ func (im IBCMiddleware) OnChanOpenInit(
// call underlying app's OnChanOpenInit callback with the passed in version
// the version returned is discarded as the ica-auth module does not have permission to edit the version string.
// ics27 will always return the version string containing the Metadata struct which is created during the `RegisterInterchainAccount` call.
if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, channelID) {
if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, connectionHops[0]) {
if _, err := im.app.OnChanOpenInit(ctx, order, connectionHops, portID, channelID, nil, counterparty, version); err != nil {
return "", err
}
Expand Down Expand Up @@ -107,8 +107,13 @@ func (im IBCMiddleware) OnChanOpenAck(
return err
}

connectionID, err := im.keeper.GetConnectionID(ctx, portID, channelID)
if err != nil {
return err
}

// call underlying app's OnChanOpenAck callback with the counterparty app version.
if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, channelID) {
if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, connectionID) {
return im.app.OnChanOpenAck(ctx, portID, channelID, counterpartyChannelID, counterpartyVersion)
}

Expand Down Expand Up @@ -144,7 +149,12 @@ func (im IBCMiddleware) OnChanCloseConfirm(
return err
}

if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, channelID) {
connectionID, err := im.keeper.GetConnectionID(ctx, portID, channelID)
if err != nil {
return err
}

if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, connectionID) {
return im.app.OnChanCloseConfirm(ctx, portID, channelID)
}

Expand Down Expand Up @@ -174,8 +184,13 @@ func (im IBCMiddleware) OnAcknowledgementPacket(
return types.ErrControllerSubModuleDisabled
}

connectionID, err := im.keeper.GetConnectionID(ctx, packet.GetSourcePort(), packet.GetSourceChannel())
if err != nil {
return err
}

// call underlying app's OnAcknowledgementPacket callback.
if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) {
if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, packet.GetSourcePort(), connectionID) {
return im.app.OnAcknowledgementPacket(ctx, packet, acknowledgement, relayer)
}

Expand All @@ -196,7 +211,12 @@ func (im IBCMiddleware) OnTimeoutPacket(
return err
}

if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) {
connectionID, err := im.keeper.GetConnectionID(ctx, packet.GetSourcePort(), packet.GetSourceChannel())
if err != nil {
return err
}

if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, packet.GetSourcePort(), connectionID) {
return im.app.OnTimeoutPacket(ctx, packet, relayer)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (suite *InterchainAccountsTestSuite) TestOnChanOpenInit() {
},
{
"middleware disabled", func() {
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ConnectionID)

suite.chainA.GetSimApp().ICAAuthModule.IBCApp.OnChanOpenInit = func(ctx sdk.Context, order channeltypes.Order, connectionHops []string,
portID, channelID string, chanCap *capabilitytypes.Capability,
Expand Down Expand Up @@ -213,7 +213,7 @@ func (suite *InterchainAccountsTestSuite) TestOnChanOpenInit() {
path.EndpointA.ChannelConfig.PortID = portID
path.EndpointA.ChannelID = ibctesting.FirstChannelID

suite.chainA.GetSimApp().ICAControllerKeeper.SetMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.chainA.GetSimApp().ICAControllerKeeper.SetMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ConnectionID)

// default values
counterparty := channeltypes.NewCounterparty(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
Expand Down Expand Up @@ -346,7 +346,7 @@ func (suite *InterchainAccountsTestSuite) TestOnChanOpenAck() {
},
{
"middleware disabled", func() {
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ConnectionID)

suite.chainA.GetSimApp().ICAAuthModule.IBCApp.OnChanOpenAck = func(
ctx sdk.Context, portID, channelID string, counterpartyChannelID string, counterpartyVersion string,
Expand Down Expand Up @@ -606,7 +606,7 @@ func (suite *InterchainAccountsTestSuite) TestOnAcknowledgementPacket() {
},
{
"middleware disabled", func() {
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ConnectionID)

suite.chainA.GetSimApp().ICAAuthModule.IBCApp.OnAcknowledgementPacket = func(
ctx sdk.Context, packet channeltypes.Packet, acknowledgement []byte, relayer sdk.AccAddress,
Expand Down Expand Up @@ -699,7 +699,7 @@ func (suite *InterchainAccountsTestSuite) TestOnTimeoutPacket() {
},
{
"middleware disabled", func() {
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.chainA.GetSimApp().ICAControllerKeeper.DeleteMiddlewareEnabled(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ConnectionID)

suite.chainA.GetSimApp().ICAAuthModule.IBCApp.OnTimeoutPacket = func(
ctx sdk.Context, packet channeltypes.Packet, relayer sdk.AccAddress,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ func (k Keeper) RegisterInterchainAccount(ctx sdk.Context, connectionID, owner,
return err
}

channelID, err := k.registerInterchainAccount(ctx, connectionID, portID, version)
_, err = k.registerInterchainAccount(ctx, connectionID, portID, version)
if err != nil {
return err
}

k.SetMiddlewareEnabled(ctx, portID, channelID)
k.SetMiddlewareEnabled(ctx, portID, connectionID)

return nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func InitGenesis(ctx sdk.Context, keeper Keeper, state genesistypes.ControllerGe
keeper.SetActiveChannelID(ctx, ch.ConnectionId, ch.PortId, ch.ChannelId)

if ch.IsMiddlewareEnabled {
keeper.SetMiddlewareEnabled(ctx, ch.PortId, ch.ChannelId)
keeper.SetMiddlewareEnabled(ctx, ch.PortId, ch.ConnectionId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (suite *KeeperTestSuite) TestInitGenesis() {
suite.Require().True(found)
suite.Require().Equal(ibctesting.FirstChannelID, channelID)

isMiddlewareEnabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareEnabled(suite.chainA.GetContext(), TestPortID, ibctesting.FirstChannelID)
isMiddlewareEnabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareEnabled(suite.chainA.GetContext(), TestPortID, ibctesting.FirstConnectionID)
suite.Require().True(isMiddlewareEnabled)

accountAdrr, found := suite.chainA.GetSimApp().ICAControllerKeeper.GetInterchainAccountAddress(suite.chainA.GetContext(), ibctesting.FirstConnectionID, TestPortID)
Expand Down
38 changes: 26 additions & 12 deletions modules/apps/27-interchain-accounts/controller/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/cosmos/cosmos-sdk/codec"
storetypes "github.com/cosmos/cosmos-sdk/store/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"
paramtypes "github.com/cosmos/cosmos-sdk/x/params/types"
"github.com/tendermint/tendermint/libs/log"
Expand Down Expand Up @@ -61,6 +62,15 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger {
return ctx.Logger().With("module", fmt.Sprintf("x/%s-%s", host.ModuleName, icatypes.ModuleName))
}

// GetConnectionID returns the connection id for the given port and channelIDs.
func (k Keeper) GetConnectionID(ctx sdk.Context, portID, channelID string) (string, error) {
channel, found := k.channelKeeper.GetChannel(ctx, portID, channelID)
if !found {
return "", sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}
return channel.ConnectionHops[0], nil
}

// GetAllPorts returns all ports to which the interchain accounts controller module is bound. Used in ExportGenesis
func (k Keeper) GetAllPorts(ctx sdk.Context) []string {
store := ctx.KVStore(k.storeKey)
Expand Down Expand Up @@ -144,11 +154,15 @@ func (k Keeper) GetAllActiveChannels(ctx sdk.Context) []genesistypes.ActiveChann
for ; iterator.Valid(); iterator.Next() {
keySplit := strings.Split(string(iterator.Key()), "/")

portID := keySplit[1]
connectionID := keySplit[2]
channelID := string(iterator.Value())

ch := genesistypes.ActiveChannel{
ConnectionId: keySplit[2],
PortId: keySplit[1],
ChannelId: string(iterator.Value()),
IsMiddlewareEnabled: k.IsMiddlewareEnabled(ctx, keySplit[1], string(iterator.Value())),
ConnectionId: connectionID,
PortId: portID,
ChannelId: channelID,
IsMiddlewareEnabled: k.IsMiddlewareEnabled(ctx, portID, connectionID),
}

activeChannels = append(activeChannels, ch)
Expand Down Expand Up @@ -208,20 +222,20 @@ func (k Keeper) SetInterchainAccountAddress(ctx sdk.Context, connectionID, portI
store.Set(icatypes.KeyOwnerAccount(portID, connectionID), []byte(address))
}

// IsMiddlewareEnabled returns true if the underlying application callbacks are enabled for given port and channel identifier pair, otherwise false
func (k Keeper) IsMiddlewareEnabled(ctx sdk.Context, portID, channelID string) bool {
// IsMiddlewareEnabled returns true if the underlying application callbacks are enabled for given port and connection identifier pair, otherwise false
func (k Keeper) IsMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) bool {
store := ctx.KVStore(k.storeKey)
return store.Has(icatypes.KeyIsMiddlewareEnabled(portID, channelID))
return store.Has(icatypes.KeyIsMiddlewareEnabled(portID, connectionID))
}

// SetMiddlewareEnabled stores a flag to indicate that the underlying application callbacks should be enabled for the given port and channel identifier pair
func (k Keeper) SetMiddlewareEnabled(ctx sdk.Context, portID, channelID string) {
// SetMiddlewareEnabled stores a flag to indicate that the underlying application callbacks should be enabled for the given port and connection identifier pair
func (k Keeper) SetMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) {
store := ctx.KVStore(k.storeKey)
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, channelID), []byte{byte(1)})
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), []byte{byte(1)})
}

// DeleteMiddlewareEnabled deletes the middleware enabled flag stored in state
func (k Keeper) DeleteMiddlewareEnabled(ctx sdk.Context, portID, channelID string) {
func (k Keeper) DeleteMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) {
store := ctx.KVStore(k.storeKey)
store.Delete(icatypes.KeyIsMiddlewareEnabled(portID, channelID))
store.Delete(icatypes.KeyIsMiddlewareEnabled(portID, connectionID))
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,18 @@ func (m Migrator) AssertChannelCapabilityMigrations(ctx sdk.Context) error {
if m.keeper != nil {
for _, ch := range m.keeper.GetAllActiveChannels(ctx) {
name := host.ChannelCapabilityPath(ch.PortId, ch.ChannelId)
cap, found := m.keeper.scopedKeeper.GetCapability(ctx, name)
capacity, found := m.keeper.scopedKeeper.GetCapability(ctx, name)
if !found {
return sdkerrors.Wrapf(capabilitytypes.ErrCapabilityNotFound, "failed to find capability: %s", name)
}

isAuthenticated := m.keeper.scopedKeeper.AuthenticateCapability(ctx, cap, name)
isAuthenticated := m.keeper.scopedKeeper.AuthenticateCapability(ctx, capacity, name)
if !isAuthenticated {
return sdkerrors.Wrapf(capabilitytypes.ErrCapabilityNotOwned, "expected capability owner: %s", types.SubModuleName)
}

m.keeper.SetMiddlewareEnabled(ctx, ch.PortId, ch.ChannelId)
m.keeper.SetMiddlewareEnabled(ctx, ch.PortId, ch.ConnectionId)
}
}

return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (suite *KeeperTestSuite) TestAssertChannelCapabilityMigrations() {
isMiddlewareEnabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareEnabled(
suite.chainA.GetContext(),
path.EndpointA.ChannelConfig.PortID,
path.EndpointA.ChannelID,
path.EndpointA.ConnectionID,
)

suite.Require().True(isMiddlewareEnabled)
Expand Down
4 changes: 2 additions & 2 deletions modules/apps/27-interchain-accounts/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ func KeyPort(portID string) []byte {
}

// KeyIsMiddlewareEnabled creates and returns a new key used for signaling legacy API callback routing via ibc middleware
func KeyIsMiddlewareEnabled(portID, channelID string) []byte {
return []byte(fmt.Sprintf("%s/%s/%s", IsMiddlewareEnabledPrefix, portID, channelID))
func KeyIsMiddlewareEnabled(portID, connectionID string) []byte {
return []byte(fmt.Sprintf("%s/%s/%s", IsMiddlewareEnabledPrefix, portID, connectionID))
}

0 comments on commit 6b7d67f

Please sign in to comment.