forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.rs
101 lines (86 loc) · 3.29 KB
/
training.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
use crate::{data::MnistBatcher, model::Model};
use burn::{
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
optim::{decay::WeightDecayConfig, AdamConfig},
prelude::*,
record::{CompactRecorder, NoStdTrainingRecorder},
tensor::backend::AutodiffBackend,
train::{
metric::{
store::{Aggregate, Direction, Split},
AccuracyMetric, CpuMemory, CpuTemperature, CpuUse, LossMetric,
},
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
},
};
static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist";
#[derive(Config)]
pub struct MnistTrainingConfig {
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 4)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
pub optimizer: AdamConfig,
}
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}
pub fn run<B: AutodiffBackend>(device: B::Device) {
create_artifact_dir(ARTIFACT_DIR);
// Config
let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5)));
let config = MnistTrainingConfig::new(config_optimizer);
B::seed(config.seed);
// Data
let batcher_train = MnistBatcher::<B>::new(device.clone());
let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(MnistDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(MnistDataset::test());
// Model
let learner = LearnerBuilder::new(ARTIFACT_DIR)
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(CpuUse::new())
.metric_valid_numeric(CpuUse::new())
.metric_train_numeric(CpuMemory::new())
.metric_valid_numeric(CpuMemory::new())
.metric_train_numeric(CpuTemperature::new())
.metric_valid_numeric(CpuTemperature::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
StoppingCondition::NoImprovementSince { n_epochs: 1 },
))
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(Model::new(&device), config.optimizer.init(), 1e-4);
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap();
model_trained
.save_file(
format!("{ARTIFACT_DIR}/model"),
&NoStdTrainingRecorder::new(),
)
.expect("Failed to save trained model");
}