diff --git a/x/capability/alias.go b/x/capability/alias.go index 3aa949677e50..7bf21b034e98 100644 --- a/x/capability/alias.go +++ b/x/capability/alias.go @@ -24,6 +24,7 @@ var ( KeyPrefixIndexCapability = types.KeyPrefixIndexCapability ErrCapabilityTaken = types.ErrCapabilityTaken ErrOwnerClaimed = types.ErrOwnerClaimed + ErrCapabilityNotOwned = types.ErrCapabilityNotOwned RegisterCodec = types.RegisterCodec RegisterCapabilityTypeCodec = types.RegisterCapabilityTypeCodec ModuleCdc = types.ModuleCdc diff --git a/x/capability/keeper/keeper.go b/x/capability/keeper/keeper.go index 03bd897a3832..acb33b3912cf 100644 --- a/x/capability/keeper/keeper.go +++ b/x/capability/keeper/keeper.go @@ -202,6 +202,45 @@ func (sk ScopedKeeper) ClaimCapability(ctx sdk.Context, cap types.Capability, na return nil } +// ReleaseCapability allows a scoped module to release a capability which it had +// previously claimed or created. After releasing the capability, if no more +// owners exist, the capability will be globally removed. +func (sk ScopedKeeper) ReleaseCapability(ctx sdk.Context, cap types.Capability) error { + memStore := ctx.KVStore(sk.memKey) + + bz := memStore.Get(types.FwdCapabilityKey(sk.module, cap)) + if len(bz) == 0 { + return sdkerrors.Wrap(types.ErrCapabilityNotOwned, sk.module) + } + + name := string(bz) + + // Remove the forward mapping between the module and capability tuple and the + // capability name in the in-memory store. + memStore.Delete(types.FwdCapabilityKey(sk.module, cap)) + + // Remove the reverse mapping between the module and capability name and the + // capability in the in-memory store. + memStore.Delete(types.RevCapabilityKey(sk.module, name)) + + // remove owner + capOwners := sk.getOwners(ctx, cap) + capOwners.Remove(types.NewOwner(sk.module, name)) + + prefixStore := prefix.NewStore(ctx.KVStore(sk.storeKey), types.KeyPrefixIndexCapability) + indexKey := types.IndexToKey(cap.GetIndex()) + + if len(capOwners.Owners) == 0 { + // remove capability owner set + prefixStore.Delete(indexKey) + } else { + // update capability owner set + prefixStore.Set(indexKey, sk.cdc.MustMarshalBinaryBare(capOwners)) + } + + return nil +} + // GetCapability allows a module to fetch a capability which it previously claimed // by name. The module is not allowed to retrieve capabilities which it does not // own. @@ -223,14 +262,7 @@ func (sk ScopedKeeper) addOwner(ctx sdk.Context, cap types.Capability, name stri prefixStore := prefix.NewStore(ctx.KVStore(sk.storeKey), types.KeyPrefixIndexCapability) indexKey := types.IndexToKey(cap.GetIndex()) - var capOwners *types.CapabilityOwners - - bz := prefixStore.Get(indexKey) - if len(bz) == 0 { - capOwners = types.NewCapabilityOwners() - } else { - sk.cdc.MustUnmarshalBinaryBare(bz, &capOwners) - } + capOwners := sk.getOwners(ctx, cap) if err := capOwners.Set(types.NewOwner(sk.module, name)); err != nil { return err @@ -241,6 +273,20 @@ func (sk ScopedKeeper) addOwner(ctx sdk.Context, cap types.Capability, name stri return nil } +func (sk ScopedKeeper) getOwners(ctx sdk.Context, cap types.Capability) (capOwners *types.CapabilityOwners) { + prefixStore := prefix.NewStore(ctx.KVStore(sk.storeKey), types.KeyPrefixIndexCapability) + indexKey := types.IndexToKey(cap.GetIndex()) + + bz := prefixStore.Get(indexKey) + if len(bz) == 0 { + capOwners = types.NewCapabilityOwners() + } else { + sk.cdc.MustUnmarshalBinaryBare(bz, &capOwners) + } + + return capOwners +} + func logger(ctx sdk.Context) log.Logger { return ctx.Logger().With("module", fmt.Sprintf("x/%s", types.ModuleName)) } diff --git a/x/capability/keeper/keeper_test.go b/x/capability/keeper/keeper_test.go index dc52d9b96d5e..7c8df90b51bc 100644 --- a/x/capability/keeper/keeper_test.go +++ b/x/capability/keeper/keeper_test.go @@ -115,6 +115,9 @@ func (suite *KeeperTestSuite) TestAuthenticateCapability() { suite.Require().False(sk2.AuthenticateCapability(suite.ctx, cap2, "invalid")) suite.Require().False(sk2.AuthenticateCapability(suite.ctx, cap1, "bond")) + sk2.ReleaseCapability(suite.ctx, cap2) + suite.Require().False(sk2.AuthenticateCapability(suite.ctx, cap2, "bond")) + badCap := types.NewCapabilityKey(100) suite.Require().False(sk1.AuthenticateCapability(suite.ctx, badCap, "transfer")) suite.Require().False(sk2.AuthenticateCapability(suite.ctx, badCap, "bond")) @@ -140,6 +143,33 @@ func (suite *KeeperTestSuite) TestClaimCapability() { suite.Require().Equal(cap, got) } +func (suite *KeeperTestSuite) TestReleaseCapability() { + sk1 := suite.keeper.ScopeToModule(bank.ModuleName) + sk2 := suite.keeper.ScopeToModule(staking.ModuleName) + + cap1, err := sk1.NewCapability(suite.ctx, "transfer") + suite.Require().NoError(err) + suite.Require().NotNil(cap1) + + suite.Require().NoError(sk2.ClaimCapability(suite.ctx, cap1, "transfer")) + + cap2, err := sk2.NewCapability(suite.ctx, "bond") + suite.Require().NoError(err) + suite.Require().NotNil(cap2) + + suite.Require().Error(sk1.ReleaseCapability(suite.ctx, cap2)) + + suite.Require().NoError(sk2.ReleaseCapability(suite.ctx, cap1)) + got, ok := sk2.GetCapability(suite.ctx, "transfer") + suite.Require().False(ok) + suite.Require().Nil(got) + + suite.Require().NoError(sk1.ReleaseCapability(suite.ctx, cap1)) + got, ok = sk1.GetCapability(suite.ctx, "transfer") + suite.Require().False(ok) + suite.Require().Nil(got) +} + func TestKeeperTestSuite(t *testing.T) { suite.Run(t, new(KeeperTestSuite)) } diff --git a/x/capability/types/errors.go b/x/capability/types/errors.go index fbf718e9fecc..bc3354789ca1 100644 --- a/x/capability/types/errors.go +++ b/x/capability/types/errors.go @@ -8,6 +8,7 @@ import ( // x/capability module sentinel errors var ( - ErrCapabilityTaken = sdkerrors.Register(ModuleName, 2, "capability name already taken") - ErrOwnerClaimed = sdkerrors.Register(ModuleName, 3, "given owner already claimed capability") + ErrCapabilityTaken = sdkerrors.Register(ModuleName, 2, "capability name already taken") + ErrOwnerClaimed = sdkerrors.Register(ModuleName, 3, "given owner already claimed capability") + ErrCapabilityNotOwned = sdkerrors.Register(ModuleName, 4, "capability not owned by module") ) diff --git a/x/capability/types/types.go b/x/capability/types/types.go index afa5b79b4e93..484abbff1dd0 100644 --- a/x/capability/types/types.go +++ b/x/capability/types/types.go @@ -77,9 +77,8 @@ func NewCapabilityOwners() *CapabilityOwners { // already exists, an error will be returned. Set runs in O(log n) average time // and O(n) in the worst case. func (co *CapabilityOwners) Set(owner Owner) error { - // find smallest index s.t. co.Owners[i] >= owner in O(log n) time - i := sort.Search(len(co.Owners), func(i int) bool { return co.Owners[i].Key() >= owner.Key() }) - if i < len(co.Owners) && co.Owners[i].Key() == owner.Key() { + i, ok := co.Get(owner) + if ok { // owner already exists at co.Owners[i] return sdkerrors.Wrapf(ErrOwnerClaimed, owner.String()) } @@ -91,3 +90,31 @@ func (co *CapabilityOwners) Set(owner Owner) error { return nil } + +// Remove removes a provided owner from the CapabilityOwners if it exists. If the +// owner does not exist, Remove is considered a no-op. +func (co *CapabilityOwners) Remove(owner Owner) { + if len(co.Owners) == 0 { + return + } + + i, ok := co.Get(owner) + if ok { + // owner exists at co.Owners[i] + co.Owners = append(co.Owners[:i], co.Owners[i+1:]...) + } +} + +// Get returns (i, true) of the provided owner in the CapabilityOwners if the +// owner exists, where i indicates the owner's index in the set. Otherwise +// (i, false) where i indicates where in the set the owner should be added. +func (co *CapabilityOwners) Get(owner Owner) (int, bool) { + // find smallest index s.t. co.Owners[i] >= owner in O(log n) time + i := sort.Search(len(co.Owners), func(i int) bool { return co.Owners[i].Key() >= owner.Key() }) + if i < len(co.Owners) && co.Owners[i].Key() == owner.Key() { + // owner exists at co.Owners[i] + return i, true + } + + return i, false +} diff --git a/x/capability/types/types_test.go b/x/capability/types/types_test.go index 0c1ba6e3d3bc..86a2a28f1618 100644 --- a/x/capability/types/types_test.go +++ b/x/capability/types/types_test.go @@ -23,7 +23,7 @@ func TestOwner(t *testing.T) { require.Equal(t, "module: bank\nname: send\n", o.String()) } -func TestCapabilityOwners(t *testing.T) { +func TestCapabilityOwners_Set(t *testing.T) { co := types.NewCapabilityOwners() owners := make([]types.Owner, 1024) @@ -47,3 +47,23 @@ func TestCapabilityOwners(t *testing.T) { require.Error(t, co.Set(owner)) } } + +func TestCapabilityOwners_Remove(t *testing.T) { + co := types.NewCapabilityOwners() + + co.Remove(types.NewOwner("bank", "send-0")) + require.Len(t, co.Owners, 0) + + for i := 0; i < 5; i++ { + require.NoError(t, co.Set(types.NewOwner("bank", fmt.Sprintf("send-%d", i)))) + } + + require.Len(t, co.Owners, 5) + + for i := 0; i < 5; i++ { + co.Remove(types.NewOwner("bank", fmt.Sprintf("send-%d", i))) + require.Len(t, co.Owners, 5-(i+1)) + } + + require.Len(t, co.Owners, 0) +}