Skip to content

Commit

Permalink
Create a helper trait CheckedMultiply
Browse files Browse the repository at this point in the history
  • Loading branch information
maneva3 committed Dec 19, 2024
1 parent 8ae5efe commit 5b6cf56
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 20 deletions.
28 changes: 27 additions & 1 deletion platform/packages/finance/src/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ use sdk::{
};

use crate::{
coin::Coin,
fraction::Fraction,
fractionable::{Fractionable, TimeSliceable},
fractionable::{CheckedMultiply, Fractionable, TimeSliceable},
ratio::Rational,
zero::Zero,
};
Expand Down Expand Up @@ -108,6 +109,31 @@ impl Duration {
{
Rational::new(amount, annual_amount).of(self)
}

/// Implementation note: This method uses the checked_mul method to safely perform the multiplication.
/// Returns None if the result exceeds the limits of the type.
pub fn into_slice_per_ratio_checked<U>(self, amount: U, annual_amount: U) -> Option<Self>
where
Self: Fractionable<U> + CheckedMultiply<U>,
U: Zero + Debug + PartialEq + Copy,
{
self.checked_mul(amount, annual_amount)
}
}

impl<C> CheckedMultiply<Coin<C>> for Duration {
#[track_caller]
fn checked_mul(self, parts: Coin<C>, total: Coin<C>) -> Option<Self>
where
Coin<C>: Zero + Debug + PartialEq<Coin<C>>,
Self: Sized,
{
let d128: u128 = self.into();

CheckedMultiply::<Coin<C>>::checked_mul(d128, parts, total)
.and_then(|res| res.try_into().ok())
.map(Self::from_nanos)
}
}

impl From<Duration> for u128 {
Expand Down
46 changes: 34 additions & 12 deletions platform/packages/finance/src/fractionable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,53 @@ pub trait HigherRank<T> {
type Intermediate;
}

impl<T, D, DIntermediate, U> Fractionable<U> for T
pub trait CheckedMultiply<U> {
#[track_caller]
fn checked_mul(self, parts: U, total: U) -> Option<Self>
where
U: Zero + Debug + PartialEq<U>,
Self: Sized;
}

impl<T, D, DIntermediate, U> CheckedMultiply<U> for T
where
T: HigherRank<U, Type = D, Intermediate = DIntermediate> + Into<D>,
D: TryInto<DIntermediate>,
DIntermediate: Into<T>,
D: Mul<D, Output = D> + Div<D, Output = D>,
U: Zero + PartialEq + Into<D> + Debug,
{
fn checked_mul(self, parts: U, total: U) -> Option<Self> {
if parts == total {
Some(self)
} else {
let res_double: D = self.into() * parts.into();
let res_double = res_double / total.into();
res_double
.try_into()
.ok()
.map(|res_intermediate: DIntermediate| res_intermediate.into())
}
}
}

impl<T, D, DIntermediate, U> Fractionable<U> for T
where
T: HigherRank<U, Type = D, Intermediate = DIntermediate> + Into<D> + CheckedMultiply<U>,
D: TryInto<DIntermediate>,
<D as TryInto<DIntermediate>>::Error: Debug,
DIntermediate: Into<T>,
D: Mul<D, Output = D> + Div<D, Output = D>,
U: Zero + PartialEq + Into<D>,
U: Zero + PartialEq + Into<D> + Debug,
{
#[track_caller]
fn safe_mul<R>(self, ratio: &R) -> Self
where
R: Ratio<U>,
{
// TODO debug_assert_eq!(T::BITS * 2, D::BITS);

if ratio.parts() == ratio.total() {
self
} else {
let res_double: D = self.into() * ratio.parts().into();
let res_double = res_double / ratio.total().into();
let res_intermediate: DIntermediate =
res_double.try_into().expect("unexpected overflow");
res_intermediate.into()
}
self.checked_mul(ratio.parts(), ratio.total())
.expect("unexpected overflow")
}
}

Expand Down
13 changes: 6 additions & 7 deletions protocol/contracts/lease/src/lease/due.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ impl DueTrait for State {
Duration::YEAR,
);

// FIX for PR#294
let ratio_threshold = Duration::MAX.nanos() / Duration::YEAR.nanos();
if Some(overdue_left) >= total_interest_a_year.checked_mul(ratio_threshold.into()) {
Duration::MAX
} else {
Duration::YEAR.into_slice_per_ratio(overdue_left, total_interest_a_year)
}
// FIX for PR#370
Duration::YEAR
.into_slice_per_ratio_checked(overdue_left, total_interest_a_year)
.map_or(Duration::MAX, |time_to_accrue_min_amount| {
time_to_accrue_min_amount
})
};
let time_to_collect = self.overdue.start_in().max(time_to_accrue_min_amount);
if time_to_collect == Duration::default() {
Expand Down

0 comments on commit 5b6cf56

Please sign in to comment.