Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/linear damping #239

Merged
merged 7 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.3.4"
version = "1.3.5"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
24 changes: 12 additions & 12 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub type Parameters = [f32];
use itertools::izip;

pub static DEFAULT_PARAMETERS: [f32; 19] = [
0.4072, 1.1829, 3.1262, 15.4722, 7.2102, 0.5316, 1.0651, 0.0234, 1.616, 0.1544, 1.0824, 1.9813,
0.0953, 0.2975, 2.2042, 0.2407, 2.9466, 0.5034, 0.6567,
0.40255, 1.18385, 3.173, 15.69105, 7.1949, 0.5345, 1.4604, 0.0046, 1.54575, 0.1192, 1.01925,
1.9395, 0.11, 0.29605, 2.2698, 0.2315, 2.9898, 0.51655, 0.6621,
];

fn infer<B: Backend>(
Expand Down Expand Up @@ -493,23 +493,23 @@ mod tests {
]))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.210983, 0.037216]);
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.211007, 0.037216]);

let fsrs = FSRS::new(Some(&[]))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217689, 0.039710]);
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216326, 0.038727]);

let fsrs = FSRS::new(Some(PARAMETERS))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203235, 0.026295]);
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203049, 0.027558]);

let (self_by_other, other_by_self) = fsrs
.universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true)
.unwrap();

assert_approx_eq([self_by_other, other_by_self], [0.014476, 0.031874]);
assert_approx_eq([self_by_other, other_by_self], [0.016236, 0.031085]);

Ok(())
}
Expand Down Expand Up @@ -544,14 +544,14 @@ mod tests {
again: ItemState {
memory: MemoryState {
stability: 2.969144,
difficulty: 9.520562
difficulty: 8.000659
},
interval: 2.9691453
},
hard: ItemState {
memory: MemoryState {
stability: 17.091442,
difficulty: 8.4513445
difficulty: 7.6913934
},
interval: 17.09145
},
Expand All @@ -565,7 +565,7 @@ mod tests {
easy: ItemState {
memory: MemoryState {
stability: 71.75015,
difficulty: 6.3129106
difficulty: 7.0728626
},
interval: 71.75018
}
Expand Down Expand Up @@ -594,17 +594,17 @@ mod tests {
let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap();
assert_approx_eq(
[memory_state.stability, memory_state.difficulty],
[9.999996, 7.279789],
[9.999996, 7.079161],
);
let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.8).unwrap();
assert_approx_eq(
[memory_state.stability, memory_state.difficulty],
[4.170096, 9.462736],
[4.170096, 9.323614],
);
let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.95).unwrap();
assert_approx_eq(
[memory_state.stability, memory_state.difficulty],
[21.712555, 2.380_21],
[21.712555, 2.174237],
);
let memory_state = fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap();
assert_approx_eq(
Expand Down
27 changes: 14 additions & 13 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,13 @@ impl<B: Backend> Model<B> {
self.w.get(4) - (self.w.get(5) * (rating - 1)).exp() + 1
}

fn linear_damping(&self, delta_d: Tensor<B, 1>, old_d: Tensor<B, 1>) -> Tensor<B, 1> {
old_d.neg().add_scalar(10.0) * delta_d.div_scalar(9.0)
}

fn next_difficulty(&self, difficulty: Tensor<B, 1>, rating: Tensor<B, 1>) -> Tensor<B, 1> {
difficulty - self.w.get(6) * (rating - 3)
let delta_d = -self.w.get(6) * (rating - 3);
difficulty.clone() + self.linear_damping(delta_d, difficulty)
}

pub(crate) fn step(
Expand Down Expand Up @@ -264,6 +269,7 @@ pub(crate) fn check_and_fill_parameters(parameters: &Parameters) -> Result<Vec<f
let mut parameters = parameters.to_vec();
parameters[4] = parameters[5].mul_add(2.0, parameters[4]);
parameters[5] = parameters[5].mul_add(3.0, 1.0).ln() / 3.0;
parameters[6] += 0.5;
parameters.extend_from_slice(&[0.0, 0.0]);
parameters
}
Expand Down Expand Up @@ -298,7 +304,7 @@ mod tests {
assert_eq!(
fsrs5_param,
vec![
0.4, 0.6, 2.4, 5.8, 6.81, 0.44675013, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05,
0.4, 0.6, 2.4, 5.8, 6.81, 0.44675013, 1.36, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05,
0.34, 1.26, 0.29, 2.61, 0.0, 0.0,
]
)
Expand Down Expand Up @@ -387,18 +393,13 @@ mod tests {
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([
5.0 + 2.0 * DEFAULT_PARAMETERS[6],
5.0 + DEFAULT_PARAMETERS[6],
5.0,
5.0 - DEFAULT_PARAMETERS[6]
])
Data::from([6.622667, 5.811333, 5.0, 4.188667])
);
let next_difficulty = model.mean_reversion(next_difficulty);
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([7.040172, 5.999995, 4.959819, 3.9196422])
Data::from([6.607035, 5.7994337, 4.9918327, 4.1842318])
)
}

Expand All @@ -419,24 +420,24 @@ mod tests {
s_recall.clone().backward();
assert_eq!(
s_recall.to_data(),
Data::from([27.43741, 15.276875, 65.24019, 224.3506])
Data::from([25.77614, 14.121894, 60.40441, 208.97597])
);
let s_forget = model.stability_after_failure(stability.clone(), difficulty, retention);
s_forget.clone().backward();
assert_eq!(
s_forget.to_data(),
Data::from([1.7390966, 2.029377, 2.433932, 2.9520853])
Data::from([1.7028502, 1.9798818, 2.3759942, 2.8885393])
);
let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget);
next_stability.clone().backward();
assert_eq!(
next_stability.to_data(),
Data::from([1.7390966, 15.276875, 65.24019, 224.3506])
Data::from([1.7028502, 14.121894, 60.40441, 208.97597])
);
let next_stability = model.stability_short_term(stability, rating);
assert_eq!(
next_stability.to_data(),
Data::from([2.542685, 4.2064567, 6.958895, 11.512355])
Data::from([2.5051427, 4.199207, 7.038856, 11.798775])
)
}

Expand Down
19 changes: 12 additions & 7 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ fn init_d_with_short_term(w: &[f32], rating: usize, rating_offset: f32) -> f32 {
new_d.clamp(1.0, 10.0)
}

fn linear_damping(delta_d: f32, old_d: f32) -> f32 {
(10.0 - old_d) / 9.0 * delta_d
}

fn next_d(w: &[f32], d: f32, rating: usize) -> f32 {
let new_d = d - w[6] * (rating as f32 - 3.0);
let delta_d = -w[6] * (rating as f32 - 3.0);
let new_d = d + linear_damping(delta_d, d);
mean_reversion(w, init_d(w, 4), new_d).clamp(1.0, 10.0)
}

Expand Down Expand Up @@ -895,7 +900,7 @@ mod tests {
simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?;
assert_eq!(
memorized_cnt_per_day[memorized_cnt_per_day.len() - 1],
6647.9404
6919.944
);
Ok(())
}
Expand Down Expand Up @@ -997,8 +1002,8 @@ mod tests {
assert_eq!(
results.1.to_vec(),
vec![
0, 13, 23, 30, 58, 81, 83, 81, 86, 88, 88, 106, 120, 110, 122, 133, 132, 121, 131,
143, 161, 188, 142, 179, 145, 156, 172, 191, 174, 165
0, 16, 25, 34, 60, 65, 76, 85, 91, 92, 100, 103, 119, 107, 103, 113, 122, 143, 149,
151, 148, 172, 154, 175, 156, 169, 155, 191, 185, 170
]
);
assert_eq!(
Expand All @@ -1015,7 +1020,7 @@ mod tests {
..Default::default()
};
let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?;
assert_eq!(results.0[results.0.len() - 1], 6284.7783);
assert_eq!(results.0[results.0.len() - 1], 6591.4854);
Ok(())
}

Expand Down Expand Up @@ -1068,7 +1073,7 @@ mod tests {
..Default::default()
};
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8451333);
assert_eq!(optimal_retention, 0.84499365);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand All @@ -1088,7 +1093,7 @@ mod tests {
let mut param = DEFAULT_PARAMETERS[..17].to_vec();
param.extend_from_slice(&[0.0, 0.0]);
let optimal_retention = fsrs.optimal_retention(&config, &param, |_v| true).unwrap();
assert_eq!(optimal_retention, 0.83150166);
assert_eq!(optimal_retention, 0.85450846);
Ok(())
}

Expand Down
18 changes: 9 additions & 9 deletions src/parameter_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec<f32> {
(S_MIN, INIT_S_MAX),
(S_MIN, INIT_S_MAX),
(1.0, 10.0),
(0.1, 4.0),
(0.1, 4.0),
(0.0, 0.75),
(0.001, 4.0),
(0.001, 4.0),
(0.001, 0.75),
(0.0, 4.5),
(0.0, 0.8),
(0.01, 3.5),
(0.1, 5.0),
(0.01, 0.25),
(0.01, 0.9),
(0.01, 4.0),
(0.001, 3.5),
(0.001, 5.0),
(0.001, 0.25),
(0.001, 0.9),
(0.0, 4.0),
(0.0, 1.0),
(1.0, 6.0),
(0.0, 2.0),
Expand Down Expand Up @@ -63,7 +63,7 @@ mod tests {

assert_eq!(
values,
&[0.01, 0.01, 100.0, 0.01, 10.0, 0.1, 1.0, 0.25, 0.0]
&[0.01, 0.01, 100.0, 0.01, 10.0, 0.001, 1.0, 0.25, 0.0]
);
}
}
6 changes: 3 additions & 3 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ mod tests {
let count = Array1::from(vec![435.0, 97.0, 63.0, 38.0, 28.0]);
let default_s0 = DEFAULT_PARAMETERS[0] as f64;
let actual = loss(&delta_t, &recall, &count, 1.017056, default_s0);
assert_eq!(actual, 280.7497802442413);
assert_eq!(actual, 280.75007086903867);
let actual = loss(&delta_t, &recall, &count, 1.017011, default_s0);
assert_eq!(actual, 280.7494462238896);
assert_eq!(actual, 280.74973684868695);
}

#[test]
Expand Down Expand Up @@ -370,6 +370,6 @@ mod tests {
let mut rating_stability = HashMap::from([(2, 0.35)]);
let rating_count = HashMap::from([(2, 1)]);
let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap();
assert_eq!(actual, [0.12048356, 0.35, 0.9249894, 4.577961]);
assert_eq!(actual, [0.11901211, 0.35, 0.9380833, 4.638989]);
}
}
18 changes: 13 additions & 5 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,16 +496,19 @@ mod tests {
Reduction::Sum,
);

assert_eq!(loss.clone().into_data().convert::<f32>().value[0], 4.380769);
assert_eq!(
loss.clone().into_data().convert::<f32>().value[0],
4.4467363
);
let gradients = loss.backward();

let w_grad = model.w.grad(&gradients).unwrap();
dbg!(&w_grad);

Data::from([
-0.044447, -0.004000, -0.002020, 0.009756, -0.036012, 1.126084, 0.101431, -0.888184,
0.540923, -2.830812, 0.492003, -0.008362, 0.024086, -0.077360, -0.000585, -0.135484,
0.203740, 0.208560, 0.037535,
-0.05832, -0.00682, -0.00255, 0.010539, -0.05128, 1.364291, 0.083658, -0.95023,
0.534472, -2.89288, 0.514163, -0.01306, 0.041905, -0.11830, -0.00092, -0.14452,
0.202374, 0.214104, 0.032307,
])
.assert_approx_eq(&w_grad.clone().into_data(), 5);
}
Expand Down Expand Up @@ -550,8 +553,13 @@ mod tests {
});

let fsrs = FSRS::new(Some(&[])).unwrap();
let parameters = fsrs.compute_parameters(items, progress2).unwrap();
let parameters = fsrs.compute_parameters(items.clone(), progress2).unwrap();
dbg!(&parameters);

// evaluate
let model = FSRS::new(Some(&parameters)).unwrap();
let metrics = model.evaluate(items, |_| true).unwrap();
dbg!(&metrics);
}
}
}
Loading