Skip to content

Commit

Permalink
Merge branch 'development' into feature/improve-display-for-naive-bayes
Browse files Browse the repository at this point in the history
  • Loading branch information
Mec-iS authored Mar 21, 2023
2 parents c40725e + d15ea43 commit 19e5115
Show file tree
Hide file tree
Showing 56 changed files with 266 additions and 268 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ jobs:
- name: Upload to codecov.io
uses: codecov/codecov-action@v2
with:
fail_ci_if_error: true
fail_ci_if_error: false
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "smartcore"
description = "Machine Learning in Rust."
homepage = "https://smartcorelib.org"
version = "0.3.0"
version = "0.3.1"
authors = ["smartcore Developers"]
edition = "2021"
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
-----
[![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)

To start getting familiar with the new smartcore v0.5 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
To start getting familiar with the new smartcore v0.3 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
22 changes: 9 additions & 13 deletions src/algorithm/neighbour/fastpair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ mod tests_fastpair {
let distances = fastpair.distances;
let neighbours = fastpair.neighbours;

assert!(distances.len() != 0);
assert!(neighbours.len() != 0);
assert!(!distances.is_empty());
assert!(!neighbours.is_empty());

assert_eq!(10, neighbours.len());
assert_eq!(10, distances.len());
Expand All @@ -276,17 +276,13 @@ mod tests_fastpair {
// We expect an error when we run `FastPair` on this dataset,
// becuase `FastPair` currently only works on a minimum of 3
// points.
let _fastpair = FastPair::new(&dataset);
let fastpair = FastPair::new(&dataset);
assert!(fastpair.is_err());

match _fastpair {
Err(e) => {
let expected_error =
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
assert_eq!(e, expected_error)
}
_ => {
assert!(false);
}
if let Err(e) = fastpair {
let expected_error =
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
assert_eq!(e, expected_error)
}
}

Expand Down Expand Up @@ -582,7 +578,7 @@ mod tests_fastpair {
};
for p in dissimilarities.iter() {
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
min_dissimilarity = p.clone()
min_dissimilarity = *p
}
}

Expand Down
9 changes: 2 additions & 7 deletions src/algorithm/neighbour/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,15 @@ pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch,
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
#[default]
CoverTree,
}

impl Default for KNNAlgorithmName {
fn default() -> Self {
KNNAlgorithmName::CoverTree
}
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
Expand Down
4 changes: 2 additions & 2 deletions src/cluster/dbscan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//!
//! Example:
//!
//! ```
//! ```ignore
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::basic::arrays::Array2;
//! use smartcore::cluster::dbscan::*;
Expand Down Expand Up @@ -511,6 +511,6 @@ mod tests {
.and_then(|dbscan| dbscan.predict(&x))
.unwrap();

println!("{:?}", labels);
println!("{labels:?}");
}
}
4 changes: 2 additions & 2 deletions src/cluster/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ mod tests {

let y: Vec<usize> = kmeans.predict(&x).unwrap();

for i in 0..y.len() {
assert_eq!(y[i] as usize, kmeans._y[i]);
for (i, _y_i) in y.iter().enumerate() {
assert_eq!({ y[i] }, kmeans._y[i]);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/dataset/boston.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("boston.xy"))
{
Err(why) => panic!("Can't deserialize boston.xy. {}", why),
Err(why) => panic!("Can't deserialize boston.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};

Expand Down
2 changes: 1 addition & 1 deletion src/dataset/breast_cancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("breast_cancer.xy")) {
Err(why) => panic!("Can't deserialize breast_cancer.xy. {}", why),
Err(why) => panic!("Can't deserialize breast_cancer.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/diabetes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("diabetes.xy")) {
Err(why) => panic!("Can't deserialize diabetes.xy. {}", why),
Err(why) => panic!("Can't deserialize diabetes.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/digits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("digits.xy"))
{
Err(why) => panic!("Can't deserialize digits.xy. {}", why),
Err(why) => panic!("Can't deserialize digits.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};

Expand Down
2 changes: 1 addition & 1 deletion src/dataset/iris.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features): (Vec<f32>, Vec<u32>, usize, usize) =
match deserialize_data(std::include_bytes!("iris.xy")) {
Err(why) => panic!("Can't deserialize iris.xy. {}", why),
Err(why) => panic!("Can't deserialize iris.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub(crate) fn serialize_data<X: Number + RealNumber, Y: RealNumber>(
.collect();
file.write_all(&y)?;
}
Err(why) => panic!("couldn't create {}: {}", filename, why),
Err(why) => panic!("couldn't create {filename}: {why}"),
}
Ok(())
}
Expand Down
20 changes: 9 additions & 11 deletions src/decomposition/pca.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,7 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable

if parameters.n_components > n {
return Err(Failed::fit(&format!(
"Number of components, n_components should be <= number of attributes ({})",
n
"Number of components, n_components should be <= number of attributes ({n})"
)));
}

Expand Down Expand Up @@ -374,21 +373,20 @@ mod tests {
let parameters = PCASearchParameters {
n_components: vec![2, 4],
use_correlation_matrix: vec![true, false],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert_eq!(next.use_correlation_matrix, true);
assert!(next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert_eq!(next.use_correlation_matrix, true);
assert!(next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert_eq!(next.use_correlation_matrix, false);
assert!(!next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert_eq!(next.use_correlation_matrix, false);
assert!(!next.use_correlation_matrix);
assert!(iter.next().is_none());
}

Expand Down Expand Up @@ -572,8 +570,8 @@ mod tests {
epsilon = 1e-4
));

for i in 0..pca.eigenvalues.len() {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
}

let us_arrests_t = pca.transform(&us_arrests).unwrap();
Expand Down Expand Up @@ -694,8 +692,8 @@ mod tests {
epsilon = 1e-4
));

for i in 0..pca.eigenvalues.len() {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
}

let us_arrests_t = pca.transform(&us_arrests).unwrap();
Expand Down
7 changes: 2 additions & 5 deletions src/decomposition/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable

if parameters.n_components >= p {
return Err(Failed::fit(&format!(
"Number of components, n_components should be < number of attributes ({})",
p
"Number of components, n_components should be < number of attributes ({p})"
)));
}

Expand All @@ -202,8 +201,7 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
let (p_c, k) = self.components.shape();
if p_c != p {
return Err(Failed::transform(&format!(
"Can not transform a {}x{} matrix into {}x{} matrix, incorrect input dimentions",
n, p, n, k
"Can not transform a {n}x{p} matrix into {n}x{k} matrix, incorrect input dimentions"
)));
}

Expand All @@ -227,7 +225,6 @@ mod tests {
fn search_parameters() {
let parameters = SVDSearchParameters {
n_components: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
Expand Down
30 changes: 29 additions & 1 deletion src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,12 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
y: &Y,
parameters: RandomForestClassifierParameters,
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
let (_, num_attributes) = x.shape();
let (x_nrows, num_attributes) = x.shape();
let y_ncols = y.shape();
if x_nrows != y_ncols {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}

let mut yi: Vec<usize> = vec![0; y_ncols];
let classes = y.unique();

Expand Down Expand Up @@ -678,6 +682,30 @@ mod tests {
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
}

#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);

let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];

let fail = RandomForestClassifier::fit(
&x_rand,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
);

assert!(fail.is_err());
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down
30 changes: 30 additions & 0 deletions src/ensemble/random_forest_regressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
let (n_rows, num_attributes) = x.shape();

if n_rows != y.shape() {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}

let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
Expand Down Expand Up @@ -595,6 +599,32 @@ mod tests {
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
}

#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);

let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];

let fail = RandomForestRegressor::fit(
&x_rand,
&y,
RandomForestRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
m: Option::None,
keep_samples: false,
seed: 87,
},
);

assert!(fail.is_err());
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down
4 changes: 2 additions & 2 deletions src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub enum FailedError {
DecompositionFailed,
/// Can't solve for x
SolutionFailed,
/// Erro in input
/// Error in input parameters
ParametersError,
}

Expand Down Expand Up @@ -98,7 +98,7 @@ impl fmt::Display for FailedError {
FailedError::SolutionFailed => "Can't find solution",
FailedError::ParametersError => "Error in input, check parameters",
};
write!(f, "{}", failed_err_str)
write!(f, "{failed_err_str}")
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
clippy::too_many_arguments,
clippy::many_single_char_names,
clippy::unnecessary_wraps,
clippy::upper_case_acronyms
clippy::upper_case_acronyms,
clippy::approx_constant
)]
#![warn(missing_docs)]
#![warn(rustdoc::missing_doc_code_examples)]
Expand Down
Loading

0 comments on commit 19e5115

Please sign in to comment.