Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dex): add single asset join #1038

Merged
merged 10 commits into from
Nov 1, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* [#985](https://github.com/NibiruChain/nibiru/pull/985) - feat: query all active positions for a trader
* [#997](https://github.com/NibiruChain/nibiru/pull/997) - feat: emit `ReserveSnapshotSavedEvent` in vpool EndBlocker
* [#1011](https://github.com/NibiruChain/nibiru/pull/1011) - feat(perp): add DonateToEF cli command
* [#1038](https://github.com/NibiruChain/nibiru/pull/1038) - feat(dex): add single asset join

### Fixes
* [#1023](https://github.com/NibiruChain/nibiru/pull/1023) - collections: golang compiler bug
Expand Down
2 changes: 2 additions & 0 deletions proto/dex/v1/tx.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ message MsgJoinPool {
(gogoproto.moretags) = "yaml:\"tokens_in\"",
(gogoproto.nullable) = false
];

bool use_all_coins = 4 [(gogoproto.moretags) = "yaml:\"use_all_coins\""];
}

/*
Expand Down
4 changes: 4 additions & 0 deletions x/dex/client/cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ const (
// FlagPoolId Will be parsed to uint64.
FlagPoolId = "pool-id"

// FlagUseAllCoins Will be parsed to uint64.
FlagUseAllCoins = "use-all-coins"

// FlagTokensIn Will be parsed to []sdk.Coin.
FlagTokensIn = "tokens-in"

Expand Down Expand Up @@ -43,6 +46,7 @@ func FlagSetJoinPool() *flag.FlagSet {

fs.Uint64(FlagPoolId, 0, "The id of pool")
fs.StringArray(FlagTokensIn, []string{""}, "Amount of each denom to send into the pool (specify multiple denoms with: --tokens-in=1uusdc --tokens-in=1unusd)")
fs.Bool(FlagUseAllCoins, false, "Whether to use all the tokens in tokens-in to maximize shares out with a swap first")
return fs
}

Expand Down
7 changes: 7 additions & 0 deletions x/dex/client/cli/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,16 @@ func CmdJoinPool() *cobra.Command {
tokensIn = tokensIn.Add(parsed...)
}

useAllCoins, err := flagSet.GetBool(FlagUseAllCoins)
if err != nil {
return err
}

msg := types.NewMsgJoinPool(
/*sender=*/ clientCtx.GetFromAddress().String(),
poolId,
tokensIn,
useAllCoins,
)

return tx.GenerateOrBroadcastTxCLI(clientCtx, flagSet, msg)
Expand All @@ -147,6 +153,7 @@ func CmdJoinPool() *cobra.Command {

_ = cmd.MarkFlagRequired(FlagPoolId)
_ = cmd.MarkFlagRequired(FlagTokensIn)
_ = cmd.MarkFlagRequired(FlagUseAllCoins)

return cmd
}
Expand Down
2 changes: 1 addition & 1 deletion x/dex/client/testutil/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (s *IntegrationTestSuite) TestNewJoinPoolCmd() {
s.Run(tc.name, func() {
ctx := val.ClientCtx

out, err := ExecMsgJoinPool(ctx, tc.poolId, val.Address, tc.tokensIn)
out, err := ExecMsgJoinPool(ctx, tc.poolId, val.Address, tc.tokensIn, "false")
if tc.expectErr {
s.Require().Error(err)
} else {
Expand Down
2 changes: 2 additions & 0 deletions x/dex/client/testutil/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ func ExecMsgJoinPool(
poolId uint64,
sender fmt.Stringer,
tokensIn string,
useAllCoins string,
extraArgs ...string,
) (testutil.BufferWriter, error) {
args := []string{
fmt.Sprintf("--%s=%d", cli.FlagPoolId, poolId),
fmt.Sprintf("--%s=%s", cli.FlagTokensIn, tokensIn),
fmt.Sprintf("--%s=%s", cli.FlagUseAllCoins, useAllCoins),
fmt.Sprintf("--%s=%s", flags.FlagFrom, sender.String()),
fmt.Sprintf("--%s=%d", flags.FlagGas, 300000),
}
Expand Down
2 changes: 1 addition & 1 deletion x/dex/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (k queryServer) EstimateSwapExactAmountIn(
return nil, err
}

tokenOut, err := pool.CalcOutAmtGivenIn(req.TokenIn, req.TokenOutDenom)
tokenOut, err := pool.CalcOutAmtGivenIn(req.TokenIn, req.TokenOutDenom, false)
if err != nil {
return nil, err
}
Expand Down
8 changes: 7 additions & 1 deletion x/dex/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ func (k Keeper) JoinPool(
joinerAddr sdk.AccAddress,
poolId uint64,
tokensIn sdk.Coins,
shouldSwap bool,
) (pool types.Pool, numSharesOut sdk.Coin, remCoins sdk.Coins, err error) {
pool, _ = k.FetchPool(ctx, poolId)

Expand All @@ -445,7 +446,12 @@ func (k Keeper) JoinPool(

poolAddr := pool.GetAddress()

numShares, remCoins, err := pool.AddTokensToPool(tokensIn)
var numShares sdk.Int
if !shouldSwap {
numShares, remCoins, err = pool.AddTokensToPool(tokensIn)
} else {
numShares, remCoins, err = pool.AddAllTokensToPool(tokensIn)
}
if err != nil {
return types.Pool{}, sdk.Coin{}, sdk.Coins{}, err
}
Expand Down
131 changes: 130 additions & 1 deletion x/dex/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,136 @@ func TestJoinPool(t *testing.T) {
joinerAddr := testutil.AccAddress()
require.NoError(t, simapp.FundAccount(app.BankKeeper, ctx, joinerAddr, tc.joinerInitialFunds))

pool, numSharesOut, remCoins, err := app.DexKeeper.JoinPool(ctx, joinerAddr, 1, tc.tokensIn)
pool, numSharesOut, remCoins, err := app.DexKeeper.JoinPool(ctx, joinerAddr, 1, tc.tokensIn, false)
require.NoError(t, err)
require.Equal(t, tc.expectedFinalPool, pool)
require.Equal(t, tc.expectedNumSharesOut, numSharesOut)
require.Equal(t, tc.expectedRemCoins, remCoins)
})
}
}

func TestJoinPoolAllAssets(t *testing.T) {
const shareDenom = "nibiru/pool/1"

tests := []struct {
name string
joinerInitialFunds sdk.Coins
initialPool types.Pool
tokensIn sdk.Coins
expectedNumSharesOut sdk.Coin
expectedRemCoins sdk.Coins
expectedJoinerFinalFunds sdk.Coins
expectedFinalPool types.Pool
}{
{
name: "join with all assets",
joinerInitialFunds: sdk.NewCoins(
sdk.NewInt64Coin("bar", 100),
sdk.NewInt64Coin("foo", 100),
),
initialPool: mock.DexPool(
/*poolId=*/ 1,
/*assets=*/ sdk.NewCoins(
sdk.NewInt64Coin("bar", 100),
sdk.NewInt64Coin("foo", 100),
),
/*shares=*/ 100),
tokensIn: sdk.NewCoins(
sdk.NewInt64Coin("bar", 100),
sdk.NewInt64Coin("foo", 100),
),
expectedNumSharesOut: sdk.NewInt64Coin(shareDenom, 100),
expectedRemCoins: sdk.NewCoins(),
expectedJoinerFinalFunds: sdk.NewCoins(sdk.NewInt64Coin(shareDenom, 100)),
expectedFinalPool: mock.DexPool(
/*poolId=*/ 1,
/*assets=*/ sdk.NewCoins(
sdk.NewInt64Coin("bar", 200),
sdk.NewInt64Coin("foo", 200),
),
/*shares=*/ 200),
},
{
name: "join with some assets, none remaining",
joinerInitialFunds: sdk.NewCoins(
sdk.NewInt64Coin("bar", 100),
sdk.NewInt64Coin("foo", 100),
),
initialPool: mock.DexPool(
/*poolId=*/ 1,
/*assets=*/ sdk.NewCoins(
sdk.NewInt64Coin("bar", 100),
sdk.NewInt64Coin("foo", 100),
),
/*shares=*/ 100),
tokensIn: sdk.NewCoins(
sdk.NewInt64Coin("bar", 50),
sdk.NewInt64Coin("foo", 50),
),
expectedNumSharesOut: sdk.NewInt64Coin(shareDenom, 50),
expectedRemCoins: sdk.NewCoins(),
expectedJoinerFinalFunds: sdk.NewCoins(
sdk.NewInt64Coin(shareDenom, 50),
sdk.NewInt64Coin("bar", 50),
sdk.NewInt64Coin("foo", 50),
),
expectedFinalPool: mock.DexPool(
/*poolId=*/ 1,
/*assets=*/ sdk.NewCoins(
sdk.NewInt64Coin("bar", 150),
sdk.NewInt64Coin("foo", 150),
),
/*shares=*/ 150),
},
{
name: "join with some assets, but swap done",
joinerInitialFunds: sdk.NewCoins(
sdk.NewInt64Coin("bar", 100),
sdk.NewInt64Coin("foo", 100),
),
initialPool: mock.DexPool(
/*poolId=*/ 1,
/*assets=*/ sdk.NewCoins(
sdk.NewInt64Coin("bar", 100),
sdk.NewInt64Coin("foo", 100),
),
/*shares=*/ 100),
tokensIn: sdk.NewCoins(
sdk.NewInt64Coin("bar", 50),
sdk.NewInt64Coin("foo", 75),
),
expectedNumSharesOut: sdk.NewInt64Coin(shareDenom, 61),
expectedRemCoins: sdk.NewCoins(),
expectedJoinerFinalFunds: sdk.NewCoins(
sdk.NewInt64Coin(shareDenom, 50),
sdk.NewInt64Coin("bar", 35),
sdk.NewInt64Coin("foo", 35),
),
expectedFinalPool: mock.DexPool(
/*poolId=*/ 1,
/*assets=*/ sdk.NewCoins(
sdk.NewInt64Coin("bar", 150),
sdk.NewInt64Coin("foo", 175),
),
/*shares=*/ 161),
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
app, ctx := simapp2.NewTestNibiruAppAndContext(true)

poolAddr := testutil.AccAddress()
tc.initialPool.Address = poolAddr.String()
tc.expectedFinalPool.Address = poolAddr.String()
app.DexKeeper.SetPool(ctx, tc.initialPool)

joinerAddr := testutil.AccAddress()
require.NoError(t, simapp.FundAccount(app.BankKeeper, ctx, joinerAddr, tc.joinerInitialFunds))

pool, numSharesOut, remCoins, err := app.DexKeeper.JoinPool(ctx, joinerAddr, 1, tc.tokensIn, true)
require.NoError(t, err)
require.Equal(t, tc.expectedFinalPool, pool)
require.Equal(t, tc.expectedNumSharesOut, numSharesOut)
Expand Down
1 change: 1 addition & 0 deletions x/dex/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func (k msgServer) JoinPool(ctx context.Context, msg *types.MsgJoinPool) (*types
sender,
msg.PoolId,
msg.TokensIn,
msg.UseAllCoins,
)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion x/dex/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func TestMsgServerJoinPool(t *testing.T) {
msgServer := keeper.NewMsgServerImpl(app.DexKeeper)
resp, err := msgServer.JoinPool(
sdk.WrapSDKContext(ctx),
types.NewMsgJoinPool(joinerAddr.String(), tc.initialPool.Id, tc.tokensIn),
types.NewMsgJoinPool(joinerAddr.String(), tc.initialPool.Id, tc.tokensIn, false),
)

require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion x/dex/keeper/swap.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (k Keeper) SwapExactAmountIn(
}

// calculate tokenOut and validate
tokenOut, err = pool.CalcOutAmtGivenIn(tokenIn, tokenOutDenom)
tokenOut, err = pool.CalcOutAmtGivenIn(tokenIn, tokenOutDenom, false)
if err != nil {
return sdk.Coin{}, err
}
Expand Down
9 changes: 5 additions & 4 deletions x/dex/types/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ func (msg *MsgExitPool) ValidateBasic() error {

var _ sdk.Msg = &MsgJoinPool{}

func NewMsgJoinPool(sender string, poolId uint64, tokensIn sdk.Coins) *MsgJoinPool {
func NewMsgJoinPool(sender string, poolId uint64, tokensIn sdk.Coins, useAllCoins bool) *MsgJoinPool {
return &MsgJoinPool{
Sender: sender,
PoolId: poolId,
TokensIn: tokensIn,
Sender: sender,
PoolId: poolId,
TokensIn: tokensIn,
UseAllCoins: useAllCoins,
}
}

Expand Down
59 changes: 59 additions & 0 deletions x/dex/types/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,65 @@ func (pool *Pool) AddTokensToPool(tokensIn sdk.Coins) (
return numShares, remCoins, nil
}

/*
Adds tokens to a pool optimizing the amount of shares (swap + join) and updates the pool balances (i.e. liquidity).
We compute the swap and then join the pool.

args:
- tokensIn: the tokens to add to the pool

ret:
- numShares: the number of LP shares given to the user for the deposit
- remCoins: the number of coins remaining after the deposit
- err: error if any
*/
func (pool *Pool) AddAllTokensToPool(tokensIn sdk.Coins) (
numShares sdk.Int, remCoins sdk.Coins, err error,
) {
swapToken, err := pool.SwapForSwapAndJoin(tokensIn)
if err != nil {
return
}
if swapToken.Amount.LT(sdk.OneInt()) {
return pool.AddTokensToPool(tokensIn)
}

index, _, err := pool.getPoolAssetAndIndex(swapToken.Denom)

if err != nil {
return
}

otherDenom := pool.PoolAssets[1-index].Token.Denom
tokenOut, err := pool.CalcOutAmtGivenIn(
/*tokenIn=*/ swapToken,
/*tokenOutDenom=*/ otherDenom,
/*noFee=*/ true,
)

if err != nil {
return
}

err = pool.ApplySwap(swapToken, tokenOut)

if err != nil {
return
}

tokensIn = sdk.Coins{
{
Denom: swapToken.Denom,
Amount: tokensIn.AmountOfNoDenomValidation(swapToken.Denom).Sub(swapToken.Amount),
},
{
Denom: otherDenom,
Amount: tokensIn.AmountOfNoDenomValidation(otherDenom).Add(tokenOut.Amount),
},
}.Sort()
return pool.AddTokensToPool(tokensIn)
}

/*
Fetch the pool's address as an sdk.Address.
*/
Expand Down
Loading