diff --git a/platform/packages/finance/src/duration.rs b/platform/packages/finance/src/duration.rs index 149ce7112..9c12651a3 100644 --- a/platform/packages/finance/src/duration.rs +++ b/platform/packages/finance/src/duration.rs @@ -11,8 +11,9 @@ use sdk::{ }; use crate::{ + coin::Coin, fraction::Fraction, - fractionable::{Fractionable, TimeSliceable}, + fractionable::{CheckedMultiply, Fractionable, TimeSliceable}, ratio::Rational, zero::Zero, }; @@ -108,6 +109,31 @@ impl Duration { { Rational::new(amount, annual_amount).of(self) } + + /// Implementation note: This method uses the checked_mul method to safely perform the multiplication. + /// Returns None if the result exceeds the limits of the type. + pub fn into_slice_per_ratio_checked(self, amount: U, annual_amount: U) -> Option + where + Self: Fractionable + CheckedMultiply, + U: Zero + Debug + PartialEq + Copy, + { + self.checked_mul(amount, annual_amount) + } +} + +impl CheckedMultiply> for Duration { + #[track_caller] + fn checked_mul(self, parts: Coin, total: Coin) -> Option + where + Coin: Zero + Debug + PartialEq>, + Self: Sized, + { + let d128: u128 = self.into(); + + CheckedMultiply::>::checked_mul(d128, parts, total) + .and_then(|res| res.try_into().ok()) + .map(Self::from_nanos) + } } impl From for u128 { diff --git a/platform/packages/finance/src/fractionable/mod.rs b/platform/packages/finance/src/fractionable/mod.rs index b48a080cd..5f4100ddc 100644 --- a/platform/packages/finance/src/fractionable/mod.rs +++ b/platform/packages/finance/src/fractionable/mod.rs @@ -31,14 +31,44 @@ pub trait HigherRank { type Intermediate; } -impl Fractionable for T +pub trait CheckedMultiply { + #[track_caller] + fn checked_mul(self, parts: U, total: U) -> Option + where + U: Zero + Debug + PartialEq, + Self: Sized; +} + +impl CheckedMultiply for T where T: HigherRank + Into, D: TryInto, + DIntermediate: Into, + D: Mul + Div, + U: Zero + PartialEq + Into + Debug, +{ + fn checked_mul(self, parts: U, total: U) -> Option { + if parts == total { + Some(self) + } else { + let res_double: D = self.into() * parts.into(); + let res_double = res_double / total.into(); + res_double + .try_into() + .ok() + .map(|res_intermediate: DIntermediate| res_intermediate.into()) + } + } +} + +impl Fractionable for T +where + T: HigherRank + Into + CheckedMultiply, + D: TryInto, >::Error: Debug, DIntermediate: Into, D: Mul + Div, - U: Zero + PartialEq + Into, + U: Zero + PartialEq + Into + Debug, { #[track_caller] fn safe_mul(self, ratio: &R) -> Self @@ -46,16 +76,8 @@ where R: Ratio, { // TODO debug_assert_eq!(T::BITS * 2, D::BITS); - - if ratio.parts() == ratio.total() { - self - } else { - let res_double: D = self.into() * ratio.parts().into(); - let res_double = res_double / ratio.total().into(); - let res_intermediate: DIntermediate = - res_double.try_into().expect("unexpected overflow"); - res_intermediate.into() - } + self.checked_mul(ratio.parts(), ratio.total()) + .expect("unexpected overflow") } } diff --git a/protocol/contracts/lease/src/lease/due.rs b/protocol/contracts/lease/src/lease/due.rs index b4b1f1526..956a19461 100644 --- a/protocol/contracts/lease/src/lease/due.rs +++ b/protocol/contracts/lease/src/lease/due.rs @@ -24,13 +24,12 @@ impl DueTrait for State { Duration::YEAR, ); - // FIX for PR#294 - let ratio_threshold = Duration::MAX.nanos() / Duration::YEAR.nanos(); - if Some(overdue_left) >= total_interest_a_year.checked_mul(ratio_threshold.into()) { - Duration::MAX - } else { - Duration::YEAR.into_slice_per_ratio(overdue_left, total_interest_a_year) - } + // FIX for PR#370 + Duration::YEAR + .into_slice_per_ratio_checked(overdue_left, total_interest_a_year) + .map_or(Duration::MAX, |time_to_accrue_min_amount| { + time_to_accrue_min_amount + }) }; let time_to_collect = self.overdue.start_in().max(time_to_accrue_min_amount); if time_to_collect == Duration::default() {