Skip to content

Commit

Permalink
feat: decouple keeper from servers in pricefeed module (#889)
Browse files Browse the repository at this point in the history
* add server decoupling

* add server decoupling

* update changelog

* fix linter
  • Loading branch information
AgentSmithMatrix authored Sep 13, 2022
1 parent 8a69139 commit b19072e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* [#865](https://github.com/NibiruChain/nibiru/pull/865) - refactor(vpool): clean up interface for CmdGetBaseAssetPrice to use add and remove as directions
* [#868](https://github.com/NibiruChain/nibiru/pull/868) - refactor dex integration tests to be independent between them
* [#876](https://github.com/NibiruChain/nibiru/pull/876) - chore(deps): bump github.com/spf13/viper from 1.12.0 to 1.13.0
* [#889](https://github.com/NibiruChain/nibiru/pull/889) - feat: decouple keeper from servers in pricefeed module
* [#886](https://github.com/NibiruChain/nibiru/pull/886) - feat: decouple keeper from servers in perp module

### Features
Expand Down
46 changes: 27 additions & 19 deletions x/pricefeed/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ import (
"github.com/NibiruChain/nibiru/x/pricefeed/types"
)

var _ types.QueryServer = Keeper{}
type queryServer struct {
k Keeper
}

func NewQuerier(k Keeper) types.QueryServer {
return queryServer{k: k}
}

var _ types.QueryServer = queryServer{}

func (k Keeper) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest) (*types.QueryPriceResponse, error) {
func (q queryServer) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest) (*types.QueryPriceResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
}
Expand All @@ -24,21 +32,21 @@ func (k Keeper) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest)
if err != nil {
return nil, err
}
if !k.GetPairs(ctx).Contains(pair) {
if !q.k.GetPairs(ctx).Contains(pair) {
return nil, status.Error(codes.NotFound, "pair not in module params")
}
if !k.ActivePairsStore().getKV(ctx).Has([]byte(pair.String())) {
if !q.k.ActivePairsStore().getKV(ctx).Has([]byte(pair.String())) {
return nil, status.Error(codes.NotFound, "invalid market ID")
}

tokens := common.DenomsFromPoolName(req.PairId)
token0, token1 := tokens[0], tokens[1]
currentPrice, err := k.GetCurrentPrice(ctx, token0, token1)
currentPrice, err := q.k.GetCurrentPrice(ctx, token0, token1)
if err != nil {
return nil, err
}

twap, err := k.GetCurrentTWAP(ctx, token0, token1)
twap, err := q.k.GetCurrentTWAP(ctx, token0, token1)
if err != nil {
return nil, err
}
Expand All @@ -52,7 +60,7 @@ func (k Keeper) QueryPrice(goCtx context.Context, req *types.QueryPriceRequest)
}, nil
}

func (k Keeper) QueryRawPrices(
func (q queryServer) QueryRawPrices(
goCtx context.Context, req *types.QueryRawPricesRequest,
) (*types.QueryRawPricesResponse, error) {
if req == nil {
Expand All @@ -61,12 +69,12 @@ func (k Keeper) QueryRawPrices(

ctx := sdk.UnwrapSDKContext(goCtx)

if !k.IsActivePair(ctx, req.PairId) {
if !q.k.IsActivePair(ctx, req.PairId) {
return nil, status.Error(codes.NotFound, "invalid market ID")
}

var prices types.PostedPriceResponses
for _, rp := range k.GetRawPrices(ctx, req.PairId) {
for _, rp := range q.k.GetRawPrices(ctx, req.PairId) {
prices = append(prices, types.PostedPriceResponse{
PairID: rp.PairID,
OracleAddress: rp.Oracle,
Expand All @@ -80,15 +88,15 @@ func (k Keeper) QueryRawPrices(
}, nil
}

func (k Keeper) QueryPrices(goCtx context.Context, req *types.QueryPricesRequest) (*types.QueryPricesResponse, error) {
func (q queryServer) QueryPrices(goCtx context.Context, req *types.QueryPricesRequest) (*types.QueryPricesResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
}

ctx := sdk.UnwrapSDKContext(goCtx)

var currentPrices types.CurrentPriceResponses
for _, currentPrice := range k.GetCurrentPrices(ctx) {
for _, currentPrice := range q.k.GetCurrentPrices(ctx) {
if currentPrice.PairID != "" {
currentPrices = append(currentPrices, types.CurrentPriceResponse{
PairID: currentPrice.PairID,
Expand All @@ -102,7 +110,7 @@ func (k Keeper) QueryPrices(goCtx context.Context, req *types.QueryPricesRequest
}, nil
}

func (k Keeper) QueryOracles(goCtx context.Context, req *types.QueryOraclesRequest) (*types.QueryOraclesResponse, error) {
func (q queryServer) QueryOracles(goCtx context.Context, req *types.QueryOraclesRequest) (*types.QueryOraclesResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
}
Expand All @@ -114,7 +122,7 @@ func (k Keeper) QueryOracles(goCtx context.Context, req *types.QueryOraclesReque
return nil, status.Error(codes.NotFound, "invalid market ID")
}

oracles := k.GetOraclesForPair(ctx, req.PairId)
oracles := q.k.GetOraclesForPair(ctx, req.PairId)
if len(oracles) == 0 {
return &types.QueryOraclesResponse{}, nil
}
Expand All @@ -129,17 +137,17 @@ func (k Keeper) QueryOracles(goCtx context.Context, req *types.QueryOraclesReque
}, nil
}

func (k Keeper) QueryParams(c context.Context, req *types.QueryParamsRequest,
func (q queryServer) QueryParams(c context.Context, req *types.QueryParamsRequest,
) (*types.QueryParamsResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
}
ctx := sdk.UnwrapSDKContext(c)

return &types.QueryParamsResponse{Params: k.GetParams(ctx)}, nil
return &types.QueryParamsResponse{Params: q.k.GetParams(ctx)}, nil
}

func (k Keeper) QueryMarkets(goCtx context.Context, req *types.QueryMarketsRequest,
func (q queryServer) QueryMarkets(goCtx context.Context, req *types.QueryMarketsRequest,
) (*types.QueryMarketsResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
Expand All @@ -148,16 +156,16 @@ func (k Keeper) QueryMarkets(goCtx context.Context, req *types.QueryMarketsReque
ctx := sdk.UnwrapSDKContext(goCtx)

var markets types.Markets
for _, pair := range k.GetParams(ctx).Pairs {
for _, pair := range q.k.GetParams(ctx).Pairs {
var oracleStrings []string
for _, oracle := range k.OraclesStore().Get(ctx, pair) {
for _, oracle := range q.k.OraclesStore().Get(ctx, pair) {
oracleStrings = append(oracleStrings, oracle.String())
}

markets = append(markets, types.Market{
PairID: pair.String(),
Oracles: oracleStrings,
Active: k.IsActivePair(ctx, pair.String()),
Active: q.k.IsActivePair(ctx, pair.String()),
})
}

Expand Down
15 changes: 10 additions & 5 deletions x/pricefeed/keeper/grpc_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,27 @@ import (
"github.com/stretchr/testify/require"

"github.com/NibiruChain/nibiru/x/common"
keeper2 "github.com/NibiruChain/nibiru/x/pricefeed/keeper"
"github.com/NibiruChain/nibiru/x/pricefeed/types"
testutilkeeper "github.com/NibiruChain/nibiru/x/testutil/keeper"
"github.com/NibiruChain/nibiru/x/testutil/sample"
)

func TestParamsQuery(t *testing.T) {
keeper, ctx := testutilkeeper.PricefeedKeeper(t)
querier := keeper2.NewQuerier(keeper)
wctx := sdk.WrapSDKContext(ctx)
params := types.Params{Pairs: common.NewAssetPairs("btc:usd", "xrp:usd")}
keeper.SetParams(ctx, params)

response, err := keeper.QueryParams(wctx, &types.QueryParamsRequest{})
response, err := querier.QueryParams(wctx, &types.QueryParamsRequest{})
require.NoError(t, err)
require.Equal(t, &types.QueryParamsResponse{Params: params}, response)
}

func TestOraclesQuery(t *testing.T) {
keeper, ctx := testutilkeeper.PricefeedKeeper(t)
querier := keeper2.NewQuerier(keeper)
wctx := sdk.WrapSDKContext(ctx)
pairs := common.NewAssetPairs("usd:btc", "usd:xrp", "usd:ada", "usd:eth")
params := types.Params{Pairs: pairs}
Expand All @@ -48,14 +51,14 @@ func TestOraclesQuery(t *testing.T) {
/*pairs=*/ []common.AssetPair{pairs[3]})

t.Log("Query for pair 2 oracles | ADA")
response, err := keeper.QueryOracles(wctx, &types.QueryOraclesRequest{
response, err := querier.QueryOracles(wctx, &types.QueryOraclesRequest{
PairId: pairs[2].String()})
require.NoError(t, err)
require.Equal(t, &types.QueryOraclesResponse{
Oracles: []string{oracleA.String(), oracleB.String()}}, response)

t.Log("Query for pair 3 oracles | ETH")
response, err = keeper.QueryOracles(wctx, &types.QueryOraclesRequest{
response, err = querier.QueryOracles(wctx, &types.QueryOraclesRequest{
PairId: pairs[3].String()})
require.NoError(t, err)
require.Equal(t, &types.QueryOraclesResponse{
Expand All @@ -64,6 +67,7 @@ func TestOraclesQuery(t *testing.T) {

func TestMarketsQuery(t *testing.T) {
keeper, ctx := testutilkeeper.PricefeedKeeper(t)
querier := keeper2.NewQuerier(keeper)
wctx := sdk.WrapSDKContext(ctx)
pairs := common.NewAssetPairs("btc:usd", "xrp:usd", "ada:usd", "eth:usd")
params := types.Params{Pairs: pairs}
Expand All @@ -78,7 +82,7 @@ func TestMarketsQuery(t *testing.T) {
keeper.ActivePairsStore().SetMany(ctx, pairs[:3], true)
keeper.ActivePairsStore().SetMany(ctx, common.AssetPairs{pairs[3]}, false)

queryResp, err := keeper.QueryMarkets(wctx, &types.QueryMarketsRequest{})
queryResp, err := querier.QueryMarkets(wctx, &types.QueryMarketsRequest{})
require.NoError(t, err)
wantQueryResponse := &types.QueryMarketsResponse{
Markets: []types.Market{
Expand Down Expand Up @@ -112,6 +116,7 @@ func TestMarketsQuery(t *testing.T) {
func TestQueryPrice(t *testing.T) {
pair := common.MustNewAssetPair("ubtc:uusd")
keeper, ctx := testutilkeeper.PricefeedKeeper(t)
querier := keeper2.NewQuerier(keeper)
keeper.SetParams(ctx, types.Params{
Pairs: common.AssetPairs{pair},
TwapLookbackWindow: time.Minute * 15,
Expand All @@ -130,7 +135,7 @@ func TestQueryPrice(t *testing.T) {
require.NoError(t, keeper.GatherRawPrices(ctx, "ubtc", "uusd"))

ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Second * 5)).WithBlockHeight(2)
resp, err := keeper.QueryPrice(sdk.WrapSDKContext(ctx), &types.QueryPriceRequest{
resp, err := querier.QueryPrice(sdk.WrapSDKContext(ctx), &types.QueryPriceRequest{
PairId: "ubtc:uusd",
})

Expand Down
12 changes: 6 additions & 6 deletions x/pricefeed/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import (
)

type msgServer struct {
Keeper
k Keeper
}

// NewMsgServerImpl returns an implementation of the MsgServer interface
// for the provided Keeper.
func NewMsgServerImpl(keeper Keeper) types.MsgServer {
return &msgServer{Keeper: keeper}
return &msgServer{k: keeper}
}

var _ types.MsgServer = msgServer{}
Expand All @@ -27,7 +27,7 @@ var _ types.MsgServer = msgServer{}
// PostPrice
// ---------------------------------------------------------------

func (k msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice,
func (ms msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice,
) (*types.MsgPostPriceResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

Expand All @@ -38,8 +38,8 @@ func (k msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice,

pair := common.AssetPair{Token0: msg.Token0, Token1: msg.Token1}

isWhitelisted := k.IsWhitelistedOracle(ctx, pair.String(), from)
isWhitelistedForInverse := k.IsWhitelistedOracle(
isWhitelisted := ms.k.IsWhitelistedOracle(ctx, pair.String(), from)
isWhitelistedForInverse := ms.k.IsWhitelistedOracle(
ctx, pair.Inverse().String(), from)
if !(isWhitelisted || isWhitelistedForInverse) {
return nil, sdkerrors.Wrapf(types.ErrInvalidOracle,
Expand All @@ -55,7 +55,7 @@ func (k msgServer) PostPrice(goCtx context.Context, msg *types.MsgPostPrice,
postedPrice = msg.Price
}

if err = k.PostRawPrice(ctx, from, pair.String(), postedPrice, msg.Expiry); err != nil {
if err = ms.k.PostRawPrice(ctx, from, pair.String(), postedPrice, msg.Expiry); err != nil {
return nil, err
}

Expand Down
2 changes: 1 addition & 1 deletion x/pricefeed/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (am AppModule) LegacyQuerierHandler(legacyQuerierCdc *codec.LegacyAmino) sd
// RegisterServices registers a GRPC query service to respond to the
// module-specific GRPC queries.
func (am AppModule) RegisterServices(cfg module.Configurator) {
types.RegisterQueryServer(cfg.QueryServer(), am.keeper)
types.RegisterQueryServer(cfg.QueryServer(), keeper.NewQuerier(am.keeper))
}

// RegisterInvariants registers the capability module's invariants.
Expand Down

0 comments on commit b19072e

Please sign in to comment.