diff --git a/CHANGELOG.md b/CHANGELOG.md index c7e42544d..c0c9731f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * [#865](https://github.com/NibiruChain/nibiru/pull/865) - refactor(vpool): clean up interface for CmdGetBaseAssetPrice to use add and remove as directions * [#868](https://github.com/NibiruChain/nibiru/pull/868) - refactor dex integration tests to be independent between them * [#876](https://github.com/NibiruChain/nibiru/pull/876) - chore(deps): bump github.com/spf13/viper from 1.12.0 to 1.13.0 +* [#889](https://github.com/NibiruChain/nibiru/pull/889) - feat: decouple keeper from servers in pricefeed module * [#886](https://github.com/NibiruChain/nibiru/pull/886) - feat: decouple keeper from servers in perp module ### Features diff --git a/x/pricefeed/keeper/grpc_query.go b/x/pricefeed/keeper/grpc_query.go index d10852a7b..be27ce869 100644 --- a/x/pricefeed/keeper/grpc_query.go +++ b/x/pricefeed/keeper/grpc_query.go @@ -11,9 +11,17 @@ import ( "github.com/NibiruChain/nibiru/x/pricefeed/types" ) -var _ types.QueryServer = Keeper{} +type queryServer struct { + k Keeper +} + +func NewQuerier(k Keeper) types.QueryServer { + return queryServer{k: k} +} + +var _ types.QueryServer = queryServer{} -func (k Keeper) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest) (*types.QueryPriceResponse, error) { +func (q queryServer) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest) (*types.QueryPriceResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "invalid request") } @@ -24,21 +32,21 @@ func (k Keeper) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest) if err != nil { return nil, err } - if !k.GetPairs(ctx).Contains(pair) { + if !q.k.GetPairs(ctx).Contains(pair) { return nil, status.Error(codes.NotFound, "pair not in module params") } - if !k.ActivePairsStore().getKV(ctx).Has([]byte(pair.String())) { + if !q.k.ActivePairsStore().getKV(ctx).Has([]byte(pair.String())) { return nil, status.Error(codes.NotFound, "invalid market ID") } tokens := common.DenomsFromPoolName(req.PairId) token0, token1 := tokens[0], tokens[1] - currentPrice, err := k.GetCurrentPrice(ctx, token0, token1) + currentPrice, err := q.k.GetCurrentPrice(ctx, token0, token1) if err != nil { return nil, err } - twap, err := k.GetCurrentTWAP(ctx, token0, token1) + twap, err := q.k.GetCurrentTWAP(ctx, token0, token1) if err != nil { return nil, err } @@ -52,7 +60,7 @@ func (k Keeper) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest) }, nil } -func (k Keeper) QueryRawPrices( +func (q queryServer) QueryRawPrices( goCtx context.Context, req *types.QueryRawPricesRequest, ) (*types.QueryRawPricesResponse, error) { if req == nil { @@ -61,12 +69,12 @@ func (k Keeper) QueryRawPrices( ctx := sdk.UnwrapSDKContext(goCtx) - if !k.IsActivePair(ctx, req.PairId) { + if !q.k.IsActivePair(ctx, req.PairId) { return nil, status.Error(codes.NotFound, "invalid market ID") } var prices types.PostedPriceResponses - for _, rp := range k.GetRawPrices(ctx, req.PairId) { + for _, rp := range q.k.GetRawPrices(ctx, req.PairId) { prices = append(prices, types.PostedPriceResponse{ PairID: rp.PairID, OracleAddress: rp.Oracle, @@ -80,7 +88,7 @@ func (k Keeper) QueryRawPrices( }, nil } -func (k Keeper) QueryPrices(goCtx context.Context, req *types.QueryPricesRequest) (*types.QueryPricesResponse, error) { +func (q queryServer) QueryPrices(goCtx context.Context, req *types.QueryPricesRequest) (*types.QueryPricesResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "invalid request") } @@ -88,7 +96,7 @@ func (k Keeper) QueryPrices(goCtx context.Context, req *types.QueryPricesRequest ctx := sdk.UnwrapSDKContext(goCtx) var currentPrices types.CurrentPriceResponses - for _, currentPrice := range k.GetCurrentPrices(ctx) { + for _, currentPrice := range q.k.GetCurrentPrices(ctx) { if currentPrice.PairID != "" { currentPrices = append(currentPrices, types.CurrentPriceResponse{ PairID: currentPrice.PairID, @@ -102,7 +110,7 @@ func (k Keeper) QueryPrices(goCtx context.Context, req *types.QueryPricesRequest }, nil } -func (k Keeper) QueryOracles(goCtx context.Context, req *types.QueryOraclesRequest) (*types.QueryOraclesResponse, error) { +func (q queryServer) QueryOracles(goCtx context.Context, req *types.QueryOraclesRequest) (*types.QueryOraclesResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "invalid request") } @@ -114,7 +122,7 @@ func (k Keeper) QueryOracles(goCtx context.Context, req *types.QueryOraclesReque return nil, status.Error(codes.NotFound, "invalid market ID") } - oracles := k.GetOraclesForPair(ctx, req.PairId) + oracles := q.k.GetOraclesForPair(ctx, req.PairId) if len(oracles) == 0 { return &types.QueryOraclesResponse{}, nil } @@ -129,17 +137,17 @@ func (k Keeper) QueryOracles(goCtx context.Context, req *types.QueryOraclesReque }, nil } -func (k Keeper) QueryParams(c context.Context, req *types.QueryParamsRequest, +func (q queryServer) QueryParams(c context.Context, req *types.QueryParamsRequest, ) (*types.QueryParamsResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "invalid request") } ctx := sdk.UnwrapSDKContext(c) - return &types.QueryParamsResponse{Params: k.GetParams(ctx)}, nil + return &types.QueryParamsResponse{Params: q.k.GetParams(ctx)}, nil } -func (k Keeper) QueryMarkets(goCtx context.Context, req *types.QueryMarketsRequest, +func (q queryServer) QueryMarkets(goCtx context.Context, req *types.QueryMarketsRequest, ) (*types.QueryMarketsResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "invalid request") @@ -148,16 +156,16 @@ func (k Keeper) QueryMarkets(goCtx context.Context, req *types.QueryMarketsReque ctx := sdk.UnwrapSDKContext(goCtx) var markets types.Markets - for _, pair := range k.GetParams(ctx).Pairs { + for _, pair := range q.k.GetParams(ctx).Pairs { var oracleStrings []string - for _, oracle := range k.OraclesStore().Get(ctx, pair) { + for _, oracle := range q.k.OraclesStore().Get(ctx, pair) { oracleStrings = append(oracleStrings, oracle.String()) } markets = append(markets, types.Market{ PairID: pair.String(), Oracles: oracleStrings, - Active: k.IsActivePair(ctx, pair.String()), + Active: q.k.IsActivePair(ctx, pair.String()), }) } diff --git a/x/pricefeed/keeper/grpc_query_test.go b/x/pricefeed/keeper/grpc_query_test.go index 0488caaf4..c67dc970e 100644 --- a/x/pricefeed/keeper/grpc_query_test.go +++ b/x/pricefeed/keeper/grpc_query_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/NibiruChain/nibiru/x/common" + keeper2 "github.com/NibiruChain/nibiru/x/pricefeed/keeper" "github.com/NibiruChain/nibiru/x/pricefeed/types" testutilkeeper "github.com/NibiruChain/nibiru/x/testutil/keeper" "github.com/NibiruChain/nibiru/x/testutil/sample" @@ -16,17 +17,19 @@ import ( func TestParamsQuery(t *testing.T) { keeper, ctx := testutilkeeper.PricefeedKeeper(t) + querier := keeper2.NewQuerier(keeper) wctx := sdk.WrapSDKContext(ctx) params := types.Params{Pairs: common.NewAssetPairs("btc:usd", "xrp:usd")} keeper.SetParams(ctx, params) - response, err := keeper.QueryParams(wctx, &types.QueryParamsRequest{}) + response, err := querier.QueryParams(wctx, &types.QueryParamsRequest{}) require.NoError(t, err) require.Equal(t, &types.QueryParamsResponse{Params: params}, response) } func TestOraclesQuery(t *testing.T) { keeper, ctx := testutilkeeper.PricefeedKeeper(t) + querier := keeper2.NewQuerier(keeper) wctx := sdk.WrapSDKContext(ctx) pairs := common.NewAssetPairs("usd:btc", "usd:xrp", "usd:ada", "usd:eth") params := types.Params{Pairs: pairs} @@ -48,14 +51,14 @@ func TestOraclesQuery(t *testing.T) { /*pairs=*/ []common.AssetPair{pairs[3]}) t.Log("Query for pair 2 oracles | ADA") - response, err := keeper.QueryOracles(wctx, &types.QueryOraclesRequest{ + response, err := querier.QueryOracles(wctx, &types.QueryOraclesRequest{ PairId: pairs[2].String()}) require.NoError(t, err) require.Equal(t, &types.QueryOraclesResponse{ Oracles: []string{oracleA.String(), oracleB.String()}}, response) t.Log("Query for pair 3 oracles | ETH") - response, err = keeper.QueryOracles(wctx, &types.QueryOraclesRequest{ + response, err = querier.QueryOracles(wctx, &types.QueryOraclesRequest{ PairId: pairs[3].String()}) require.NoError(t, err) require.Equal(t, &types.QueryOraclesResponse{ @@ -64,6 +67,7 @@ func TestOraclesQuery(t *testing.T) { func TestMarketsQuery(t *testing.T) { keeper, ctx := testutilkeeper.PricefeedKeeper(t) + querier := keeper2.NewQuerier(keeper) wctx := sdk.WrapSDKContext(ctx) pairs := common.NewAssetPairs("btc:usd", "xrp:usd", "ada:usd", "eth:usd") params := types.Params{Pairs: pairs} @@ -78,7 +82,7 @@ func TestMarketsQuery(t *testing.T) { keeper.ActivePairsStore().SetMany(ctx, pairs[:3], true) keeper.ActivePairsStore().SetMany(ctx, common.AssetPairs{pairs[3]}, false) - queryResp, err := keeper.QueryMarkets(wctx, &types.QueryMarketsRequest{}) + queryResp, err := querier.QueryMarkets(wctx, &types.QueryMarketsRequest{}) require.NoError(t, err) wantQueryResponse := &types.QueryMarketsResponse{ Markets: []types.Market{ @@ -112,6 +116,7 @@ func TestMarketsQuery(t *testing.T) { func TestQueryPrice(t *testing.T) { pair := common.MustNewAssetPair("ubtc:uusd") keeper, ctx := testutilkeeper.PricefeedKeeper(t) + querier := keeper2.NewQuerier(keeper) keeper.SetParams(ctx, types.Params{ Pairs: common.AssetPairs{pair}, TwapLookbackWindow: time.Minute * 15, @@ -130,7 +135,7 @@ func TestQueryPrice(t *testing.T) { require.NoError(t, keeper.GatherRawPrices(ctx, "ubtc", "uusd")) ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Second * 5)).WithBlockHeight(2) - resp, err := keeper.QueryPrice(sdk.WrapSDKContext(ctx), &types.QueryPriceRequest{ + resp, err := querier.QueryPrice(sdk.WrapSDKContext(ctx), &types.QueryPriceRequest{ PairId: "ubtc:uusd", }) diff --git a/x/pricefeed/keeper/msg_server.go b/x/pricefeed/keeper/msg_server.go index c132773fb..ee08bf1f3 100644 --- a/x/pricefeed/keeper/msg_server.go +++ b/x/pricefeed/keeper/msg_server.go @@ -12,13 +12,13 @@ import ( ) type msgServer struct { - Keeper + k Keeper } // NewMsgServerImpl returns an implementation of the MsgServer interface // for the provided Keeper. func NewMsgServerImpl(keeper Keeper) types.MsgServer { - return &msgServer{Keeper: keeper} + return &msgServer{k: keeper} } var _ types.MsgServer = msgServer{} @@ -27,7 +27,7 @@ var _ types.MsgServer = msgServer{} // PostPrice // --------------------------------------------------------------- -func (k msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice, +func (ms msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice, ) (*types.MsgPostPriceResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) @@ -38,8 +38,8 @@ func (k msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice, pair := common.AssetPair{Token0: msg.Token0, Token1: msg.Token1} - isWhitelisted := k.IsWhitelistedOracle(ctx, pair.String(), from) - isWhitelistedForInverse := k.IsWhitelistedOracle( + isWhitelisted := ms.k.IsWhitelistedOracle(ctx, pair.String(), from) + isWhitelistedForInverse := ms.k.IsWhitelistedOracle( ctx, pair.Inverse().String(), from) if !(isWhitelisted || isWhitelistedForInverse) { return nil, sdkerrors.Wrapf(types.ErrInvalidOracle, @@ -55,7 +55,7 @@ func (k msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice, postedPrice = msg.Price } - if err = k.PostRawPrice(ctx, from, pair.String(), postedPrice, msg.Expiry); err != nil { + if err = ms.k.PostRawPrice(ctx, from, pair.String(), postedPrice, msg.Expiry); err != nil { return nil, err } diff --git a/x/pricefeed/module.go b/x/pricefeed/module.go index 98da84a35..c92fbfb14 100644 --- a/x/pricefeed/module.go +++ b/x/pricefeed/module.go @@ -141,7 +141,7 @@ func (am AppModule) LegacyQuerierHandler(legacyQuerierCdc *codec.LegacyAmino) sd // RegisterServices registers a GRPC query service to respond to the // module-specific GRPC queries. func (am AppModule) RegisterServices(cfg module.Configurator) { - types.RegisterQueryServer(cfg.QueryServer(), am.keeper) + types.RegisterQueryServer(cfg.QueryServer(), keeper.NewQuerier(am.keeper)) } // RegisterInvariants registers the capability module's invariants.