diff --git a/PENDING.md b/PENDING.md index 832f25e8f42d..640c4987f3e9 100644 --- a/PENDING.md +++ b/PENDING.md @@ -30,6 +30,7 @@ BREAKING CHANGES * SDK * [core] \#1807 Switch from use of rational to decimal * [types] \#1901 Validator interface's GetOwner() renamed to GetOperator() + * [x/slashing] [#2122](https://github.com/cosmos/cosmos-sdk/pull/2122) - Implement slashing period * [types] \#2119 Parsed error messages and ABCI log errors to make them more human readable. * [simulation] Rename TestAndRunTx to Operation [#2153](https://github.com/cosmos/cosmos-sdk/pull/2153) diff --git a/cmd/gaia/app/app.go b/cmd/gaia/app/app.go index 4ce6b2806de7..f1ca2a7b6630 100644 --- a/cmd/gaia/app/app.go +++ b/cmd/gaia/app/app.go @@ -93,9 +93,10 @@ func NewGaiaApp(logger log.Logger, db dbm.DB, traceStore io.Writer, baseAppOptio app.ibcMapper = ibc.NewMapper(app.cdc, app.keyIBC, app.RegisterCodespace(ibc.DefaultCodespace)) app.paramsKeeper = params.NewKeeper(app.cdc, app.keyParams) app.stakeKeeper = stake.NewKeeper(app.cdc, app.keyStake, app.coinKeeper, app.RegisterCodespace(stake.DefaultCodespace)) + app.slashingKeeper = slashing.NewKeeper(app.cdc, app.keySlashing, app.stakeKeeper, app.paramsKeeper.Getter(), app.RegisterCodespace(slashing.DefaultCodespace)) + app.stakeKeeper = app.stakeKeeper.WithValidatorHooks(app.slashingKeeper.ValidatorHooks()) app.govKeeper = gov.NewKeeper(app.cdc, app.keyGov, app.paramsKeeper.Setter(), app.coinKeeper, app.stakeKeeper, app.RegisterCodespace(gov.DefaultCodespace)) app.feeCollectionKeeper = auth.NewFeeCollectionKeeper(app.cdc, app.keyFeeCollection) - app.slashingKeeper = slashing.NewKeeper(app.cdc, app.keySlashing, app.stakeKeeper, app.paramsKeeper.Getter(), app.RegisterCodespace(slashing.DefaultCodespace)) // register message routes app.Router(). diff --git a/server/export_test.go b/server/export_test.go index 358f72cf60fe..488c55bbf654 100644 --- a/server/export_test.go +++ b/server/export_test.go @@ -1,16 +1,16 @@ package server import ( - "testing" - "github.com/stretchr/testify/require" + "bytes" + "github.com/cosmos/cosmos-sdk/server/mock" "github.com/cosmos/cosmos-sdk/wire" - "github.com/tendermint/tendermint/libs/log" + "github.com/stretchr/testify/require" tcmd "github.com/tendermint/tendermint/cmd/tendermint/commands" - "os" - "bytes" + "github.com/tendermint/tendermint/libs/log" "io" - "github.com/cosmos/cosmos-sdk/server/mock" - ) + "os" + "testing" +) func TestEmptyState(t *testing.T) { defer setupViper(t)() diff --git a/server/mock/app.go b/server/mock/app.go index eb2dfc3cc300..3c6ad3ec2798 100644 --- a/server/mock/app.go +++ b/server/mock/app.go @@ -129,7 +129,7 @@ func AppGenStateEmpty(_ *wire.Codec, _ []json.RawMessage) (appState json.RawMess // Return a validator, not much else func AppGenTx(_ *wire.Codec, pk crypto.PubKey, genTxConfig gc.GenTx) ( - appGenTx, cliPrint json.RawMessage, validator tmtypes.GenesisValidator, err error) { + appGenTx, cliPrint json.RawMessage, validator tmtypes.GenesisValidator, err error) { validator = tmtypes.GenesisValidator{ PubKey: pk, diff --git a/server/tm_cmds.go b/server/tm_cmds.go index b6daf0775320..bf208a5becdf 100644 --- a/server/tm_cmds.go +++ b/server/tm_cmds.go @@ -7,11 +7,11 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + "github.com/cosmos/cosmos-sdk/client" sdk "github.com/cosmos/cosmos-sdk/types" tcmd "github.com/tendermint/tendermint/cmd/tendermint/commands" "github.com/tendermint/tendermint/p2p" pvm "github.com/tendermint/tendermint/privval" - "github.com/cosmos/cosmos-sdk/client" ) // ShowNodeIDCmd - ported from Tendermint, dump node ID to stdout diff --git a/types/decimal.go b/types/decimal.go index baf2d9573d39..8e7db1340b0e 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -415,6 +415,14 @@ func MinDec(d1, d2 Dec) Dec { return d2 } +// maximum decimal between two +func MaxDec(d1, d2 Dec) Dec { + if d1.LT(d2) { + return d2 + } + return d1 +} + // intended to be used with require/assert: require.True(DecEq(...)) func DecEq(t *testing.T, exp, got Dec) (*testing.T, bool, string, Dec, Dec) { return t, exp.Equal(got), "expected:\t%v\ngot:\t\t%v", exp, got diff --git a/types/stake.go b/types/stake.go index f41125177555..24274f3dbd76 100644 --- a/types/stake.go +++ b/types/stake.go @@ -95,3 +95,13 @@ type DelegationSet interface { IterateDelegations(ctx Context, delegator AccAddress, fn func(index int64, delegation Delegation) (stop bool)) } + +// validator event hooks +// These can be utilized to communicate between a staking keeper +// and another keeper which must take particular actions when +// validators are bonded and unbonded. The second keeper must implement +// this interface, which then the staking keeper can call. +type ValidatorHooks interface { + OnValidatorBonded(ctx Context, address ConsAddress) // Must be called when a validator is bonded + OnValidatorBeginUnbonding(ctx Context, address ConsAddress) // Must be called when a validator begins unbonding +} diff --git a/x/mock/simulation/random_simulate_blocks.go b/x/mock/simulation/random_simulate_blocks.go index 995013ef8f21..a99b3a4c97d5 100644 --- a/x/mock/simulation/random_simulate_blocks.go +++ b/x/mock/simulation/random_simulate_blocks.go @@ -67,9 +67,10 @@ func SimulateFromSeed( header := abci.Header{Height: 0, Time: timestamp} opCount := 0 - request := abci.RequestBeginBlock{Header: header} - var pastTimes []time.Time + var pastSigningValidators [][]abci.SigningValidator + + request := RandomRequestBeginBlock(t, r, validators, livenessTransitionMatrix, evidenceFraction, pastTimes, pastSigningValidators, event, header, log) // These are operations which have been queued by previous operations operationQueue := make(map[int][]Operation) @@ -77,6 +78,7 @@ func SimulateFromSeed( // Log the header time for future lookup pastTimes = append(pastTimes, header.Time) + pastSigningValidators = append(pastSigningValidators, request.LastCommitInfo.Validators) // Run the BeginBlock handler app.BeginBlock(request) @@ -131,7 +133,7 @@ func SimulateFromSeed( } // Generate a random RequestBeginBlock with the current validator set for the next block - request = RandomRequestBeginBlock(t, r, validators, livenessTransitionMatrix, evidenceFraction, pastTimes, event, header, log) + request = RandomRequestBeginBlock(t, r, validators, livenessTransitionMatrix, evidenceFraction, pastTimes, pastSigningValidators, event, header, log) // Update the validator set validators = updateValidators(t, r, validators, res.ValidatorUpdates, event) @@ -187,13 +189,12 @@ func getKeys(validators map[string]mockValidator) []string { // RandomRequestBeginBlock generates a list of signing validators according to the provided list of validators, signing fraction, and evidence fraction func RandomRequestBeginBlock(t *testing.T, r *rand.Rand, validators map[string]mockValidator, livenessTransitions TransitionMatrix, evidenceFraction float64, - pastTimes []time.Time, event func(string), header abci.Header, log string) abci.RequestBeginBlock { + pastTimes []time.Time, pastSigningValidators [][]abci.SigningValidator, event func(string), header abci.Header, log string) abci.RequestBeginBlock { if len(validators) == 0 { return abci.RequestBeginBlock{Header: header} } signingValidators := make([]abci.SigningValidator, len(validators)) i := 0 - for _, key := range getKeys(validators) { mVal := validators[key] mVal.livenessState = livenessTransitions.NextState(r, mVal.livenessState) @@ -220,26 +221,31 @@ func RandomRequestBeginBlock(t *testing.T, r *rand.Rand, validators map[string]m i++ } evidence := make([]abci.Evidence, 0) - for r.Float64() < evidenceFraction { - height := header.Height - time := header.Time - if r.Float64() < pastEvidenceFraction { - height = int64(r.Intn(int(header.Height))) - time = pastTimes[height] - } - validator := signingValidators[r.Intn(len(signingValidators))].Validator - var currentTotalVotingPower int64 - for _, mVal := range validators { - currentTotalVotingPower += mVal.val.Power + // Anything but the first block + if len(pastTimes) > 0 { + for r.Float64() < evidenceFraction { + height := header.Height + time := header.Time + vals := signingValidators + if r.Float64() < pastEvidenceFraction { + height = int64(r.Intn(int(header.Height))) + time = pastTimes[height] + vals = pastSigningValidators[height] + } + validator := vals[r.Intn(len(vals))].Validator + var totalVotingPower int64 + for _, val := range vals { + totalVotingPower += val.Validator.Power + } + evidence = append(evidence, abci.Evidence{ + Type: tmtypes.ABCIEvidenceTypeDuplicateVote, + Validator: validator, + Height: height, + Time: time, + TotalVotingPower: totalVotingPower, + }) + event("beginblock/evidence") } - evidence = append(evidence, abci.Evidence{ - Type: tmtypes.ABCIEvidenceTypeDuplicateVote, - Validator: validator, - Height: height, - Time: time, - TotalVotingPower: currentTotalVotingPower, - }) - event("beginblock/evidence") } return abci.RequestBeginBlock{ Header: header, diff --git a/x/slashing/app_test.go b/x/slashing/app_test.go index f9ec0833fa15..0c3270139dce 100644 --- a/x/slashing/app_test.go +++ b/x/slashing/app_test.go @@ -79,7 +79,7 @@ func checkValidator(t *testing.T, mapp *mock.App, keeper stake.Keeper, } func checkValidatorSigningInfo(t *testing.T, mapp *mock.App, keeper Keeper, - addr sdk.ValAddress, expFound bool) ValidatorSigningInfo { + addr sdk.ConsAddress, expFound bool) ValidatorSigningInfo { ctxCheck := mapp.BaseApp.NewContext(true, abci.Header{}) signingInfo, found := keeper.getValidatorSigningInfo(ctxCheck, addr) require.Equal(t, expFound, found) @@ -113,7 +113,7 @@ func TestSlashingMsgs(t *testing.T) { unjailMsg := MsgUnjail{ValidatorAddr: sdk.ValAddress(validator.PubKey.Address())} // no signing info yet - checkValidatorSigningInfo(t, mapp, keeper, sdk.ValAddress(addr1), false) + checkValidatorSigningInfo(t, mapp, keeper, sdk.ConsAddress(addr1), false) // unjail should fail with unknown validator res := mock.SignCheckDeliver(t, mapp.BaseApp, []sdk.Msg{unjailMsg}, []int64{0}, []int64{1}, false, priv1) diff --git a/x/slashing/client/cli/query.go b/x/slashing/client/cli/query.go index 9f6d834dda06..87d0ad41d358 100644 --- a/x/slashing/client/cli/query.go +++ b/x/slashing/client/cli/query.go @@ -25,7 +25,7 @@ func GetCmdQuerySigningInfo(storeName string, cdc *wire.Codec) *cobra.Command { return err } - key := slashing.GetValidatorSigningInfoKey(sdk.ValAddress(pk.Address())) + key := slashing.GetValidatorSigningInfoKey(sdk.ConsAddress(pk.Address())) cliCtx := context.NewCLIContext().WithCodec(cdc) res, err := cliCtx.QueryStore(key, storeName) diff --git a/x/slashing/client/rest/query.go b/x/slashing/client/rest/query.go index 291679375b06..78c4a2d2f859 100644 --- a/x/slashing/client/rest/query.go +++ b/x/slashing/client/rest/query.go @@ -30,7 +30,7 @@ func signingInfoHandlerFn(cliCtx context.CLIContext, storeName string, cdc *wire return } - key := slashing.GetValidatorSigningInfoKey(sdk.ValAddress(pk.Address())) + key := slashing.GetValidatorSigningInfoKey(sdk.ConsAddress(pk.Address())) res, err := cliCtx.QueryStore(key, storeName) if err != nil { diff --git a/x/slashing/handler.go b/x/slashing/handler.go index d79ea73c2d3e..0531b714d1ad 100644 --- a/x/slashing/handler.go +++ b/x/slashing/handler.go @@ -30,7 +30,7 @@ func handleMsgUnjail(ctx sdk.Context, msg MsgUnjail, k Keeper) sdk.Result { return ErrValidatorNotJailed(k.codespace).Result() } - addr := sdk.ValAddress(validator.GetPubKey().Address()) + addr := sdk.ConsAddress(validator.GetPubKey().Address()) // Signing info must exist info, found := k.getValidatorSigningInfo(ctx, addr) diff --git a/x/slashing/handler_test.go b/x/slashing/handler_test.go index 8e3b719f4986..ea1a9ad58a61 100644 --- a/x/slashing/handler_test.go +++ b/x/slashing/handler_test.go @@ -19,7 +19,7 @@ func TestCannotUnjailUnlessJailed(t *testing.T) { got := stake.NewHandler(sk)(ctx, msg) require.True(t, got.IsOK()) stake.EndBlocker(ctx, sk) - require.Equal(t, ck.GetCoins(ctx, addr), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) + require.Equal(t, ck.GetCoins(ctx, sdk.AccAddress(addr)), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) require.True(t, sdk.NewDecFromInt(amt).Equal(sk.Validator(ctx, sdk.ValAddress(addr)).GetPower())) // assert non-jailed validator can't be unjailed diff --git a/x/slashing/hooks.go b/x/slashing/hooks.go new file mode 100644 index 000000000000..f5f3cc48c3d7 --- /dev/null +++ b/x/slashing/hooks.go @@ -0,0 +1,46 @@ +package slashing + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// Create a new slashing period when a validator is bonded +func (k Keeper) onValidatorBonded(ctx sdk.Context, address sdk.ConsAddress) { + slashingPeriod := ValidatorSlashingPeriod{ + ValidatorAddr: address, + StartHeight: ctx.BlockHeight(), + EndHeight: 0, + SlashedSoFar: sdk.ZeroDec(), + } + k.addOrUpdateValidatorSlashingPeriod(ctx, slashingPeriod) +} + +// Mark the slashing period as having ended when a validator begins unbonding +func (k Keeper) onValidatorBeginUnbonding(ctx sdk.Context, address sdk.ConsAddress) { + slashingPeriod := k.getValidatorSlashingPeriodForHeight(ctx, address, ctx.BlockHeight()) + slashingPeriod.EndHeight = ctx.BlockHeight() + k.addOrUpdateValidatorSlashingPeriod(ctx, slashingPeriod) +} + +// Wrapper struct for sdk.ValidatorHooks +type ValidatorHooks struct { + k Keeper +} + +// Assert implementation +var _ sdk.ValidatorHooks = ValidatorHooks{} + +// Return a sdk.ValidatorHooks interface over the wrapper struct +func (k Keeper) ValidatorHooks() sdk.ValidatorHooks { + return ValidatorHooks{k} +} + +// Implements sdk.ValidatorHooks +func (v ValidatorHooks) OnValidatorBonded(ctx sdk.Context, address sdk.ConsAddress) { + v.k.onValidatorBonded(ctx, address) +} + +// Implements sdk.ValidatorHooks +func (v ValidatorHooks) OnValidatorBeginUnbonding(ctx sdk.Context, address sdk.ConsAddress) { + v.k.onValidatorBeginUnbonding(ctx, address) +} diff --git a/x/slashing/hooks_test.go b/x/slashing/hooks_test.go new file mode 100644 index 000000000000..0731fd8f26fb --- /dev/null +++ b/x/slashing/hooks_test.go @@ -0,0 +1,26 @@ +package slashing + +import ( + "testing" + + "github.com/stretchr/testify/require" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +func TestHookOnValidatorBonded(t *testing.T) { + ctx, _, _, _, keeper := createTestInput(t) + addr := sdk.ConsAddress(addrs[0]) + keeper.onValidatorBonded(ctx, addr) + period := keeper.getValidatorSlashingPeriodForHeight(ctx, addr, ctx.BlockHeight()) + require.Equal(t, ValidatorSlashingPeriod{addr, ctx.BlockHeight(), 0, sdk.ZeroDec()}, period) +} + +func TestHookOnValidatorBeginUnbonding(t *testing.T) { + ctx, _, _, _, keeper := createTestInput(t) + addr := sdk.ConsAddress(addrs[0]) + keeper.onValidatorBonded(ctx, addr) + keeper.onValidatorBeginUnbonding(ctx, addr) + period := keeper.getValidatorSlashingPeriodForHeight(ctx, addr, ctx.BlockHeight()) + require.Equal(t, ValidatorSlashingPeriod{addr, ctx.BlockHeight(), ctx.BlockHeight(), sdk.ZeroDec()}, period) +} diff --git a/x/slashing/keeper.go b/x/slashing/keeper.go index 6d8e47cbe29f..272516585765 100644 --- a/x/slashing/keeper.go +++ b/x/slashing/keeper.go @@ -40,7 +40,7 @@ func (k Keeper) handleDoubleSign(ctx sdk.Context, addr crypto.Address, infractio logger := ctx.Logger().With("module", "x/slashing") time := ctx.BlockHeader().Time age := time.Sub(timestamp) - address := sdk.ValAddress(addr) + address := sdk.ConsAddress(addr) pubkey, err := k.getPubkey(ctx, addr) if err != nil { panic(fmt.Sprintf("Validator address %v not found", addr)) @@ -56,8 +56,14 @@ func (k Keeper) handleDoubleSign(ctx sdk.Context, addr crypto.Address, infractio // Double sign confirmed logger.Info(fmt.Sprintf("Confirmed double sign from %s at height %d, age of %d less than max age of %d", pubkey.Address(), infractionHeight, age, maxEvidenceAge)) + // Cap the amount slashed to the penalty for the worst infraction + // within the slashing period when this infraction was committed + fraction := k.SlashFractionDoubleSign(ctx) + revisedFraction := k.capBySlashingPeriod(ctx, address, fraction, infractionHeight) + logger.Info(fmt.Sprintf("Fraction slashed capped by slashing period from %v to %v", fraction, revisedFraction)) + // Slash validator - k.validatorSet.Slash(ctx, pubkey, infractionHeight, power, k.SlashFractionDoubleSign(ctx)) + k.validatorSet.Slash(ctx, pubkey, infractionHeight, power, revisedFraction) // Jail validator k.validatorSet.Jail(ctx, pubkey) @@ -76,7 +82,7 @@ func (k Keeper) handleDoubleSign(ctx sdk.Context, addr crypto.Address, infractio func (k Keeper) handleValidatorSignature(ctx sdk.Context, addr crypto.Address, power int64, signed bool) { logger := ctx.Logger().With("module", "x/slashing") height := ctx.BlockHeight() - address := sdk.ValAddress(addr) + address := sdk.ConsAddress(addr) pubkey, err := k.getPubkey(ctx, addr) if err != nil { panic(fmt.Sprintf("Validator address %v not found", addr)) @@ -169,7 +175,3 @@ func (k Keeper) deleteAddrPubkeyRelation(ctx sdk.Context, addr crypto.Address) { store := ctx.KVStore(k.storeKey) store.Delete(getAddrPubkeyRelationKey(addr)) } - -func getAddrPubkeyRelationKey(address []byte) []byte { - return append([]byte{0x03}, address...) -} diff --git a/x/slashing/keeper_test.go b/x/slashing/keeper_test.go index 3bdb043a825f..af15bc2b288c 100644 --- a/x/slashing/keeper_test.go +++ b/x/slashing/keeper_test.go @@ -24,13 +24,14 @@ func TestHandleDoubleSign(t *testing.T) { // initial setup ctx, ck, sk, _, keeper := createTestInput(t) + sk = sk.WithValidatorHooks(keeper.ValidatorHooks()) amtInt := int64(100) addr, val, amt := addrs[0], pks[0], sdk.NewInt(amtInt) got := stake.NewHandler(sk)(ctx, newTestMsgCreateValidator(sdk.ValAddress(addr), val, amt)) require.True(t, got.IsOK()) validatorUpdates := stake.EndBlocker(ctx, sk) keeper.AddValidators(ctx, validatorUpdates) - require.Equal(t, ck.GetCoins(ctx, addr), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) + require.Equal(t, ck.GetCoins(ctx, sdk.AccAddress(addr)), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) require.True(t, sdk.NewDecFromInt(amt).Equal(sk.Validator(ctx, sdk.ValAddress(addr)).GetPower())) // handle a signature to set signing info @@ -58,12 +59,68 @@ func TestHandleDoubleSign(t *testing.T) { ) } +// Test that the amount a validator is slashed for multiple double signs +// is correctly capped by the slashing period in which they were committed +func TestSlashingPeriodCap(t *testing.T) { + + // initial setup + ctx, ck, sk, _, keeper := createTestInput(t) + sk = sk.WithValidatorHooks(keeper.ValidatorHooks()) + amtInt := int64(100) + addr, val, amt := addrs[0], pks[0], sdk.NewInt(amtInt) + got := stake.NewHandler(sk)(ctx, newTestMsgCreateValidator(addr, val, amt)) + require.True(t, got.IsOK()) + validatorUpdates := stake.EndBlocker(ctx, sk) + keeper.AddValidators(ctx, validatorUpdates) + require.Equal(t, ck.GetCoins(ctx, sdk.AccAddress(addr)), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) + require.True(t, sdk.NewDecFromInt(amt).Equal(sk.Validator(ctx, addr).GetPower())) + + // handle a signature to set signing info + keeper.handleValidatorSignature(ctx, val.Address(), amtInt, true) + + // double sign less than max age + keeper.handleDoubleSign(ctx, val.Address(), 0, time.Unix(0, 0), amtInt) + + // should be jailed + require.True(t, sk.Validator(ctx, addr).GetJailed()) + // update block height + ctx = ctx.WithBlockHeight(int64(1)) + // unjail to measure power + sk.Unjail(ctx, val) + // power should be reduced + expectedPower := sdk.NewDecFromInt(amt).Mul(sdk.NewDec(19).Quo(sdk.NewDec(20))) + require.Equal(t, expectedPower, sk.Validator(ctx, addr).GetPower()) + + // double sign again, same slashing period + keeper.handleDoubleSign(ctx, val.Address(), 0, time.Unix(0, 0), amtInt) + // should be jailed + require.True(t, sk.Validator(ctx, addr).GetJailed()) + // update block height + ctx = ctx.WithBlockHeight(int64(2)) + // unjail to measure power + sk.Unjail(ctx, val) + // power should be equal, no more should have been slashed + expectedPower = sdk.NewDecFromInt(amt).Mul(sdk.NewDec(19).Quo(sdk.NewDec(20))) + require.Equal(t, expectedPower, sk.Validator(ctx, addr).GetPower()) + + // double sign again, new slashing period + keeper.handleDoubleSign(ctx, val.Address(), 2, time.Unix(0, 0), amtInt) + // should be jailed + require.True(t, sk.Validator(ctx, addr).GetJailed()) + // unjail to measure power + sk.Unjail(ctx, val) + // power should be reduced + expectedPower = sdk.NewDecFromInt(amt).Mul(sdk.NewDec(18).Quo(sdk.NewDec(20))) + require.Equal(t, expectedPower, sk.Validator(ctx, addr).GetPower()) +} + // Test a validator through uptime, downtime, revocation, // unrevocation, starting height reset, and revocation again func TestHandleAbsentValidator(t *testing.T) { // initial setup ctx, ck, sk, _, keeper := createTestInput(t) + sk = sk.WithValidatorHooks(keeper.ValidatorHooks()) amtInt := int64(100) addr, val, amt := addrs[0], pks[0], sdk.NewInt(amtInt) sh := stake.NewHandler(sk) @@ -72,9 +129,9 @@ func TestHandleAbsentValidator(t *testing.T) { require.True(t, got.IsOK()) validatorUpdates := stake.EndBlocker(ctx, sk) keeper.AddValidators(ctx, validatorUpdates) - require.Equal(t, ck.GetCoins(ctx, addr), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) + require.Equal(t, ck.GetCoins(ctx, sdk.AccAddress(addr)), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) require.True(t, sdk.NewDecFromInt(amt).Equal(sk.Validator(ctx, sdk.ValAddress(addr)).GetPower())) - info, found := keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(val.Address())) + info, found := keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address())) require.False(t, found) require.Equal(t, int64(0), info.StartHeight) require.Equal(t, int64(0), info.IndexOffset) @@ -89,7 +146,7 @@ func TestHandleAbsentValidator(t *testing.T) { ctx = ctx.WithBlockHeight(height) keeper.handleValidatorSignature(ctx, val.Address(), amtInt, true) } - info, found = keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(val.Address())) + info, found = keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address())) require.True(t, found) require.Equal(t, int64(0), info.StartHeight) require.Equal(t, keeper.SignedBlocksWindow(ctx), info.SignedBlocksCounter) @@ -99,7 +156,7 @@ func TestHandleAbsentValidator(t *testing.T) { ctx = ctx.WithBlockHeight(height) keeper.handleValidatorSignature(ctx, val.Address(), amtInt, false) } - info, found = keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(val.Address())) + info, found = keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address())) require.True(t, found) require.Equal(t, int64(0), info.StartHeight) require.Equal(t, keeper.SignedBlocksWindow(ctx)-keeper.MinSignedPerWindow(ctx), info.SignedBlocksCounter) @@ -113,7 +170,7 @@ func TestHandleAbsentValidator(t *testing.T) { // 501st block missed ctx = ctx.WithBlockHeight(height) keeper.handleValidatorSignature(ctx, val.Address(), amtInt, false) - info, found = keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(val.Address())) + info, found = keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address())) require.True(t, found) require.Equal(t, int64(0), info.StartHeight) require.Equal(t, keeper.SignedBlocksWindow(ctx)-keeper.MinSignedPerWindow(ctx)-1, info.SignedBlocksCounter) @@ -141,7 +198,7 @@ func TestHandleAbsentValidator(t *testing.T) { require.Equal(t, int64(amtInt)-slashAmt, pool.BondedTokens.RoundInt64()) // validator start height should have been changed - info, found = keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(val.Address())) + info, found = keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address())) require.True(t, found) require.Equal(t, height, info.StartHeight) require.Equal(t, keeper.SignedBlocksWindow(ctx)-keeper.MinSignedPerWindow(ctx)-1, info.SignedBlocksCounter) @@ -182,7 +239,7 @@ func TestHandleNewValidator(t *testing.T) { require.True(t, got.IsOK()) validatorUpdates := stake.EndBlocker(ctx, sk) keeper.AddValidators(ctx, validatorUpdates) - require.Equal(t, ck.GetCoins(ctx, addr), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.SubRaw(amt)}}) + require.Equal(t, ck.GetCoins(ctx, sdk.AccAddress(addr)), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.SubRaw(amt)}}) require.Equal(t, sdk.NewDec(amt), sk.Validator(ctx, sdk.ValAddress(addr)).GetPower()) // 1000 first blocks not a validator @@ -193,7 +250,7 @@ func TestHandleNewValidator(t *testing.T) { ctx = ctx.WithBlockHeight(keeper.SignedBlocksWindow(ctx) + 2) keeper.handleValidatorSignature(ctx, val.Address(), 100, false) - info, found := keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(val.Address())) + info, found := keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address())) require.True(t, found) require.Equal(t, int64(keeper.SignedBlocksWindow(ctx)+1), info.StartHeight) require.Equal(t, int64(2), info.IndexOffset) diff --git a/x/slashing/keys.go b/x/slashing/keys.go new file mode 100644 index 000000000000..2af9e069a13c --- /dev/null +++ b/x/slashing/keys.go @@ -0,0 +1,43 @@ +package slashing + +import ( + "encoding/binary" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// key prefix bytes +var ( + ValidatorSigningInfoKey = []byte{0x01} // Prefix for signing info + ValidatorSigningBitArrayKey = []byte{0x02} // Prefix for signature bit array + ValidatorSlashingPeriodKey = []byte{0x03} // Prefix for slashing period + AddrPubkeyRelationKey = []byte{0x04} // Prefix for address-pubkey relation +) + +// stored by *Tendermint* address (not owner address) +func GetValidatorSigningInfoKey(v sdk.ConsAddress) []byte { + return append(ValidatorSigningInfoKey, v.Bytes()...) +} + +// stored by *Tendermint* address (not owner address) +func GetValidatorSigningBitArrayKey(v sdk.ConsAddress, i int64) []byte { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(i)) + return append(ValidatorSigningBitArrayKey, append(v.Bytes(), b...)...) +} + +// stored by *Tendermint* address (not owner address) +func GetValidatorSlashingPeriodPrefix(v sdk.ConsAddress) []byte { + return append(ValidatorSlashingPeriodKey, v.Bytes()...) +} + +// stored by *Tendermint* address (not owner address) followed by start height +func GetValidatorSlashingPeriodKey(v sdk.ConsAddress, startHeight int64) []byte { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(startHeight)) + return append(GetValidatorSlashingPeriodPrefix(v), b...) +} + +func getAddrPubkeyRelationKey(address []byte) []byte { + return append(AddrPubkeyRelationKey, address...) +} diff --git a/x/slashing/signing_info.go b/x/slashing/signing_info.go index 25a83e833d31..e76fea53f2d1 100644 --- a/x/slashing/signing_info.go +++ b/x/slashing/signing_info.go @@ -1,7 +1,6 @@ package slashing import ( - "encoding/binary" "fmt" "time" @@ -9,7 +8,7 @@ import ( ) // Stored by *validator* address (not owner address) -func (k Keeper) getValidatorSigningInfo(ctx sdk.Context, address sdk.ValAddress) (info ValidatorSigningInfo, found bool) { +func (k Keeper) getValidatorSigningInfo(ctx sdk.Context, address sdk.ConsAddress) (info ValidatorSigningInfo, found bool) { store := ctx.KVStore(k.storeKey) bz := store.Get(GetValidatorSigningInfoKey(address)) if bz == nil { @@ -22,14 +21,14 @@ func (k Keeper) getValidatorSigningInfo(ctx sdk.Context, address sdk.ValAddress) } // Stored by *validator* address (not owner address) -func (k Keeper) setValidatorSigningInfo(ctx sdk.Context, address sdk.ValAddress, info ValidatorSigningInfo) { +func (k Keeper) setValidatorSigningInfo(ctx sdk.Context, address sdk.ConsAddress, info ValidatorSigningInfo) { store := ctx.KVStore(k.storeKey) bz := k.cdc.MustMarshalBinary(info) store.Set(GetValidatorSigningInfoKey(address), bz) } // Stored by *validator* address (not owner address) -func (k Keeper) getValidatorSigningBitArray(ctx sdk.Context, address sdk.ValAddress, index int64) (signed bool) { +func (k Keeper) getValidatorSigningBitArray(ctx sdk.Context, address sdk.ConsAddress, index int64) (signed bool) { store := ctx.KVStore(k.storeKey) bz := store.Get(GetValidatorSigningBitArrayKey(address, index)) if bz == nil { @@ -42,7 +41,7 @@ func (k Keeper) getValidatorSigningBitArray(ctx sdk.Context, address sdk.ValAddr } // Stored by *validator* address (not owner address) -func (k Keeper) setValidatorSigningBitArray(ctx sdk.Context, address sdk.ValAddress, index int64, signed bool) { +func (k Keeper) setValidatorSigningBitArray(ctx sdk.Context, address sdk.ConsAddress, index int64, signed bool) { store := ctx.KVStore(k.storeKey) bz := k.cdc.MustMarshalBinary(signed) store.Set(GetValidatorSigningBitArrayKey(address, index), bz) @@ -71,15 +70,3 @@ func (i ValidatorSigningInfo) HumanReadableString() string { return fmt.Sprintf("Start height: %d, index offset: %d, jailed until: %v, signed blocks counter: %d", i.StartHeight, i.IndexOffset, i.JailedUntil, i.SignedBlocksCounter) } - -// Stored by *validator* address (not owner address) -func GetValidatorSigningInfoKey(v sdk.ValAddress) []byte { - return append([]byte{0x01}, v.Bytes()...) -} - -// Stored by *validator* address (not owner address) -func GetValidatorSigningBitArrayKey(v sdk.ValAddress, i int64) []byte { - b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(i)) - return append([]byte{0x02}, append(v.Bytes(), b...)...) -} diff --git a/x/slashing/signing_info_test.go b/x/slashing/signing_info_test.go index f92c43581b5a..7aff0da95f1b 100644 --- a/x/slashing/signing_info_test.go +++ b/x/slashing/signing_info_test.go @@ -11,7 +11,7 @@ import ( func TestGetSetValidatorSigningInfo(t *testing.T) { ctx, _, _, _, keeper := createTestInput(t) - info, found := keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(addrs[0])) + info, found := keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(addrs[0])) require.False(t, found) newInfo := ValidatorSigningInfo{ StartHeight: int64(4), @@ -19,8 +19,8 @@ func TestGetSetValidatorSigningInfo(t *testing.T) { JailedUntil: time.Unix(2, 0), SignedBlocksCounter: int64(10), } - keeper.setValidatorSigningInfo(ctx, sdk.ValAddress(addrs[0]), newInfo) - info, found = keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(addrs[0])) + keeper.setValidatorSigningInfo(ctx, sdk.ConsAddress(addrs[0]), newInfo) + info, found = keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(addrs[0])) require.True(t, found) require.Equal(t, info.StartHeight, int64(4)) require.Equal(t, info.IndexOffset, int64(3)) @@ -30,9 +30,9 @@ func TestGetSetValidatorSigningInfo(t *testing.T) { func TestGetSetValidatorSigningBitArray(t *testing.T) { ctx, _, _, _, keeper := createTestInput(t) - signed := keeper.getValidatorSigningBitArray(ctx, sdk.ValAddress(addrs[0]), 0) + signed := keeper.getValidatorSigningBitArray(ctx, sdk.ConsAddress(addrs[0]), 0) require.False(t, signed) // treat empty key as unsigned - keeper.setValidatorSigningBitArray(ctx, sdk.ValAddress(addrs[0]), 0, true) - signed = keeper.getValidatorSigningBitArray(ctx, sdk.ValAddress(addrs[0]), 0) + keeper.setValidatorSigningBitArray(ctx, sdk.ConsAddress(addrs[0]), 0, true) + signed = keeper.getValidatorSigningBitArray(ctx, sdk.ConsAddress(addrs[0]), 0) require.True(t, signed) // now should be signed } diff --git a/x/slashing/slashing_period.go b/x/slashing/slashing_period.go new file mode 100644 index 000000000000..61d25071eb80 --- /dev/null +++ b/x/slashing/slashing_period.go @@ -0,0 +1,107 @@ +package slashing + +import ( + "encoding/binary" + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// Cap an infraction's slash amount by the slashing period in which it was committed +func (k Keeper) capBySlashingPeriod(ctx sdk.Context, address sdk.ConsAddress, fraction sdk.Dec, infractionHeight int64) (revisedFraction sdk.Dec) { + + // Fetch the newest slashing period starting before this infraction was committed + slashingPeriod := k.getValidatorSlashingPeriodForHeight(ctx, address, infractionHeight) + + // Sanity check + if slashingPeriod.EndHeight > 0 && slashingPeriod.EndHeight < infractionHeight { + panic(fmt.Sprintf("slashing period ended before infraction: infraction height %d, slashing period ended at %d", infractionHeight, slashingPeriod.EndHeight)) + } + + // Calculate the updated total slash amount + // This is capped at the slashing fraction for the worst infraction within this slashing period + totalToSlash := sdk.MaxDec(slashingPeriod.SlashedSoFar, fraction) + + // Calculate the remainder which we now must slash + revisedFraction = totalToSlash.Sub(slashingPeriod.SlashedSoFar) + + // Update the slashing period struct + slashingPeriod.SlashedSoFar = totalToSlash + k.addOrUpdateValidatorSlashingPeriod(ctx, slashingPeriod) + + return +} + +// Stored by validator Tendermint address (not owner address) +// This function retrieves the most recent slashing period starting +// before a particular height - so the slashing period that was "in effect" +// at the time of an infraction committed at that height. +func (k Keeper) getValidatorSlashingPeriodForHeight(ctx sdk.Context, address sdk.ConsAddress, height int64) (slashingPeriod ValidatorSlashingPeriod) { + store := ctx.KVStore(k.storeKey) + // Get the most recent slashing period at or before the infraction height + start := GetValidatorSlashingPeriodPrefix(address) + end := sdk.PrefixEndBytes(GetValidatorSlashingPeriodKey(address, height)) + iterator := store.ReverseIterator(start, end) + if !iterator.Valid() { + panic("expected to find slashing period, but none was found") + } + slashingPeriod = k.unmarshalSlashingPeriodKeyValue(iterator.Key(), iterator.Value()) + return +} + +// Stored by validator Tendermint address (not owner address) +// This function sets a validator slashing period for a particular validator, +// start height, end height, and current slashed-so-far total, or updates +// an existing slashing period for the same validator and start height. +func (k Keeper) addOrUpdateValidatorSlashingPeriod(ctx sdk.Context, slashingPeriod ValidatorSlashingPeriod) { + slashingPeriodValue := ValidatorSlashingPeriodValue{ + EndHeight: slashingPeriod.EndHeight, + SlashedSoFar: slashingPeriod.SlashedSoFar, + } + store := ctx.KVStore(k.storeKey) + bz := k.cdc.MustMarshalBinary(slashingPeriodValue) + store.Set(GetValidatorSlashingPeriodKey(slashingPeriod.ValidatorAddr, slashingPeriod.StartHeight), bz) +} + +// Unmarshal key/value into a ValidatorSlashingPeriod +func (k Keeper) unmarshalSlashingPeriodKeyValue(key []byte, value []byte) ValidatorSlashingPeriod { + var slashingPeriodValue ValidatorSlashingPeriodValue + k.cdc.MustUnmarshalBinary(value, &slashingPeriodValue) + address := sdk.ConsAddress(key[1 : 1+sdk.AddrLen]) + startHeight := int64(binary.LittleEndian.Uint64(key[1+sdk.AddrLen : 1+sdk.AddrLen+8])) + return ValidatorSlashingPeriod{ + ValidatorAddr: address, + StartHeight: startHeight, + EndHeight: slashingPeriodValue.EndHeight, + SlashedSoFar: slashingPeriodValue.SlashedSoFar, + } +} + +// Construct a new `ValidatorSlashingPeriod` struct +func NewValidatorSlashingPeriod(startHeight int64, endHeight int64, slashedSoFar sdk.Dec) ValidatorSlashingPeriod { + return ValidatorSlashingPeriod{ + StartHeight: startHeight, + EndHeight: endHeight, + SlashedSoFar: slashedSoFar, + } +} + +// Slashing period for a validator +type ValidatorSlashingPeriod struct { + ValidatorAddr sdk.ConsAddress `json:"validator_addr"` // validator which this slashing period is for + StartHeight int64 `json:"start_height"` // starting height of the slashing period + EndHeight int64 `json:"end_height"` // ending height of the slashing period, or sentinel value of 0 for in-progress + SlashedSoFar sdk.Dec `json:"slashed_so_far"` // fraction of validator stake slashed so far in this slashing period +} + +// Value part of slashing period (validator address & start height are stored in the key) +type ValidatorSlashingPeriodValue struct { + EndHeight int64 `json:"end_height"` + SlashedSoFar sdk.Dec `json:"slashed_so_far"` +} + +// Return human readable slashing period +func (p ValidatorSlashingPeriod) HumanReadableString() string { + return fmt.Sprintf("Start height: %d, end height: %d, slashed so far: %v", + p.StartHeight, p.EndHeight, p.SlashedSoFar) +} diff --git a/x/slashing/slashing_period_test.go b/x/slashing/slashing_period_test.go new file mode 100644 index 000000000000..54157bb9cc0e --- /dev/null +++ b/x/slashing/slashing_period_test.go @@ -0,0 +1,86 @@ +package slashing + +import ( + "testing" + + "github.com/stretchr/testify/require" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +func TestGetSetValidatorSlashingPeriod(t *testing.T) { + ctx, _, _, _, keeper := createTestInput(t) + addr := sdk.ConsAddress(addrs[0]) + height := int64(5) + require.Panics(t, func() { keeper.getValidatorSlashingPeriodForHeight(ctx, addr, height) }) + newPeriod := ValidatorSlashingPeriod{ + ValidatorAddr: addr, + StartHeight: height, + EndHeight: height + 10, + SlashedSoFar: sdk.ZeroDec(), + } + keeper.addOrUpdateValidatorSlashingPeriod(ctx, newPeriod) + + // Get at start height + retrieved := keeper.getValidatorSlashingPeriodForHeight(ctx, addr, height) + require.Equal(t, newPeriod, retrieved) + + // Get after start height (works) + retrieved = keeper.getValidatorSlashingPeriodForHeight(ctx, addr, int64(6)) + require.Equal(t, newPeriod, retrieved) + + // Get before start height (panic) + require.Panics(t, func() { keeper.getValidatorSlashingPeriodForHeight(ctx, addr, int64(0)) }) + + // Get after end height (panic) + newPeriod.EndHeight = int64(4) + keeper.addOrUpdateValidatorSlashingPeriod(ctx, newPeriod) + require.Panics(t, func() { keeper.capBySlashingPeriod(ctx, addr, sdk.ZeroDec(), height) }) + + // Back to old end height + newPeriod.EndHeight = height + 10 + keeper.addOrUpdateValidatorSlashingPeriod(ctx, newPeriod) + + // Set a new, later period + anotherPeriod := ValidatorSlashingPeriod{ + ValidatorAddr: addr, + StartHeight: height + 1, + EndHeight: height + 11, + SlashedSoFar: sdk.ZeroDec(), + } + keeper.addOrUpdateValidatorSlashingPeriod(ctx, anotherPeriod) + + // Old period retrieved for prior height + retrieved = keeper.getValidatorSlashingPeriodForHeight(ctx, addr, height) + require.Equal(t, newPeriod, retrieved) + + // New period retrieved at new height + retrieved = keeper.getValidatorSlashingPeriodForHeight(ctx, addr, height+1) + require.Equal(t, anotherPeriod, retrieved) +} + +func TestValidatorSlashingPeriodCap(t *testing.T) { + ctx, _, _, _, keeper := createTestInput(t) + addr := sdk.ConsAddress(addrs[0]) + height := int64(5) + newPeriod := ValidatorSlashingPeriod{ + ValidatorAddr: addr, + StartHeight: height, + EndHeight: height + 10, + SlashedSoFar: sdk.ZeroDec(), + } + keeper.addOrUpdateValidatorSlashingPeriod(ctx, newPeriod) + half := sdk.NewDec(1).Quo(sdk.NewDec(2)) + + // First slash should be full + fractionA := keeper.capBySlashingPeriod(ctx, addr, half, height) + require.True(t, fractionA.Equal(half)) + + // Second slash should be capped + fractionB := keeper.capBySlashingPeriod(ctx, addr, half, height) + require.True(t, fractionB.Equal(sdk.ZeroDec())) + + // Third slash should be capped to difference + fractionC := keeper.capBySlashingPeriod(ctx, addr, sdk.OneDec(), height) + require.True(t, fractionC.Equal(half)) +} diff --git a/x/slashing/test_common.go b/x/slashing/test_common.go index 1053823786c2..82af340f76ca 100644 --- a/x/slashing/test_common.go +++ b/x/slashing/test_common.go @@ -30,10 +30,10 @@ var ( newPubKey("0B485CFC0EECC619440448436F8FC9DF40566F2369E72400281454CB552AFB51"), newPubKey("0B485CFC0EECC619440448436F8FC9DF40566F2369E72400281454CB552AFB52"), } - addrs = []sdk.AccAddress{ - sdk.AccAddress(pks[0].Address()), - sdk.AccAddress(pks[1].Address()), - sdk.AccAddress(pks[2].Address()), + addrs = []sdk.ValAddress{ + sdk.ValAddress(pks[0].Address()), + sdk.ValAddress(pks[1].Address()), + sdk.ValAddress(pks[2].Address()), } initCoins = sdk.NewInt(200) ) @@ -75,7 +75,7 @@ func createTestInput(t *testing.T) (sdk.Context, bank.Keeper, stake.Keeper, para require.Nil(t, err) for _, addr := range addrs { - _, _, err = ck.AddCoins(ctx, addr, sdk.Coins{ + _, _, err = ck.AddCoins(ctx, sdk.AccAddress(addr), sdk.Coins{ {sk.GetParams(ctx).BondDenom, initCoins}, }) } diff --git a/x/slashing/tick_test.go b/x/slashing/tick_test.go index 9eb956e6710a..40705cf515b1 100644 --- a/x/slashing/tick_test.go +++ b/x/slashing/tick_test.go @@ -21,7 +21,7 @@ func TestBeginBlocker(t *testing.T) { require.True(t, got.IsOK()) validatorUpdates := stake.EndBlocker(ctx, sk) keeper.AddValidators(ctx, validatorUpdates) - require.Equal(t, ck.GetCoins(ctx, addr), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) + require.Equal(t, ck.GetCoins(ctx, sdk.AccAddress(addr)), sdk.Coins{{sk.GetParams(ctx).BondDenom, initCoins.Sub(amt)}}) require.True(t, sdk.NewDecFromInt(amt).Equal(sk.Validator(ctx, sdk.ValAddress(addr)).GetPower())) val := abci.Validator{ @@ -40,7 +40,7 @@ func TestBeginBlocker(t *testing.T) { } BeginBlocker(ctx, req, keeper) - info, found := keeper.getValidatorSigningInfo(ctx, sdk.ValAddress(pk.Address())) + info, found := keeper.getValidatorSigningInfo(ctx, sdk.ConsAddress(pk.Address())) require.True(t, found) require.Equal(t, ctx.BlockHeight(), info.StartHeight) require.Equal(t, int64(1), info.IndexOffset) diff --git a/x/stake/keeper/keeper.go b/x/stake/keeper/keeper.go index 187649c5f766..14c834387898 100644 --- a/x/stake/keeper/keeper.go +++ b/x/stake/keeper/keeper.go @@ -10,9 +10,10 @@ import ( // keeper of the stake store type Keeper struct { - storeKey sdk.StoreKey - cdc *wire.Codec - coinKeeper bank.Keeper + storeKey sdk.StoreKey + cdc *wire.Codec + coinKeeper bank.Keeper + validatorHooks sdk.ValidatorHooks // codespace codespace sdk.CodespaceType @@ -20,14 +21,24 @@ type Keeper struct { func NewKeeper(cdc *wire.Codec, key sdk.StoreKey, ck bank.Keeper, codespace sdk.CodespaceType) Keeper { keeper := Keeper{ - storeKey: key, - cdc: cdc, - coinKeeper: ck, - codespace: codespace, + storeKey: key, + cdc: cdc, + coinKeeper: ck, + validatorHooks: nil, + codespace: codespace, } return keeper } +// Set the validator hooks +func (k Keeper) WithValidatorHooks(v sdk.ValidatorHooks) Keeper { + if k.validatorHooks != nil { + panic("cannot set validator hooks twice") + } + k.validatorHooks = v + return k +} + //_________________________________________________________________________ // return the codespace diff --git a/x/stake/keeper/validator.go b/x/stake/keeper/validator.go index cb225df6cc77..0ea24e639672 100644 --- a/x/stake/keeper/validator.go +++ b/x/stake/keeper/validator.go @@ -591,6 +591,13 @@ func (k Keeper) unbondValidator(ctx sdk.Context, validator types.Validator) type // also remove from the Bonded types.Validators Store store.Delete(GetValidatorsBondedIndexKey(validator.Operator)) + + // call the unbond hook if present + if k.validatorHooks != nil { + k.validatorHooks.OnValidatorBeginUnbonding(ctx, validator.ConsAddress()) + } + + // return updated validator return validator } @@ -617,6 +624,12 @@ func (k Keeper) bondValidator(ctx sdk.Context, validator types.Validator) types. bzABCI := k.cdc.MustMarshalBinary(validator.ABCIValidator()) store.Set(GetTendermintUpdatesKey(validator.Operator), bzABCI) + // call the bond hook if present + if k.validatorHooks != nil { + k.validatorHooks.OnValidatorBonded(ctx, validator.ConsAddress()) + } + + // return updated validator return validator } diff --git a/x/stake/types/validator.go b/x/stake/types/validator.go index 2cb952db2859..6a53965ca859 100644 --- a/x/stake/types/validator.go +++ b/x/stake/types/validator.go @@ -246,6 +246,11 @@ func (v Validator) Equal(c2 Validator) bool { v.LastBondedTokens.Equal(c2.LastBondedTokens) } +// return the TM validator address +func (v Validator) ConsAddress() sdk.ConsAddress { + return sdk.ConsAddress(v.PubKey.Address()) +} + // constant used in flags to indicate that description field should not be updated const DoNotModifyDesc = "[do-not-modify]"