From 0053f9945de08f6c3119532c77969afb3198e022 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sun, 20 Oct 2024 12:35:09 +0800 Subject: [PATCH 1/6] Feat/linear damping --- src/inference.rs | 4 ++-- src/model.rs | 7 ++++++- src/optimal_retention.rs | 7 ++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index c593d43..2e4efb9 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -22,8 +22,8 @@ pub type Parameters = [f32]; use itertools::izip; pub static DEFAULT_PARAMETERS: [f32; 19] = [ - 0.4072, 1.1829, 3.1262, 15.4722, 7.2102, 0.5316, 1.0651, 0.0234, 1.616, 0.1544, 1.0824, 1.9813, - 0.0953, 0.2975, 2.2042, 0.2407, 2.9466, 0.5034, 0.6567, + 0.40255, 1.18385, 3.173, 15.69105, 7.1949, 0.5345, 1.4604, 0.0046, 1.54575, 0.1192, 1.01925, + 1.9395, 0.11, 0.29605, 2.2698, 0.2315, 2.9898, 0.51655, 0.6621, ]; fn infer( diff --git a/src/model.rs b/src/model.rs index 83c38bd..f65b0b2 100644 --- a/src/model.rs +++ b/src/model.rs @@ -104,8 +104,13 @@ impl Model { self.w.get(4) - (self.w.get(5) * (rating - 1)).exp() + 1 } + fn linear_damping(&self, delta_d: Tensor, old_d: Tensor) -> Tensor { + old_d.neg().add_scalar(10.0) * delta_d.div_scalar(9.0) + } + fn next_difficulty(&self, difficulty: Tensor, rating: Tensor) -> Tensor { - difficulty - self.w.get(6) * (rating - 3) + let delta_d = -self.w.get(6) * (rating - 3); + difficulty.clone() + self.linear_damping(delta_d, difficulty) } pub(crate) fn step( diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index cd5d825..4f63715 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -97,8 +97,13 @@ fn init_d_with_short_term(w: &[f32], rating: usize, rating_offset: f32) -> f32 { new_d.clamp(1.0, 10.0) } +fn linear_damping(delta_d: f32, old_d: f32) -> f32 { + (10.0 - old_d) / 9.0 * delta_d +} + fn next_d(w: &[f32], d: f32, rating: usize) -> f32 { - let new_d = d - w[6] * (rating as f32 - 3.0); + let delta_d = -w[6] * (rating as f32 - 3.0); + let new_d = d + linear_damping(delta_d, d); mean_reversion(w, init_d(w, 4), new_d).clamp(1.0, 10.0) } From 49671634cee13e2f5beac5920a92d590174b5b4e Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sun, 20 Oct 2024 15:34:06 +0800 Subject: [PATCH 2/6] update tests --- src/inference.rs | 20 ++++++++++---------- src/model.rs | 17 ++++++----------- src/optimal_retention.rs | 12 ++++++------ src/pre_training.rs | 6 +++--- 4 files changed, 25 insertions(+), 30 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index 2e4efb9..7153344 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -493,23 +493,23 @@ mod tests { ]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.210983, 0.037216]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.211007, 0.037216]); let fsrs = FSRS::new(Some(&[]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217689, 0.039710]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216326, 0.038727]); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203235, 0.026295]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203049, 0.027558]); let (self_by_other, other_by_self) = fsrs .universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true) .unwrap(); - assert_approx_eq([self_by_other, other_by_self], [0.014476, 0.031874]); + assert_approx_eq([self_by_other, other_by_self], [0.016236, 0.031085]); Ok(()) } @@ -544,14 +544,14 @@ mod tests { again: ItemState { memory: MemoryState { stability: 2.969144, - difficulty: 9.520562 + difficulty: 8.000659 }, interval: 2.9691453 }, hard: ItemState { memory: MemoryState { stability: 17.091442, - difficulty: 8.4513445 + difficulty: 7.6913934 }, interval: 17.09145 }, @@ -565,7 +565,7 @@ mod tests { easy: ItemState { memory: MemoryState { stability: 71.75015, - difficulty: 6.3129106 + difficulty: 7.0728626 }, interval: 71.75018 } @@ -594,17 +594,17 @@ mod tests { let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(); assert_approx_eq( [memory_state.stability, memory_state.difficulty], - [9.999996, 7.279789], + [9.999996, 7.079161], ); let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.8).unwrap(); assert_approx_eq( [memory_state.stability, memory_state.difficulty], - [4.170096, 9.462736], + [4.170096, 9.323614], ); let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.95).unwrap(); assert_approx_eq( [memory_state.stability, memory_state.difficulty], - [21.712555, 2.380_21], + [21.712555, 2.174237], ); let memory_state = fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap(); assert_approx_eq( diff --git a/src/model.rs b/src/model.rs index f65b0b2..f70172e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -380,18 +380,13 @@ mod tests { next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([ - 5.0 + 2.0 * DEFAULT_PARAMETERS[6], - 5.0 + DEFAULT_PARAMETERS[6], - 5.0, - 5.0 - DEFAULT_PARAMETERS[6] - ]) + Data::from([6.622667, 5.811333, 5.0, 4.188667]) ); let next_difficulty = model.mean_reversion(next_difficulty); next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([7.040172, 5.999995, 4.959819, 3.9196422]) + Data::from([6.607035, 5.7994337, 4.9918327, 4.1842318]) ) } @@ -412,24 +407,24 @@ mod tests { s_recall.clone().backward(); assert_eq!( s_recall.to_data(), - Data::from([27.43741, 15.276875, 65.24019, 224.3506]) + Data::from([25.77614, 14.121894, 60.40441, 208.97597]) ); let s_forget = model.stability_after_failure(stability.clone(), difficulty, retention); s_forget.clone().backward(); assert_eq!( s_forget.to_data(), - Data::from([1.7390966, 2.029377, 2.433932, 2.9520853]) + Data::from([1.7028502, 1.9798818, 2.3759942, 2.8885393]) ); let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget); next_stability.clone().backward(); assert_eq!( next_stability.to_data(), - Data::from([1.7390966, 15.276875, 65.24019, 224.3506]) + Data::from([1.7028502, 14.121894, 60.40441, 208.97597]) ); let next_stability = model.stability_short_term(stability, rating); assert_eq!( next_stability.to_data(), - Data::from([2.542685, 4.2064567, 6.958895, 11.512355]) + Data::from([2.5051427, 4.199207, 7.038856, 11.798775]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 4f63715..fc381ce 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -900,7 +900,7 @@ mod tests { simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?; assert_eq!( memorized_cnt_per_day[memorized_cnt_per_day.len() - 1], - 6647.9404 + 6919.944 ); Ok(()) } @@ -1002,8 +1002,8 @@ mod tests { assert_eq!( results.1.to_vec(), vec![ - 0, 13, 23, 30, 58, 81, 83, 81, 86, 88, 88, 106, 120, 110, 122, 133, 132, 121, 131, - 143, 161, 188, 142, 179, 145, 156, 172, 191, 174, 165 + 0, 16, 25, 34, 60, 65, 76, 85, 91, 92, 100, 103, 119, 107, 103, 113, 122, 143, 149, + 151, 148, 172, 154, 175, 156, 169, 155, 191, 185, 170 ] ); assert_eq!( @@ -1020,7 +1020,7 @@ mod tests { ..Default::default() }; let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?; - assert_eq!(results.0[results.0.len() - 1], 6284.7783); + assert_eq!(results.0[results.0.len() - 1], 6591.4854); Ok(()) } @@ -1073,7 +1073,7 @@ mod tests { ..Default::default() }; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8451333); + assert_eq!(optimal_retention, 0.84499365); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } @@ -1093,7 +1093,7 @@ mod tests { let mut param = DEFAULT_PARAMETERS[..17].to_vec(); param.extend_from_slice(&[0.0, 0.0]); let optimal_retention = fsrs.optimal_retention(&config, ¶m, |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.83150166); + assert_eq!(optimal_retention, 0.85450846); Ok(()) } diff --git a/src/pre_training.rs b/src/pre_training.rs index 8dc78b9..eae87e5 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -303,9 +303,9 @@ mod tests { let count = Array1::from(vec![435.0, 97.0, 63.0, 38.0, 28.0]); let default_s0 = DEFAULT_PARAMETERS[0] as f64; let actual = loss(&delta_t, &recall, &count, 1.017056, default_s0); - assert_eq!(actual, 280.7497802442413); + assert_eq!(actual, 280.75007086903867); let actual = loss(&delta_t, &recall, &count, 1.017011, default_s0); - assert_eq!(actual, 280.7494462238896); + assert_eq!(actual, 280.74973684868695); } #[test] @@ -370,6 +370,6 @@ mod tests { let mut rating_stability = HashMap::from([(2, 0.35)]); let rating_count = HashMap::from([(2, 1)]); let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap(); - assert_eq!(actual, [0.12048356, 0.35, 0.9249894, 4.577961]); + assert_eq!(actual, [0.11901211, 0.35, 0.9380833, 4.638989]); } } From d4a8289890c0bbb58c4700a1b5520cfcdafcf6d0 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 21 Oct 2024 17:16:09 +0800 Subject: [PATCH 3/6] update results of test_loss_and_grad --- src/parameter_clipper.rs | 18 +++++++++--------- src/training.rs | 11 +++++++---- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/parameter_clipper.rs b/src/parameter_clipper.rs index 501e06e..d84c0e9 100644 --- a/src/parameter_clipper.rs +++ b/src/parameter_clipper.rs @@ -20,16 +20,16 @@ pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec { (S_MIN, INIT_S_MAX), (S_MIN, INIT_S_MAX), (1.0, 10.0), - (0.1, 4.0), - (0.1, 4.0), - (0.0, 0.75), + (0.001, 4.0), + (0.001, 4.0), + (0.001, 0.75), (0.0, 4.5), (0.0, 0.8), - (0.01, 3.5), - (0.1, 5.0), - (0.01, 0.25), - (0.01, 0.9), - (0.01, 4.0), + (0.001, 3.5), + (0.001, 5.0), + (0.001, 0.25), + (0.001, 0.9), + (0.0, 4.0), (0.0, 1.0), (1.0, 6.0), (0.0, 2.0), @@ -63,7 +63,7 @@ mod tests { assert_eq!( values, - &[0.01, 0.01, 100.0, 0.01, 10.0, 0.1, 1.0, 0.25, 0.0] + &[0.01, 0.01, 100.0, 0.01, 10.0, 0.001, 1.0, 0.25, 0.0] ); } } diff --git a/src/training.rs b/src/training.rs index e4d21be..448e708 100644 --- a/src/training.rs +++ b/src/training.rs @@ -496,16 +496,19 @@ mod tests { Reduction::Sum, ); - assert_eq!(loss.clone().into_data().convert::().value[0], 4.380769); + assert_eq!( + loss.clone().into_data().convert::().value[0], + 4.4467363 + ); let gradients = loss.backward(); let w_grad = model.w.grad(&gradients).unwrap(); dbg!(&w_grad); Data::from([ - -0.044447, -0.004000, -0.002020, 0.009756, -0.036012, 1.126084, 0.101431, -0.888184, - 0.540923, -2.830812, 0.492003, -0.008362, 0.024086, -0.077360, -0.000585, -0.135484, - 0.203740, 0.208560, 0.037535, + -0.05832, -0.00682, -0.00255, 0.010539, -0.05128, 1.364291, 0.083658, -0.95023, + 0.534472, -2.89288, 0.514163, -0.01306, 0.041905, -0.11830, -0.00092, -0.14452, + 0.202374, 0.214104, 0.032307, ]) .assert_approx_eq(&w_grad.clone().into_data(), 5); } From 613ebecf91f66cb3e89b934dc9bc9c1da38946a0 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 21 Oct 2024 17:58:56 +0800 Subject: [PATCH 4/6] dbg!(&metrics) in test of training --- src/training.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/training.rs b/src/training.rs index 448e708..1acfc27 100644 --- a/src/training.rs +++ b/src/training.rs @@ -553,8 +553,13 @@ mod tests { }); let fsrs = FSRS::new(Some(&[])).unwrap(); - let parameters = fsrs.compute_parameters(items, progress2).unwrap(); + let parameters = fsrs.compute_parameters(items.clone(), progress2).unwrap(); dbg!(¶meters); + + // evaluate + let model = FSRS::new(Some(¶meters)).unwrap(); + let metrics = model.evaluate(items, |_| true).unwrap(); + dbg!(&metrics); } } } From 39e718b83f41f2d03a6cc2d461a2e32ccc742a7d Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 22 Oct 2024 17:23:35 +0800 Subject: [PATCH 5/6] update check_and_fill_parameters formula --- src/model.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model.rs b/src/model.rs index f1a0692..5cda85c 100644 --- a/src/model.rs +++ b/src/model.rs @@ -269,6 +269,7 @@ pub(crate) fn check_and_fill_parameters(parameters: &Parameters) -> Result Date: Tue, 22 Oct 2024 17:32:10 +0800 Subject: [PATCH 6/6] cargo clippy & bump version --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/model.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 06a2e32..d5629ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,7 +1077,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "1.3.4" +version = "1.3.5" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index e5d5130..82b01a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "1.3.4" +version = "1.3.5" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/model.rs b/src/model.rs index 5cda85c..7e6c067 100644 --- a/src/model.rs +++ b/src/model.rs @@ -269,7 +269,7 @@ pub(crate) fn check_and_fill_parameters(parameters: &Parameters) -> Result