Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(shares-math): add util to get shares from amount #221

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/libraries/BlueLib.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import {Id, Market, IBlue} from "../interfaces/IBlue.sol";

import {MarketLib} from "./MarketLib.sol";
import {SharesMath} from "./SharesMath.sol";

library BlueLib {
using MarketLib for Market;
using SharesMath for uint256;

function withdrawAmount(IBlue blue, Market memory market, uint256 amount, address onBehalf, address receiver)
external
returns (uint256 shares)
{
Id id = market.id();
shares = amount.toWithdrawShares(blue.totalSupply(id), blue.totalSupplyShares(id));

uint256 maxShares = blue.supplyShares(id, address(this));
if (shares > maxShares) shares = maxShares;

blue.withdraw(market, shares, onBehalf, receiver);
}

function repayAmount(IBlue blue, Market memory market, uint256 amount, address onBehalf, bytes calldata data)
external
returns (uint256 shares)
{
Id id = market.id();
shares = amount.toRepayShares(blue.totalBorrow(id), blue.totalBorrowShares(id));

uint256 maxShares = blue.borrowShares(id, address(this));
if (shares > maxShares) shares = maxShares;

blue.repay(market, shares, onBehalf, data);
}
}
28 changes: 28 additions & 0 deletions src/libraries/SharesMath.sol
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,32 @@ library SharesMath {
function toAssetsUp(uint256 shares, uint256 totalAssets, uint256 totalShares) internal pure returns (uint256) {
return shares.mulDivUp(totalAssets + VIRTUAL_ASSETS, totalShares + VIRTUAL_SHARES);
}

/// @dev Calculates the amount of shares corresponding to an exact amount of supply to withdraw.
/// Note: only works as long as totalSupplyShares + VIRTUAL_SHARES >= totalSupply + VIRTUAL_ASSETS.
function toWithdrawShares(uint256 amount, uint256 totalSupply, uint256 totalSupplyShares)
internal
pure
returns (uint256)
{
uint256 sharesMin = toSharesDown(amount, totalSupply, totalSupplyShares);
uint256 sharesMax = toSharesUp(amount + 1, totalSupply, totalSupplyShares);

return (sharesMin + sharesMax) / 2;
}

/// @dev Calculates the amount of shares corresponding to an exact amount of debt to repay.
/// Note: only works as long as totalBorrowShares + VIRTUAL_SHARES >= totalBorrow + VIRTUAL_ASSETS.
function toRepayShares(uint256 amount, uint256 totalBorrow, uint256 totalBorrowShares)
internal
pure
returns (uint256)
{
if (amount == 0) return 0;

uint256 sharesMin = toSharesDown(amount - 1, totalBorrow, totalBorrowShares);
uint256 sharesMax = toSharesUp(amount, totalBorrow, totalBorrowShares);

return (sharesMin + sharesMax) / 2;
}
}
91 changes: 6 additions & 85 deletions test/forge/Blue.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {SigUtils} from "./helpers/SigUtils.sol";

import "src/Blue.sol";
import {SharesMath} from "src/libraries/SharesMath.sol";
import {BlueLib} from "src/libraries/BlueLib.sol";
import {
IBlueLiquidateCallback,
IBlueRepayCallback,
Expand All @@ -25,6 +26,7 @@ contract BlueTest is
IBlueRepayCallback,
IBlueLiquidateCallback
{
using BlueLib for IBlue;
using MarketLib for Market;
using SharesMath for uint256;
using stdStorage for StdStorage;
Expand Down Expand Up @@ -400,57 +402,15 @@ contract BlueTest is
);
}

function testWithdrawMinAmount(uint256 amountLent, uint256 minAmountWithdrawn) public {
_testWithdrawCommon(amountLent);

uint256 totalSupplyBefore = blue.totalSupply(id);
uint256 supplySharesBefore = blue.supplyShares(id, address(this));
minAmountWithdrawn = bound(
minAmountWithdrawn, 1, supplySharesBefore.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id))
);
uint256 sharesWithdrawn = (minAmountWithdrawn + 1).toSharesUp(blue.totalSupply(id), blue.totalSupplyShares(id));
sharesWithdrawn = sharesWithdrawn > supplySharesBefore ? supplySharesBefore : sharesWithdrawn;
uint256 realAmountWithdrawn = sharesWithdrawn.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id));
if (sharesWithdrawn > 0) blue.withdraw(market, sharesWithdrawn, address(this), address(this));

assertGe(realAmountWithdrawn, minAmountWithdrawn, "realAmountWithdrawn");
assertEq(blue.supplyShares(id, address(this)), supplySharesBefore - sharesWithdrawn, "supply share");
assertEq(borrowableAsset.balanceOf(address(this)), realAmountWithdrawn, "this balance");
assertEq(borrowableAsset.balanceOf(address(blue)), totalSupplyBefore - realAmountWithdrawn, "blue balance");
}

function testWithdrawMaxAmount(uint256 amountLent, uint256 maxAmountWithdrawn) public {
_testWithdrawCommon(amountLent);

uint256 totalSupplyBefore = blue.totalSupply(id);
uint256 supplySharesBefore = blue.supplyShares(id, address(this));
maxAmountWithdrawn = bound(
maxAmountWithdrawn, 1, supplySharesBefore.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id))
);
uint256 sharesWithdrawn = maxAmountWithdrawn.toSharesDown(blue.totalSupply(id), blue.totalSupplyShares(id));
uint256 realAmountWithdrawn = sharesWithdrawn.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id));
if (sharesWithdrawn > 0) blue.withdraw(market, sharesWithdrawn, address(this), address(this));

assertLe(realAmountWithdrawn, maxAmountWithdrawn, "realAmountWithdrawn");
assertEq(blue.supplyShares(id, address(this)), supplySharesBefore - sharesWithdrawn, "supply share");
assertEq(borrowableAsset.balanceOf(address(this)), realAmountWithdrawn, "this balance");
assertEq(borrowableAsset.balanceOf(address(blue)), totalSupplyBefore - realAmountWithdrawn, "blue balance");
}

function testWithdrawExactAmount(uint256 amountLent, uint256 exactAmountWithdrawn) public {
function testWithdrawAmount(uint256 amountLent, uint256 exactAmountWithdrawn) public {
_testWithdrawCommon(amountLent);

uint256 totalSupplyBefore = blue.totalSupply(id);
uint256 supplySharesBefore = blue.supplyShares(id, address(this));
exactAmountWithdrawn = bound(
exactAmountWithdrawn, 1, supplySharesBefore.toAssetsDown(blue.totalSupply(id), blue.totalSupplyShares(id))
);
uint256 sharesWithdrawnMin = exactAmountWithdrawn.toSharesDown(blue.totalSupply(id), blue.totalSupplyShares(id));
uint256 sharesWithdrawnMax =
(exactAmountWithdrawn + 1).toSharesUp(blue.totalSupply(id), blue.totalSupplyShares(id));
uint256 sharesWithdrawn = (sharesWithdrawnMin + sharesWithdrawnMax) / 2;
sharesWithdrawn = sharesWithdrawn > supplySharesBefore ? supplySharesBefore : sharesWithdrawn;
blue.withdraw(market, sharesWithdrawn, address(this), address(this));
uint256 sharesWithdrawn = blue.withdrawAmount(market, exactAmountWithdrawn, address(this), address(this));

assertEq(blue.supplyShares(id, address(this)), supplySharesBefore - sharesWithdrawn, "supply share");
assertEq(borrowableAsset.balanceOf(address(this)), exactAmountWithdrawn, "this balance");
Expand Down Expand Up @@ -501,53 +461,14 @@ contract BlueTest is
assertEq(borrowableAsset.balanceOf(address(blue)), amountRepaid, "blue balance");
}

function testRepayMinAmount(uint256 amountBorrowed, uint256 minAmountRepaid) public {
_testRepayCommon(amountBorrowed, address(this));

uint256 thisBalanceBefore = borrowableAsset.balanceOf(address(this));
uint256 borrowSharesBefore = blue.borrowShares(id, address(this));
minAmountRepaid =
bound(minAmountRepaid, 1, borrowSharesBefore.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id)));
uint256 sharesRepaid = (minAmountRepaid - 1).toSharesDown(blue.totalBorrow(id), blue.totalBorrowShares(id));
uint256 realAmountRepaid = sharesRepaid.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id));
if (sharesRepaid > 0) blue.repay(market, sharesRepaid, address(this), hex"");

assertLe(realAmountRepaid, minAmountRepaid, "real amount repaid");
assertEq(blue.borrowShares(id, address(this)), borrowSharesBefore - sharesRepaid, "borrow share");
assertEq(borrowableAsset.balanceOf(address(this)), thisBalanceBefore - realAmountRepaid, "this balance");
assertEq(borrowableAsset.balanceOf(address(blue)), realAmountRepaid, "blue balance");
}

function testRepayMaxAmount(uint256 amountBorrowed, uint256 maxAmountRepaid) public {
_testRepayCommon(amountBorrowed, address(this));

uint256 thisBalanceBefore = borrowableAsset.balanceOf(address(this));
uint256 borrowSharesBefore = blue.borrowShares(id, address(this));
maxAmountRepaid =
bound(maxAmountRepaid, 1, borrowSharesBefore.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id)));
uint256 sharesRepaid = maxAmountRepaid.toSharesUp(blue.totalBorrow(id), blue.totalBorrowShares(id));
sharesRepaid = sharesRepaid > borrowSharesBefore ? borrowSharesBefore : sharesRepaid;
uint256 realAmountRepaid = sharesRepaid.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id));
if (sharesRepaid > 0) blue.repay(market, sharesRepaid, address(this), hex"");

assertGe(realAmountRepaid, maxAmountRepaid, "real amount repaid");
assertEq(blue.borrowShares(id, address(this)), borrowSharesBefore - sharesRepaid, "borrow share");
assertEq(borrowableAsset.balanceOf(address(this)), thisBalanceBefore - realAmountRepaid, "this balance");
assertEq(borrowableAsset.balanceOf(address(blue)), realAmountRepaid, "blue balance");
}

function testRepayExactAmount(uint256 amountBorrowed, uint256 exactAmountRepaid) public {
function testRepayAmount(uint256 amountBorrowed, uint256 exactAmountRepaid) public {
_testRepayCommon(amountBorrowed, address(this));

uint256 thisBalanceBefore = borrowableAsset.balanceOf(address(this));
uint256 borrowSharesBefore = blue.borrowShares(id, address(this));
exactAmountRepaid =
bound(exactAmountRepaid, 1, borrowSharesBefore.toAssetsUp(blue.totalBorrow(id), blue.totalBorrowShares(id)));
uint256 sharesRepaidMin = (exactAmountRepaid - 1).toSharesDown(blue.totalBorrow(id), blue.totalBorrowShares(id));
uint256 sharesRepaidMax = exactAmountRepaid.toSharesUp(blue.totalBorrow(id), blue.totalBorrowShares(id));
uint256 sharesRepaid = (sharesRepaidMin + sharesRepaidMax + 1) / 2;
sharesRepaid = sharesRepaid > borrowSharesBefore ? borrowSharesBefore : sharesRepaid;
blue.repay(market, sharesRepaid, address(this), hex"");
uint256 sharesRepaid = blue.repayAmount(market, exactAmountRepaid, address(this), hex"");

assertEq(blue.borrowShares(id, address(this)), borrowSharesBefore - sharesRepaid, "borrow share");
assertEq(borrowableAsset.balanceOf(address(this)), thisBalanceBefore - exactAmountRepaid, "this balance");
Expand Down
41 changes: 41 additions & 0 deletions test/forge/SharesMath.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import "forge-std/Test.sol";
import "forge-std/console2.sol";

import {SharesMath} from "src/libraries/SharesMath.sol";

contract SharesMathTest is Test {
using SharesMath for uint256;

function testToSupplyShares(uint256 amount, uint256 supplyShares, uint256 totalSupply, uint256 totalSupplyShares)
public
{
totalSupplyShares = bound(totalSupplyShares, SharesMath.VIRTUAL_SHARES, type(uint128).max);
totalSupply = bound(totalSupply, 0, totalSupplyShares);

supplyShares = bound(supplyShares, 0, totalSupplyShares);

amount = bound(amount, 0, supplyShares.toAssetsDown(totalSupply, totalSupplyShares));

assertEq(
amount, amount.toWithdrawShares(totalSupply, totalSupplyShares).toAssetsDown(totalSupply, totalSupplyShares)
);
}

function testToBorrowShares(uint256 amount, uint256 borrowShares, uint256 totalBorrow, uint256 totalBorrowShares)
public
{
totalBorrowShares = bound(totalBorrowShares, SharesMath.VIRTUAL_SHARES, type(uint128).max);
totalBorrow = bound(totalBorrow, 0, totalBorrowShares);

borrowShares = bound(borrowShares, 0, totalBorrowShares);

amount = bound(amount, 0, borrowShares.toAssetsDown(totalBorrow, totalBorrowShares));

assertEq(
amount, amount.toRepayShares(totalBorrow, totalBorrowShares).toAssetsUp(totalBorrow, totalBorrowShares)
);
}
}