Skip to content

Commit

Permalink
New validation for keeper fields (#740)
Browse files Browse the repository at this point in the history
* wip

* add len check

* use reflect's NotZero()

* panics, not tests

* rm old

* fix tests

* Update keeper.go
  • Loading branch information
shaspitz authored Feb 21, 2023
1 parent 558516f commit afd1b2b
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 6 deletions.
5 changes: 3 additions & 2 deletions testutil/keeper/unit_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"
"time"

authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
abci "github.com/tendermint/tendermint/abci/types"

"github.com/cosmos/cosmos-sdk/codec"
Expand Down Expand Up @@ -122,7 +123,7 @@ func NewInMemProviderKeeper(params InMemKeeperParams, mocks MockedKeepers) provi
mocks.MockSlashingKeeper,
mocks.MockAccountKeeper,
mocks.MockEvidenceKeeper,
"",
authtypes.FeeCollectorName,
)
}

Expand All @@ -142,7 +143,7 @@ func NewInMemConsumerKeeper(params InMemKeeperParams, mocks MockedKeepers) consu
mocks.MockAccountKeeper,
mocks.MockIBCTransferKeeper,
mocks.MockIBCCoreKeeper,
"",
authtypes.FeeCollectorName,
)
}

Expand Down
2 changes: 1 addition & 1 deletion x/ccv/consumer/keeper/distribution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestGetEstimatedNextFeeDistribution(t *testing.T) {

// Setup mock calls
gomock.InOrder(
mockAccountKeeper.EXPECT().GetModuleAccount(ctx, "").
mockAccountKeeper.EXPECT().GetModuleAccount(ctx, authTypes.FeeCollectorName).
Return(mAcc).
Times(1),
mockBankKeeper.EXPECT().GetAllBalances(ctx, mAcc.GetAddress()).
Expand Down
61 changes: 60 additions & 1 deletion x/ccv/consumer/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package keeper
import (
"encoding/binary"
"fmt"
"reflect"
"time"

"github.com/cosmos/cosmos-sdk/codec"
Expand Down Expand Up @@ -58,7 +59,7 @@ func NewKeeper(
paramSpace = paramSpace.WithKeyTable(types.ParamKeyTable())
}

return Keeper{
k := Keeper{
storeKey: key,
cdc: cdc,
paramStore: paramSpace,
Expand All @@ -74,6 +75,64 @@ func NewKeeper(
ibcCoreKeeper: ibcCoreKeeper,
feeCollectorName: feeCollectorName,
}

k.mustValidateFields()
return k
}

// Validates that the consumer keeper is initialized with non-zero and
// non-nil values for all its fields. Otherwise this method will panic.
func (k Keeper) mustValidateFields() {

// Ensures no fields are missed in this validation
if reflect.ValueOf(k).NumField() != 15 {
panic("number of fields in provider keeper is not 15")
}

// Note 14 fields will be validated, hooks are explicitly set after the constructor

if reflect.ValueOf(k.storeKey).IsZero() { // 1
panic("storeKey is zero-valued or nil")
}
if reflect.ValueOf(k.cdc).IsZero() { // 2
panic("cdc is zero-valued or nil")
}
if reflect.ValueOf(k.paramStore).IsZero() { // 3
panic("paramStore is zero-valued or nil")
}
if reflect.ValueOf(k.scopedKeeper).IsZero() { // 4
panic("scopedKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.channelKeeper).IsZero() { // 5
panic("channelKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.portKeeper).IsZero() { // 6
panic("portKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.connectionKeeper).IsZero() { // 7
panic("connectionKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.clientKeeper).IsZero() { // 8
panic("clientKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.slashingKeeper).IsZero() { // 9
panic("slashingKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.bankKeeper).IsZero() { // 10
panic("bankKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.authKeeper).IsZero() { // 11
panic("authKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.ibcTransferKeeper).IsZero() { // 12
panic("ibcTransferKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.ibcCoreKeeper).IsZero() { // 13
panic("ibcCoreKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.feeCollectorName).IsZero() { // 14
panic("feeCollectorName is zero-valued or nil")
}
}

// Logger returns a module-specific logger.
Expand Down
2 changes: 1 addition & 1 deletion x/ccv/provider/ibc_module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestOnChanOpenTry(t *testing.T) {
mocks.MockClientKeeper.EXPECT().GetClientState(ctx, "clientIDToConsumer").Return(
&ibctmtypes.ClientState{ChainId: "consumerChainID"}, true,
).AnyTimes(),
mocks.MockAccountKeeper.EXPECT().GetModuleAccount(ctx, "").Return(&moduleAcct).AnyTimes(),
mocks.MockAccountKeeper.EXPECT().GetModuleAccount(ctx, authtypes.FeeCollectorName).Return(&moduleAcct).AnyTimes(),
)

tc.mutateParams(&params, &providerKeeper)
Expand Down
56 changes: 55 additions & 1 deletion x/ccv/provider/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package keeper
import (
"encoding/binary"
"fmt"
"reflect"
"time"

"github.com/cosmos/cosmos-sdk/codec"
Expand Down Expand Up @@ -56,7 +57,7 @@ func NewKeeper(
paramSpace = paramSpace.WithKeyTable(types.ParamKeyTable())
}

return Keeper{
k := Keeper{
cdc: cdc,
storeKey: key,
paramSpace: paramSpace,
Expand All @@ -71,6 +72,59 @@ func NewKeeper(
evidenceKeeper: evidenceKeeper,
feeCollectorName: feeCollectorName,
}

k.mustValidateFields()
return k
}

// Validates that the provider keeper is initialized with non-zero and
// non-nil values for all its fields. Otherwise this method will panic.
func (k Keeper) mustValidateFields() {

// Ensures no fields are missed in this validation
if reflect.ValueOf(k).NumField() != 13 {
panic("number of fields in provider keeper is not 13")
}

if reflect.ValueOf(k.cdc).IsZero() { // 1
panic("cdc is zero-valued or nil")
}
if reflect.ValueOf(k.storeKey).IsZero() { // 2
panic("storeKey is zero-valued or nil")
}
if reflect.ValueOf(k.paramSpace).IsZero() { // 3
panic("paramSpace is zero-valued or nil")
}
if reflect.ValueOf(k.scopedKeeper).IsZero() { // 4
panic("scopedKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.channelKeeper).IsZero() { // 5
panic("channelKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.portKeeper).IsZero() { // 6
panic("portKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.connectionKeeper).IsZero() { // 7
panic("connectionKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.accountKeeper).IsZero() { // 8
panic("accountKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.clientKeeper).IsZero() { // 9
panic("clientKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.stakingKeeper).IsZero() { // 10
panic("stakingKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.slashingKeeper).IsZero() { // 11
panic("slashingKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.evidenceKeeper).IsZero() { // 12
panic("evidenceKeeper is zero-valued or nil")
}
if reflect.ValueOf(k.feeCollectorName).IsZero() { // 13
panic("feeCollectorName is zero-valued or nil")
}
}

// Logger returns a module-specific logger.
Expand Down

0 comments on commit afd1b2b

Please sign in to comment.