Skip to content

Commit

Permalink
Gamm stableswap improvements (backport #3839) (#3933)
Browse files Browse the repository at this point in the history
* Gamm stableswap improvements (#3839)

* Stableswap: README pseudocode fixes

* Stableswap: Binary search code improvement

* Stableswap: minor code improvements

* Stableswap: minor code improvements 2

* Binary search: Check potential division by zero

(cherry picked from commit 2ac5d35)

# Conflicts:
#	x/gamm/pool-models/stableswap/util_test.go

* Fix conflict

* Update osmomath go mod

Co-authored-by: Aleksandar Ljahović <[email protected]>
Co-authored-by: Dev Ojha <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2023
1 parent 243faec commit 1067f1e
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 137 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.16
github.com/ory/dockertest/v3 v3.9.1
github.com/osmosis-labs/go-mutesting v0.0.0-20221208041716-b43bcd97b3b3
github.com/osmosis-labs/osmosis/osmomath v0.0.0-20230106110532-e17f2f459464
github.com/osmosis-labs/osmosis/osmomath v0.0.0-20230106133904-bf95f2df4908
github.com/osmosis-labs/osmosis/osmoutils v0.0.0-20230106095152-4f77cc5e42af
github.com/osmosis-labs/osmosis/x/ibc-hooks v0.0.0-20230106110415-61e4300ada92
github.com/pkg/errors v0.9.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ github.com/osmosis-labs/cosmos-sdk v0.45.1-0.20221118211718-545aed73e94e h1:A3by
github.com/osmosis-labs/cosmos-sdk v0.45.1-0.20221118211718-545aed73e94e/go.mod h1:rud0OaBIuq3+qOqtwT4SR7Q7iSzRp7w41fjninTjfnQ=
github.com/osmosis-labs/go-mutesting v0.0.0-20221208041716-b43bcd97b3b3 h1:YlmchqTmlwdWSmrRmXKR+PcU96ntOd8u10vTaTZdcNY=
github.com/osmosis-labs/go-mutesting v0.0.0-20221208041716-b43bcd97b3b3/go.mod h1:lV6KnqXYD/ayTe7310MHtM3I2q8Z6bBfMAi+bhwPYtI=
github.com/osmosis-labs/osmosis/osmomath v0.0.0-20230106110532-e17f2f459464 h1:2562qeTuCCb1IQBcbKjQgVD+cpPxVMzi//Dr/1uQbGc=
github.com/osmosis-labs/osmosis/osmomath v0.0.0-20230106110532-e17f2f459464/go.mod h1:KrzYoNtnWUH75rj1XAsSR4nymlHFU7jeVOx7/1KMe0k=
github.com/osmosis-labs/osmosis/osmomath v0.0.0-20230106133904-bf95f2df4908 h1:N7JvlXT8N82iN7NDzX6CXFOFrpFTWqSfJuAZcInQ+/Y=
github.com/osmosis-labs/osmosis/osmomath v0.0.0-20230106133904-bf95f2df4908/go.mod h1:KrzYoNtnWUH75rj1XAsSR4nymlHFU7jeVOx7/1KMe0k=
github.com/osmosis-labs/osmosis/osmoutils v0.0.0-20230106095152-4f77cc5e42af h1:/UyyIUTH2FZaN7xULcA2SQ6+bKD9fsau4PJkYWMir3g=
github.com/osmosis-labs/osmosis/osmoutils v0.0.0-20230106095152-4f77cc5e42af/go.mod h1:K4de+n3DtLdueen98dOzaRXZvqMd8JvigL8O1xW445o=
github.com/osmosis-labs/osmosis/x/ibc-hooks v0.0.0-20230106110415-61e4300ada92 h1:aXAru0jzeTjrSmFEcqIXUI23yaE0RqpwKg04HK1sYcs=
Expand Down
56 changes: 31 additions & 25 deletions osmomath/binary_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ func (e ErrTolerance) Compare(expected sdk.Int, actual sdk.Int) int {
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.Quo(sdk.MinInt(expected.Abs(), actual.Abs()).ToDec())
minValue := sdk.MinInt(expected.Abs(), actual.Abs())
if minValue.IsZero() {
return comparisonSign
}

errTerm := diff.Quo(minValue.ToDec())
if errTerm.GT(e.MultiplicativeTolerance) {
return comparisonSign
}
Expand Down Expand Up @@ -121,7 +126,12 @@ func (e ErrTolerance) CompareBigDec(expected BigDec, actual BigDec) int {
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.Quo(MinDec(expected.Abs(), actual.Abs()))
minValue := MinDec(expected.Abs(), actual.Abs())
if minValue.IsZero() {
return comparisonSign
}

errTerm := diff.Quo(minValue)
// fmt.Printf("err term %v\n", errTerm)
if errTerm.GT(BigDecFromSDKDec(e.MultiplicativeTolerance)) {
return comparisonSign
Expand All @@ -141,14 +151,19 @@ func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
errTolerance ErrTolerance,
maxIterations int,
) (sdk.Int, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err := f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
var (
curEstimate, curOutput sdk.Int
err error
)

curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}

compRes := errTolerance.Compare(targetOutput, curOutput)
if compRes < 0 {
upperbound = curEstimate
Expand All @@ -157,11 +172,6 @@ func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
} else {
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
}

return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough")
Expand All @@ -182,21 +192,22 @@ type SdkDec[D any] interface {
//
// It binary searches on the input range, until it finds an input y s.t. f(y) meets the err tolerance constraints for how close it is to x.
// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error.
func BinarySearchBigDec(f func(input BigDec) (BigDec, error),
func BinarySearchBigDec(f func(input BigDec) BigDec,
lowerbound BigDec,
upperbound BigDec,
targetOutput BigDec,
errTolerance ErrTolerance,
maxIterations int,
) (BigDec, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput, err := f(curEstimate)
if err != nil {
return BigDec{}, err
}
var (
curEstimate, curOutput BigDec
)

curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
curEstimate = lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput = f(curEstimate)

// fmt.Println("binary search, input, target output, cur output", curEstimate, targetOutput, curOutput)
compRes := errTolerance.CompareBigDec(targetOutput, curOutput)
if compRes < 0 {
Expand All @@ -206,11 +217,6 @@ func BinarySearchBigDec(f func(input BigDec) (BigDec, error),
} else {
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput, err = f(curEstimate)
if err != nil {
return BigDec{}, err
}
}

return BigDec{}, errors.New("hit maximum iterations, did not converge fast enough")
Expand Down
16 changes: 8 additions & 8 deletions osmomath/binary_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ func TestBinarySearch(t *testing.T) {

// straight line function that returns input. Simplest to binary search on,
// binary search directly reveals one bit of the answer in each iteration with this function.
func lineF(a BigDec) (BigDec, error) {
return a, nil
func lineF(a BigDec) BigDec {
return a
}
func cubicF(a BigDec) (BigDec, error) {
return a.PowerInteger(3), nil
func cubicF(a BigDec) BigDec {
return a.PowerInteger(3)
}

var negCubicFConstant BigDec
Expand All @@ -89,11 +89,11 @@ func init() {
negCubicFConstant = NewBigDec(1 << 62).PowerInteger(3).Neg()
}

func negCubicF(a BigDec) (BigDec, error) {
return a.PowerInteger(3).Add(negCubicFConstant), nil
func negCubicF(a BigDec) BigDec {
return a.PowerInteger(3).Add(negCubicFConstant)
}

type searchFn func(BigDec) (BigDec, error)
type searchFn func(BigDec) BigDec

type binarySearchTestCase struct {
f searchFn
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestIterationDepthRandValue(t *testing.T) {
errTolerance ErrTolerance, maxNumIters int, errToleranceName string) {
targetF := fnMap[fnName]
targetX := int64(rand.Intn(int(upperbound-lowerbound-1))) + lowerbound + 1
target, _ := targetF(NewBigDec(targetX))
target := targetF(NewBigDec(targetX))
testCase := binarySearchTestCase{
f: lineF,
lowerbound: NewBigDec(lowerbound), upperbound: NewBigDec(upperbound),
Expand Down
4 changes: 2 additions & 2 deletions x/gamm/pool-models/stableswap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def CalcOutAmountGivenExactAmountIn(pool, in_coin, out_denom, swap_fee):
in_reserve, out_reserve, rem_reserves = pool.ScaledLiquidity(in_coin, out_denom, RoundingMode.RoundDown)
in_amt_scaled = pool.ScaleToken(in_coin, RoundingMode.RoundDown)
amm_in = in_amt_scaled * (1 - swap_fee)
out_amt_scaled = solve_y(in_reserve, out_reserve, remReserves, in_amt_scaled)
out_amt_scaled = solve_y(in_reserve, out_reserve, remReserves, amm_in)
out_amt = pool.DescaleToken(out_amt_scaled, out_denom)
return out_amt
```
Expand All @@ -308,7 +308,7 @@ We do this by having `token_in = amm_in / (1 - swapfee)`.
```python
def CalcInAmountGivenExactAmountOut(pool, out_coin, in_denom, swap_fee):
in_reserve, out_reserve, rem_reserves = pool.ScaledLiquidity(in_denom, out_coin, RoundingMode.RoundDown)
out_amt_scaled = pool.ScaleToken(in_coin, RoundingMode.RoundUp)
out_amt_scaled = pool.ScaleToken(out_coin, RoundingMode.RoundUp)

amm_in_scaled = solve_y(out_reserve, in_reserve, remReserves, -out_amt_scaled)
swap_in_scaled = ceil(amm_in_scaled / (1 - swapfee))
Expand Down
35 changes: 13 additions & 22 deletions x/gamm/pool-models/stableswap/amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ func cfmmConstant(xReserve, yReserve osmomath.BigDec) osmomath.BigDec {
// We use this version for calculations since the u
// term in the full CFMM is constant.
func cfmmConstantMultiNoV(xReserve, yReserve, wSumSquares osmomath.BigDec) osmomath.BigDec {
if !xReserve.IsPositive() || !yReserve.IsPositive() || wSumSquares.IsNegative() {
panic("invalid input: reserves must be positive")
}

return cfmmConstantMultiNoVY(xReserve, yReserve, wSumSquares).Mul(yReserve)
}

Expand Down Expand Up @@ -263,19 +259,19 @@ func targetKCalculator(x0, y0, w, yf osmomath.BigDec) osmomath.BigDec {

// $$k_{iter}(x_f) = -x_{out}^3 + 3 x_0 x_{out}^2 - (y_f^2 + w + 3x_0^2)x_{out}$$
// where x_out = x_0 - x_f
func iterKCalculator(x0, w, yf osmomath.BigDec) func(osmomath.BigDec) (osmomath.BigDec, error) {
func iterKCalculator(x0, w, yf osmomath.BigDec) func(osmomath.BigDec) osmomath.BigDec {
// compute coefficients first
cubicCoeff := osmomath.OneDec().Neg()
quadraticCoeff := x0.MulInt64(3)
linearCoeff := quadraticCoeff.Mul(x0).Add(w).Add(yf.Mul(yf)).Neg()
return func(xf osmomath.BigDec) (osmomath.BigDec, error) {
return func(xf osmomath.BigDec) osmomath.BigDec {
xOut := x0.Sub(xf)
// horners method
// ax^3 + bx^2 + cx = x(c + x(b + ax))
res := cubicCoeff.Mul(xOut)
res = res.Add(quadraticCoeff).Mul(xOut)
res = res.Add(linearCoeff).Mul(xOut)
return res, nil
return res
}
}

Expand Down Expand Up @@ -488,34 +484,29 @@ func (p *Pool) joinPoolSharesInternal(ctx sdk.Context, tokensIn sdk.Coins, swapF
if !tokensIn.DenomsSubsetOf(p.GetTotalPoolLiquidity(ctx)) {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New("attempted joining pool with assets that do not exist in pool")
}

if len(tokensIn) == 1 && tokensIn[0].Amount.GT(sdk.OneInt()) {
numShares, err = p.calcSingleAssetJoinShares(tokensIn[0], swapFee)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

tokensJoined = tokensIn

p.updatePoolForJoin(tokensJoined, numShares)

if err = validatePoolLiquidity(p.PoolLiquidity, p.ScalingFactors); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

return numShares, tokensJoined, nil
} else if len(tokensIn) != p.NumAssets() {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New(
"stableswap pool only supports LP'ing with one asset, or all assets in pool")
}
} else {
// Add all exact coins we can (no swap). ctx arg doesn't matter for Stableswap
var remCoins sdk.Coins
numShares, remCoins, err = cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

// Add all exact coins we can (no swap). ctx arg doesn't matter for Stableswap
numShares, remCoins, err := cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
tokensJoined = tokensIn.Sub(remCoins)
}
p.updatePoolForJoin(tokensIn.Sub(remCoins), numShares)

tokensJoined = tokensIn.Sub(remCoins)
p.updatePoolForJoin(tokensJoined, numShares)

if err = validatePoolLiquidity(p.PoolLiquidity, p.ScalingFactors); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
Expand Down
Loading

0 comments on commit 1067f1e

Please sign in to comment.