diff --git a/src/libraries/BlueLib.sol b/src/libraries/BlueLib.sol new file mode 100644 index 000000000..a9783fae1 --- /dev/null +++ b/src/libraries/BlueLib.sol @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import {Id, Market, IBlue} from "../interfaces/IBlue.sol"; + +import {MarketLib} from "./MarketLib.sol"; +import {SharesMath} from "./SharesMath.sol"; + +library BlueLib { + using MarketLib for Market; + using SharesMath for uint256; + + function withdrawAmount(IBlue blue, Market memory market, uint256 amount, address onBehalf, address receiver) + external + returns (uint256 shares) + { + Id id = market.id(); + shares = amount.toWithdrawShares(blue.totalSupply(id), blue.totalSupplyShares(id)); + + uint256 maxShares = blue.supplyShares(id, address(this)); + if (shares > maxShares) shares = maxShares; + + blue.withdraw(market, shares, onBehalf, receiver); + } + + function repayAmount(IBlue blue, Market memory market, uint256 amount, address onBehalf, bytes calldata data) + external + returns (uint256 shares) + { + Id id = market.id(); + shares = amount.toRepayShares(blue.totalBorrow(id), blue.totalBorrowShares(id)); + + uint256 maxShares = blue.borrowShares(id, address(this)); + if (shares > maxShares) shares = maxShares; + + blue.repay(market, shares, onBehalf, data); + } +} diff --git a/src/libraries/SharesMath.sol b/src/libraries/SharesMath.sol index 3d9728c3d..9f1dab24f 100644 --- a/src/libraries/SharesMath.sol +++ b/src/libraries/SharesMath.sol @@ -34,4 +34,32 @@ library SharesMath { function toAssetsUp(uint256 shares, uint256 totalAssets, uint256 totalShares) internal pure returns (uint256) { return shares.mulDivUp(totalAssets + VIRTUAL_ASSETS, totalShares + VIRTUAL_SHARES); } + + /// @dev Calculates the amount of shares corresponding to an exact amount of supply to withdraw. + /// Note: only works as long as totalSupplyShares + VIRTUAL_SHARES >= totalSupply + VIRTUAL_ASSETS. + function toWithdrawShares(uint256 amount, uint256 totalSupply, uint256 totalSupplyShares) + internal + pure + returns (uint256) + { + uint256 sharesMin = toSharesDown(amount, totalSupply, totalSupplyShares); + uint256 sharesMax = toSharesUp(amount + 1, totalSupply, totalSupplyShares); + + return (sharesMin + sharesMax) / 2; + } + + /// @dev Calculates the amount of shares corresponding to an exact amount of debt to repay. + /// Note: only works as long as totalBorrowShares + VIRTUAL_SHARES >= totalBorrow + VIRTUAL_ASSETS. + function toRepayShares(uint256 amount, uint256 totalBorrow, uint256 totalBorrowShares) + internal + pure + returns (uint256) + { + if (amount == 0) return 0; + + uint256 sharesMin = toSharesDown(amount - 1, totalBorrow, totalBorrowShares); + uint256 sharesMax = toSharesUp(amount, totalBorrow, totalBorrowShares); + + return (sharesMin + sharesMax) / 2; + } } diff --git a/test/forge/Blue.t.sol b/test/forge/Blue.t.sol index 52ce9a666..c0940f99c 100644 --- a/test/forge/Blue.t.sol +++ b/test/forge/Blue.t.sol @@ -8,6 +8,7 @@ import {SigUtils} from "./helpers/SigUtils.sol"; import "src/Blue.sol"; import {SharesMath} from "src/libraries/SharesMath.sol"; +import {BlueLib} from "src/libraries/BlueLib.sol"; import { IBlueLiquidateCallback, IBlueRepayCallback, @@ -25,6 +26,7 @@ contract BlueTest is IBlueRepayCallback, IBlueLiquidateCallback { + using BlueLib for IBlue; using MarketLib for Market; using SharesMath for uint256; using stdStorage for StdStorage; @@ -400,44 +402,7 @@ contract BlueTest is ); } - function testWithdrawMinAmount(uint256 amountLent, uint256 minAmountWithdrawn) public { - _testWithdrawCommon(amountLent); - - uint256 totalSupplyBefore = blue.totalSupply(id); - uint256 supplySharesBefore = blue.supplyShares(id, address(this)); - minAmountWithdrawn = bound( - minAmountWithdrawn, 1, supplySharesBefore.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id)) - ); - uint256 sharesWithdrawn = (minAmountWithdrawn + 1).toSharesUp(blue.totalSupply(id), blue.totalSupplyShares(id)); - sharesWithdrawn = sharesWithdrawn > supplySharesBefore ? supplySharesBefore : sharesWithdrawn; - uint256 realAmountWithdrawn = sharesWithdrawn.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id)); - if (sharesWithdrawn > 0) blue.withdraw(market, sharesWithdrawn, address(this), address(this)); - - assertGe(realAmountWithdrawn, minAmountWithdrawn, "realAmountWithdrawn"); - assertEq(blue.supplyShares(id, address(this)), supplySharesBefore - sharesWithdrawn, "supply share"); - assertEq(borrowableAsset.balanceOf(address(this)), realAmountWithdrawn, "this balance"); - assertEq(borrowableAsset.balanceOf(address(blue)), totalSupplyBefore - realAmountWithdrawn, "blue balance"); - } - - function testWithdrawMaxAmount(uint256 amountLent, uint256 maxAmountWithdrawn) public { - _testWithdrawCommon(amountLent); - - uint256 totalSupplyBefore = blue.totalSupply(id); - uint256 supplySharesBefore = blue.supplyShares(id, address(this)); - maxAmountWithdrawn = bound( - maxAmountWithdrawn, 1, supplySharesBefore.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id)) - ); - uint256 sharesWithdrawn = maxAmountWithdrawn.toSharesDown(blue.totalSupply(id), blue.totalSupplyShares(id)); - uint256 realAmountWithdrawn = sharesWithdrawn.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id)); - if (sharesWithdrawn > 0) blue.withdraw(market, sharesWithdrawn, address(this), address(this)); - - assertLe(realAmountWithdrawn, maxAmountWithdrawn, "realAmountWithdrawn"); - assertEq(blue.supplyShares(id, address(this)), supplySharesBefore - sharesWithdrawn, "supply share"); - assertEq(borrowableAsset.balanceOf(address(this)), realAmountWithdrawn, "this balance"); - assertEq(borrowableAsset.balanceOf(address(blue)), totalSupplyBefore - realAmountWithdrawn, "blue balance"); - } - - function testWithdrawExactAmount(uint256 amountLent, uint256 exactAmountWithdrawn) public { + function testWithdrawAmount(uint256 amountLent, uint256 exactAmountWithdrawn) public { _testWithdrawCommon(amountLent); uint256 totalSupplyBefore = blue.totalSupply(id); @@ -445,12 +410,7 @@ contract BlueTest is exactAmountWithdrawn = bound( exactAmountWithdrawn, 1, supplySharesBefore.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id)) ); - uint256 sharesWithdrawnMin = exactAmountWithdrawn.toSharesDown(blue.totalSupply(id), blue.totalSupplyShares(id)); - uint256 sharesWithdrawnMax = - (exactAmountWithdrawn + 1).toSharesUp(blue.totalSupply(id), blue.totalSupplyShares(id)); - uint256 sharesWithdrawn = (sharesWithdrawnMin + sharesWithdrawnMax) / 2; - sharesWithdrawn = sharesWithdrawn > supplySharesBefore ? supplySharesBefore : sharesWithdrawn; - blue.withdraw(market, sharesWithdrawn, address(this), address(this)); + uint256 sharesWithdrawn = blue.withdrawAmount(market, exactAmountWithdrawn, address(this), address(this)); assertEq(blue.supplyShares(id, address(this)), supplySharesBefore - sharesWithdrawn, "supply share"); assertEq(borrowableAsset.balanceOf(address(this)), exactAmountWithdrawn, "this balance"); @@ -501,53 +461,14 @@ contract BlueTest is assertEq(borrowableAsset.balanceOf(address(blue)), amountRepaid, "blue balance"); } - function testRepayMinAmount(uint256 amountBorrowed, uint256 minAmountRepaid) public { - _testRepayCommon(amountBorrowed, address(this)); - - uint256 thisBalanceBefore = borrowableAsset.balanceOf(address(this)); - uint256 borrowSharesBefore = blue.borrowShares(id, address(this)); - minAmountRepaid = - bound(minAmountRepaid, 1, borrowSharesBefore.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id))); - uint256 sharesRepaid = (minAmountRepaid - 1).toSharesDown(blue.totalBorrow(id), blue.totalBorrowShares(id)); - uint256 realAmountRepaid = sharesRepaid.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id)); - if (sharesRepaid > 0) blue.repay(market, sharesRepaid, address(this), hex""); - - assertLe(realAmountRepaid, minAmountRepaid, "real amount repaid"); - assertEq(blue.borrowShares(id, address(this)), borrowSharesBefore - sharesRepaid, "borrow share"); - assertEq(borrowableAsset.balanceOf(address(this)), thisBalanceBefore - realAmountRepaid, "this balance"); - assertEq(borrowableAsset.balanceOf(address(blue)), realAmountRepaid, "blue balance"); - } - - function testRepayMaxAmount(uint256 amountBorrowed, uint256 maxAmountRepaid) public { - _testRepayCommon(amountBorrowed, address(this)); - - uint256 thisBalanceBefore = borrowableAsset.balanceOf(address(this)); - uint256 borrowSharesBefore = blue.borrowShares(id, address(this)); - maxAmountRepaid = - bound(maxAmountRepaid, 1, borrowSharesBefore.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id))); - uint256 sharesRepaid = maxAmountRepaid.toSharesUp(blue.totalBorrow(id), blue.totalBorrowShares(id)); - sharesRepaid = sharesRepaid > borrowSharesBefore ? borrowSharesBefore : sharesRepaid; - uint256 realAmountRepaid = sharesRepaid.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id)); - if (sharesRepaid > 0) blue.repay(market, sharesRepaid, address(this), hex""); - - assertGe(realAmountRepaid, maxAmountRepaid, "real amount repaid"); - assertEq(blue.borrowShares(id, address(this)), borrowSharesBefore - sharesRepaid, "borrow share"); - assertEq(borrowableAsset.balanceOf(address(this)), thisBalanceBefore - realAmountRepaid, "this balance"); - assertEq(borrowableAsset.balanceOf(address(blue)), realAmountRepaid, "blue balance"); - } - - function testRepayExactAmount(uint256 amountBorrowed, uint256 exactAmountRepaid) public { + function testRepayAmount(uint256 amountBorrowed, uint256 exactAmountRepaid) public { _testRepayCommon(amountBorrowed, address(this)); uint256 thisBalanceBefore = borrowableAsset.balanceOf(address(this)); uint256 borrowSharesBefore = blue.borrowShares(id, address(this)); exactAmountRepaid = bound(exactAmountRepaid, 1, borrowSharesBefore.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id))); - uint256 sharesRepaidMin = (exactAmountRepaid - 1).toSharesDown(blue.totalBorrow(id), blue.totalBorrowShares(id)); - uint256 sharesRepaidMax = exactAmountRepaid.toSharesUp(blue.totalBorrow(id), blue.totalBorrowShares(id)); - uint256 sharesRepaid = (sharesRepaidMin + sharesRepaidMax + 1) / 2; - sharesRepaid = sharesRepaid > borrowSharesBefore ? borrowSharesBefore : sharesRepaid; - blue.repay(market, sharesRepaid, address(this), hex""); + uint256 sharesRepaid = blue.repayAmount(market, exactAmountRepaid, address(this), hex""); assertEq(blue.borrowShares(id, address(this)), borrowSharesBefore - sharesRepaid, "borrow share"); assertEq(borrowableAsset.balanceOf(address(this)), thisBalanceBefore - exactAmountRepaid, "this balance"); diff --git a/test/forge/SharesMath.t.sol b/test/forge/SharesMath.t.sol new file mode 100644 index 000000000..5165bab90 --- /dev/null +++ b/test/forge/SharesMath.t.sol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; +import "forge-std/console2.sol"; + +import {SharesMath} from "src/libraries/SharesMath.sol"; + +contract SharesMathTest is Test { + using SharesMath for uint256; + + function testToSupplyShares(uint256 amount, uint256 supplyShares, uint256 totalSupply, uint256 totalSupplyShares) + public + { + totalSupplyShares = bound(totalSupplyShares, SharesMath.VIRTUAL_SHARES, type(uint128).max); + totalSupply = bound(totalSupply, 0, totalSupplyShares); + + supplyShares = bound(supplyShares, 0, totalSupplyShares); + + amount = bound(amount, 0, supplyShares.toAssetsDown(totalSupply, totalSupplyShares)); + + assertEq( + amount, amount.toWithdrawShares(totalSupply, totalSupplyShares).toAssetsDown(totalSupply, totalSupplyShares) + ); + } + + function testToBorrowShares(uint256 amount, uint256 borrowShares, uint256 totalBorrow, uint256 totalBorrowShares) + public + { + totalBorrowShares = bound(totalBorrowShares, SharesMath.VIRTUAL_SHARES, type(uint128).max); + totalBorrow = bound(totalBorrow, 0, totalBorrowShares); + + borrowShares = bound(borrowShares, 0, totalBorrowShares); + + amount = bound(amount, 0, borrowShares.toAssetsDown(totalBorrow, totalBorrowShares)); + + assertEq( + amount, amount.toRepayShares(totalBorrow, totalBorrowShares).toAssetsUp(totalBorrow, totalBorrowShares) + ); + } +}