diff --git a/programs/rewards/src/error.rs b/programs/rewards/src/error.rs index 1b79a04..b1de329 100644 --- a/programs/rewards/src/error.rs +++ b/programs/rewards/src/error.rs @@ -75,6 +75,11 @@ pub enum MplxRewardsError { /// No need to transfer zero amount of rewards. #[error("Rewards: rewards amount must be positive")] RewardsMustBeGreaterThanZero, + + /// 13 + /// No need to transfer zero amount of rewards. + #[error("No changes at the date in weighted stake modifiers while they're expected")] + NoWeightedStakeModifiersAtADate, } impl PrintProgramError for MplxRewardsError { diff --git a/programs/rewards/src/state/reward_pool.rs b/programs/rewards/src/state/reward_pool.rs index df453fb..70b6d1b 100644 --- a/programs/rewards/src/state/reward_pool.rs +++ b/programs/rewards/src/state/reward_pool.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{btree_map::Entry, BTreeMap}; use crate::{ error::MplxRewardsError, @@ -215,16 +215,17 @@ impl RewardPool { curr_part_of_weighted_stake_for_flex, )?; - self.calculator - .weighted_stake_diffs - .entry(deposit_old_expiration_ts) - .and_modify(|modifier| *modifier -= weighted_stake_diff); + Self::modify_weighted_stake_diffs( + &mut self.calculator.weighted_stake_diffs, + deposit_old_expiration_ts, + weighted_stake_diff, + )?; - mining - .index - .weighted_stake_diffs - .entry(deposit_old_expiration_ts) - .and_modify(|modifier| *modifier -= weighted_stake_diff); + Self::modify_weighted_stake_diffs( + &mut mining.index.weighted_stake_diffs, + deposit_old_expiration_ts, + weighted_stake_diff, + )?; // also, we need to reduce staking power because we want to extend stake from "scratch" mining.share = mining @@ -268,6 +269,23 @@ impl RewardPool { .checked_sub(amount_multiplied_by_flex) .ok_or(MplxRewardsError::MathOverflow) } + + fn modify_weighted_stake_diffs( + diffs: &mut BTreeMap, + timestamp: u64, + weighted_stake_diff: u64, + ) -> Result<(), MplxRewardsError> { + match diffs.entry(timestamp) { + Entry::Vacant(_) => Err(MplxRewardsError::NoWeightedStakeModifiersAtADate), + Entry::Occupied(mut entry) => { + let modifier = entry.get_mut(); + *modifier = modifier + .checked_sub(weighted_stake_diff) + .ok_or(MplxRewardsError::MathOverflow)?; + Ok(()) + } + } + } } impl Sealed for RewardPool {} diff --git a/programs/rewards/tests/rewards/close_mining.rs b/programs/rewards/tests/rewards/close_mining.rs index ce97bd6..808980f 100644 --- a/programs/rewards/tests/rewards/close_mining.rs +++ b/programs/rewards/tests/rewards/close_mining.rs @@ -1,11 +1,8 @@ -use crate::utils::*; +use crate::utils::{assert_custom_on_chain_error::AssertCustomOnChainErr, *}; use mplx_rewards::{error::MplxRewardsError, utils::LockupPeriod}; use solana_program::pubkey::Pubkey; use solana_program_test::*; -use solana_sdk::{ - clock::SECONDS_PER_DAY, instruction::InstructionError, signature::Keypair, signer::Signer, - transaction::TransactionError, -}; +use solana_sdk::{clock::SECONDS_PER_DAY, signature::Keypair, signer::Signer}; async fn setup() -> (ProgramTestContext, TestRewards, Keypair, Pubkey) { let test = ProgramTest::new( @@ -128,17 +125,8 @@ async fn success_after_not_interacting_for_a_long_time() { advance_clock_by_ts(&mut context, (SECONDS_PER_DAY * 100).try_into().unwrap()).await; - let res = test_rewards + test_rewards .close_mining(&mut context, &mining, &mining_owner, &mining_owner.pubkey()) - .await; - - match res { - Err(BanksClientError::TransactionError(TransactionError::InstructionError( - _, - InstructionError::Custom(code), - ))) => { - assert_eq!(code, MplxRewardsError::RewardsMustBeClaimed as u32); - } - _ => unreachable!(), - } + .await + .assert_on_chain_err(MplxRewardsError::RewardsMustBeClaimed); } diff --git a/programs/rewards/tests/rewards/fill_vault.rs b/programs/rewards/tests/rewards/fill_vault.rs index b110a17..b13ffff 100644 --- a/programs/rewards/tests/rewards/fill_vault.rs +++ b/programs/rewards/tests/rewards/fill_vault.rs @@ -1,8 +1,8 @@ -use crate::utils::*; +use crate::utils::{assert_custom_on_chain_error::AssertCustomOnChainErr, *}; use mplx_rewards::{error::MplxRewardsError, utils::LockupPeriod}; -use solana_program::{instruction::InstructionError, program_pack::Pack}; +use solana_program::program_pack::Pack; use solana_program_test::*; -use solana_sdk::{signature::Keypair, signer::Signer, transaction::TransactionError}; +use solana_sdk::{signature::Keypair, signer::Signer}; use spl_token::state::Account; use std::borrow::Borrow; @@ -133,17 +133,8 @@ async fn zero_amount_of_rewards() { .unix_timestamp as u64 + 86400 * 100; - let res = test_rewards + test_rewards .fill_vault(&mut context, &rewarder.pubkey(), 0, distribution_ends_at) - .await; - - match res { - Err(BanksClientError::TransactionError(TransactionError::InstructionError( - _, - InstructionError::Custom(code), - ))) => { - assert_eq!(code, MplxRewardsError::RewardsMustBeGreaterThanZero as u32); - } - _ => unreachable!(), - } + .await + .assert_on_chain_err(MplxRewardsError::RewardsMustBeGreaterThanZero); } diff --git a/programs/rewards/tests/rewards/utils.rs b/programs/rewards/tests/rewards/utils.rs index e82fbf8..428f2d3 100644 --- a/programs/rewards/tests/rewards/utils.rs +++ b/programs/rewards/tests/rewards/utils.rs @@ -1,14 +1,14 @@ use std::borrow::{Borrow, BorrowMut}; -use mplx_rewards::utils::LockupPeriod; -use solana_program::pubkey::Pubkey; +use mplx_rewards::{error::MplxRewardsError, utils::LockupPeriod}; +use solana_program::{instruction::InstructionError, pubkey::Pubkey}; use solana_program_test::{BanksClientError, ProgramTestContext}; use solana_sdk::{ account::Account, program_pack::Pack, signature::{Keypair, Signer}, system_instruction::{self}, - transaction::Transaction, + transaction::{Transaction, TransactionError}, }; use spl_token::state::Account as SplTokenAccount; @@ -448,3 +448,27 @@ pub async fn claim_and_assert( .unwrap(); assert_tokens(context, user_reward, amount).await; } + +pub mod assert_custom_on_chain_error { + use super::*; + use std::fmt::Debug; + + pub trait AssertCustomOnChainErr { + fn assert_on_chain_err(self, expected_err: MplxRewardsError); + } + + impl AssertCustomOnChainErr for Result { + fn assert_on_chain_err(self, expected_err: MplxRewardsError) { + assert!(self.is_err()); + match self.unwrap_err() { + BanksClientError::TransactionError(TransactionError::InstructionError( + _, + InstructionError::Custom(code), + )) => { + debug_assert_eq!(expected_err as u32, code); + } + _ => unreachable!("BanksClientError has no 'Custom' variant."), + } + } + } +}