-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathmain.rs
79 lines (72 loc) · 2.27 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
extern crate csv;
extern crate itertools;
extern crate lightgbm;
extern crate serde_json;
use itertools::zip;
use lightgbm::{Booster, Dataset};
use serde_json::json;
fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
let rdr = csv::ReaderBuilder::new()
.has_headers(false)
.delimiter(b'\t')
.from_path(file_path);
let mut labels: Vec<f32> = Vec::new();
let mut features: Vec<Vec<f64>> = Vec::new();
for result in rdr.unwrap().records() {
let record = result.unwrap();
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record
.iter()
.map(|x| x.parse::<f64>().unwrap())
.collect::<Vec<f64>>()[1..]
.to_vec();
labels.push(label);
features.push(feature);
}
(features, labels)
}
fn argmax<T: PartialOrd>(xs: &[T]) -> usize {
if xs.len() == 1 {
0
} else {
let mut maxval = &xs[0];
let mut max_ixs: Vec<usize> = vec![0];
for (i, x) in xs.iter().enumerate().skip(1) {
if x > maxval {
maxval = x;
max_ixs = vec![i];
} else if x == maxval {
max_ixs.push(i);
}
}
max_ixs[0]
}
}
fn main() -> std::io::Result<()> {
let (train_features, train_labels) = load_file(
"../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.train",
);
let (test_features, test_labels) =
load_file("../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.test");
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();
let params = json! {
{
"num_iterations": 100,
"objective": "multiclass",
"metric": "multi_logloss",
"num_class": 5,
}
};
let booster = Booster::train(train_dataset, ¶ms).unwrap();
let result = booster.predict(test_features).unwrap();
let mut tp = 0;
for (label, pred) in zip(&test_labels, &result) {
let argmax_pred = argmax(&pred);
if *label == argmax_pred as f32 {
tp += 1;
}
println!("{}, {}, {:?}", label, argmax_pred, &pred);
}
println!("{} / {}", &tp, result.len());
Ok(())
}