diff --git a/platform/packages/finance/src/duration.rs b/platform/packages/finance/src/duration.rs index 149ce7112..9c12651a3 100644 --- a/platform/packages/finance/src/duration.rs +++ b/platform/packages/finance/src/duration.rs @@ -11,8 +11,9 @@ use sdk::{ }; use crate::{ + coin::Coin, fraction::Fraction, - fractionable::{Fractionable, TimeSliceable}, + fractionable::{CheckedMultiply, Fractionable, TimeSliceable}, ratio::Rational, zero::Zero, }; @@ -108,6 +109,31 @@ impl Duration { { Rational::new(amount, annual_amount).of(self) } + + /// Implementation note: This method uses the checked_mul method to safely perform the multiplication. + /// Returns None if the result exceeds the limits of the type. + pub fn into_slice_per_ratio_checked(self, amount: U, annual_amount: U) -> Option + where + Self: Fractionable + CheckedMultiply, + U: Zero + Debug + PartialEq + Copy, + { + self.checked_mul(amount, annual_amount) + } +} + +impl CheckedMultiply> for Duration { + #[track_caller] + fn checked_mul(self, parts: Coin, total: Coin) -> Option + where + Coin: Zero + Debug + PartialEq>, + Self: Sized, + { + let d128: u128 = self.into(); + + CheckedMultiply::>::checked_mul(d128, parts, total) + .and_then(|res| res.try_into().ok()) + .map(Self::from_nanos) + } } impl From for u128 { diff --git a/platform/packages/finance/src/fractionable/mod.rs b/platform/packages/finance/src/fractionable/mod.rs index b48a080cd..5f4100ddc 100644 --- a/platform/packages/finance/src/fractionable/mod.rs +++ b/platform/packages/finance/src/fractionable/mod.rs @@ -31,14 +31,44 @@ pub trait HigherRank { type Intermediate; } -impl Fractionable for T +pub trait CheckedMultiply { + #[track_caller] + fn checked_mul(self, parts: U, total: U) -> Option + where + U: Zero + Debug + PartialEq, + Self: Sized; +} + +impl CheckedMultiply for T where T: HigherRank + Into, D: TryInto, + DIntermediate: Into, + D: Mul + Div, + U: Zero + PartialEq + Into + Debug, +{ + fn checked_mul(self, parts: U, total: U) -> Option { + if parts == total { + Some(self) + } else { + let res_double: D = self.into() * parts.into(); + let res_double = res_double / total.into(); + res_double + .try_into() + .ok() + .map(|res_intermediate: DIntermediate| res_intermediate.into()) + } + } +} + +impl Fractionable for T +where + T: HigherRank + Into + CheckedMultiply, + D: TryInto, >::Error: Debug, DIntermediate: Into, D: Mul + Div, - U: Zero + PartialEq + Into, + U: Zero + PartialEq + Into + Debug, { #[track_caller] fn safe_mul(self, ratio: &R) -> Self @@ -46,16 +76,8 @@ where R: Ratio, { // TODO debug_assert_eq!(T::BITS * 2, D::BITS); - - if ratio.parts() == ratio.total() { - self - } else { - let res_double: D = self.into() * ratio.parts().into(); - let res_double = res_double / ratio.total().into(); - let res_intermediate: DIntermediate = - res_double.try_into().expect("unexpected overflow"); - res_intermediate.into() - } + self.checked_mul(ratio.parts(), ratio.total()) + .expect("unexpected overflow") } } diff --git a/protocol/contracts/lease/src/lease/due.rs b/protocol/contracts/lease/src/lease/due.rs index 028d37731..956a19461 100644 --- a/protocol/contracts/lease/src/lease/due.rs +++ b/protocol/contracts/lease/src/lease/due.rs @@ -23,11 +23,13 @@ impl DueTrait for State { self.principal_due, Duration::YEAR, ); - if total_interest_a_year.is_zero() { - Duration::MAX - } else { - Duration::YEAR.into_slice_per_ratio(overdue_left, total_interest_a_year) - } + + // FIX for PR#370 + Duration::YEAR + .into_slice_per_ratio_checked(overdue_left, total_interest_a_year) + .map_or(Duration::MAX, |time_to_accrue_min_amount| { + time_to_accrue_min_amount + }) }; let time_to_collect = self.overdue.start_in().max(time_to_accrue_min_amount); if time_to_collect == Duration::default() { @@ -53,7 +55,7 @@ mod test { use crate::{ loan::{Overdue, State}, - position::DueTrait, + position::{DueTrait, OverdueCollection}, }; #[test] @@ -230,4 +232,24 @@ mod test { assert_eq!(Coin::ZERO, overdue_collection.amount()); assert_eq!(principal_due + total_interest, s.total_due()); } + + #[test] + fn test_large_interest_accrual_period() { + let principal_due = 20.into(); + let due_interest = 5.into(); + let due_margin_interest = 1.into(); + let till_due_end = Duration::from_days(1); + let s = State { + annual_interest: Percent::from_percent(15), + annual_interest_margin: Percent::from_percent(0), + principal_due, + due_interest, + due_margin_interest, + overdue: Overdue::StartIn(till_due_end), + }; + assert_eq!( + OverdueCollection::StartIn(Duration::MAX), + s.overdue_collection(1_800.into()) + ); + } } diff --git a/protocol/contracts/lease/src/position/interest.rs b/protocol/contracts/lease/src/position/interest.rs index 7965a971d..d5b8a4017 100644 --- a/protocol/contracts/lease/src/position/interest.rs +++ b/protocol/contracts/lease/src/position/interest.rs @@ -17,7 +17,7 @@ pub trait Due { /// When overdue interest amount goes above a configured minimum then the interest becomes collectable. fn overdue_collection(&self, min_amount: LpnCoin) -> OverdueCollection; } - +#[derive(PartialEq, Debug)] pub enum OverdueCollection { /// No collectable overdue interest yet ///