Skip to content

Commit

Permalink
Merge pull request #4 from twitchax/twitchax/model_fiddling
Browse files Browse the repository at this point in the history
Twitchax/model fiddling
  • Loading branch information
twitchax authored Mar 24, 2023
2 parents 0a5bfca + 6ff078c commit 2144b5c
Show file tree
Hide file tree
Showing 218 changed files with 855 additions and 369 deletions.
504 changes: 271 additions & 233 deletions Cargo.lock

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "kord"
version = "0.5.2"
version = "0.5.3"
edition = "2021"
authors = ["Aaron Roney <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -29,7 +29,7 @@ analyze_file_alac = ["symphonia/alac", "symphonia/isomp4"]

ml = ["ml_train", "ml_infer"]
ml_base = ["serde", "byteorder", "bincode"]
ml_train = ["ml_base", "rand", "burn-autodiff", "burn/train", "burn/std", "burn-ndarray/std"]
ml_train = ["ml_base", "rand", "rayon", "burn-autodiff", "burn/train", "burn/std", "burn-ndarray/std"]
ml_infer = ["ml_base", "burn", "burn-ndarray"]
ml_gpu = ["ml_train", "burn-tch"]

Expand Down Expand Up @@ -76,10 +76,11 @@ serde = { version = "1.0.152", features = ["derive"], optional = true }
rand = { version = "0.8.4", optional = true }
byteorder = { version = "1.4.3", optional = true }
bincode = { version = "2.0.0-rc.2", git = "https://github.com/bincode-org/bincode.git", default-features = false, optional = true, features = ["alloc", "serde"] }
burn = { git = "https://github.com/burn-rs/burn", default-features = false, optional = true }
burn-autodiff = { git = "https://github.com/burn-rs/burn", optional = true }
burn-tch = { git = "https://github.com/burn-rs/burn", optional = true }
burn-ndarray = { git = "https://github.com/burn-rs/burn", default-features = false, optional = true }
rayon = { version = "1.7.0", optional = true }
burn = { version = "0.6.0", default-features = false, optional = true }
burn-autodiff = { version = "0.6.0", optional = true }
burn-tch = { version = "0.6.0", optional = true }
burn-ndarray = { version = "0.6.0", default-features = false, optional = true }

# plot
plotters = { version = "0.3.4", optional = true }
Expand Down
18 changes: 11 additions & 7 deletions model/model_config.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
{
"source": ".hidden/samples",
"source": "samples",
"destination": "model",
"log": ".hidden/train_log",
"simulation_size": 100,
"mlp_layers": 0,
"simulation_peak_radius": 1.0,
"simulation_harmonic_decay": 0.1,
"simulation_frequency_wobble": 0.4,
"mlp_layers": 3,
"mlp_size": 1024,
"mlp_dropout": 0.3,
"mlp_dropout": 0.1,
"model_epochs": 32,
"model_batch_size": 100,
"model_workers": 32,
"model_workers": 64,
"model_seed": 76980,
"adam_learning_rate": 0.0001,
"adam_weight_decay": 0.00005,
"adam_learning_rate": 0.00001,
"adam_weight_decay": 0.0005,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_epsilon": 1.1920929e-7,
"sigmoid_strength": 1.0
"sigmoid_strength": 1.0,
"no_plots": false
}
Binary file modified model/state.bincode
Binary file not shown.
Binary file modified model/state.json.gz
Binary file not shown.
Binary file added samples/A2_6939561008255810116.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/A2_C♯3_E3_135673607488100431.bin
Binary file not shown.
Binary file added samples/A3_15525224679269967542.bin
Binary file not shown.
Binary file added samples/A3_2114422690770739153.bin
Binary file not shown.
Binary file added samples/A3_C4_E4_3709226405532565024.bin
Binary file not shown.
Binary file added samples/A3_C4_E♭4_7073774578769237268.bin
Binary file not shown.
Binary file added samples/A3_C♯4_E4_13607019354188392587.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/A4_15568704343601286976.bin
Binary file not shown.
Binary file added samples/A4_4604221548488691769.bin
Binary file not shown.
Binary file added samples/A4_C♯5_E5_4202887347195583068.bin
Binary file not shown.
Binary file added samples/A4_C♯5_F5_6603647987771049246.bin
Binary file not shown.
Binary file added samples/A5_12013511172100247510.bin
Binary file not shown.
Binary file added samples/A♭3_8435545193679993968.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added samples/A♭4_6959536689947473480.bin
Binary file not shown.
Binary file added samples/A♭4_C5_E5_18402162344830372805.bin
Binary file not shown.
Binary file added samples/A♭5_7319853225296889603.bin
Binary file not shown.
Binary file added samples/A♯2_15153254466844930854.bin
Binary file not shown.
Binary file added samples/A♯3_8223072861416917522.bin
Binary file not shown.
Binary file added samples/A♯4_18054394618689881644.bin
Binary file not shown.
Binary file added samples/B2_3266109487991583972.bin
Binary file not shown.
Binary file added samples/B2_D3_F3_2655044190809377492.bin
Binary file not shown.
Binary file added samples/B3_14415700770550177593.bin
Binary file not shown.
Binary file added samples/B3_14619477925112100779.bin
Binary file not shown.
Binary file added samples/B3_D4_F4_5360805271218901947.bin
Binary file not shown.
Binary file added samples/B3_D4_F♯4_8233293503881210985.bin
Binary file not shown.
Binary file added samples/B3_D♯4_F♯4_9421294715005877386.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added samples/B4_10348381276932749476.bin
Binary file not shown.
Binary file added samples/B4_8469490325632819679.bin
Binary file not shown.
Binary file added samples/B4_D♯5_G5_16938390119246995197.bin
Binary file not shown.
Binary file added samples/B5_5414276455454342213.bin
Binary file not shown.
Binary file added samples/B♭3_1364161979617313941.bin
Binary file not shown.
Binary file added samples/B♭3_D4_F4_13256197258577637181.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/B♭3_D♭4_F4_839474655859676554.bin
Binary file not shown.
Binary file added samples/B♭4_17045939762980485249.bin
Binary file not shown.
Binary file added samples/B♭4_D5_F5_1353221556208210567.bin
Binary file not shown.
Binary file added samples/B♭4_D5_F♯5_4920510611447780909.bin
Binary file not shown.
Binary file added samples/C3_15324773906739422009.bin
Binary file not shown.
Binary file added samples/C3_2415525817263198274.bin
Binary file not shown.
Binary file added samples/C3_E3_G3_13402213481065715212.bin
Binary file not shown.
Binary file added samples/C3_E3_G3_1851553282790634491.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/C3_E3_G♯3_B3_9924736445809682317.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/C3_E♭3_G3_12069400533193527062.bin
Binary file not shown.
Binary file added samples/C3_E♭3_G3_3313011917094755589.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/C4_13717375330106042948.bin
Binary file not shown.
Binary file added samples/C4_17266515140012857991.bin
Binary file not shown.
Binary file added samples/C4_5200762332132119656.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/C4_D4_E4_G4_16864881080154718167.bin
Binary file not shown.
Binary file added samples/C4_D4_E4_G4_16901159761326542920.bin
Binary file not shown.
Binary file added samples/C4_D4_E4_G4_9937784476176028191.bin
Binary file not shown.
Binary file added samples/C4_E4_G4_18305297924631365159.bin
Binary file not shown.
Binary file added samples/C4_E4_G4_680358085253083288.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added samples/C4_E4_G♯4_12195248690323345897.bin
Binary file not shown.
Binary file added samples/C5_12706616267101778521.bin
Binary file not shown.
Binary file added samples/C5_E5_G5_12896055740239505954.bin
Binary file not shown.
Binary file added samples/C7_11609981053982937465.bin
Binary file not shown.
Binary file added samples/C♯3_13825068807320527062.bin
Binary file not shown.
Binary file added samples/C♯3_8262311341105636167.bin
Binary file not shown.
Binary file added samples/C♯3_D3_14896717167203722579.bin
Binary file not shown.
Binary file added samples/C♯3_E3_G♯3_2210022060339843949.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added samples/C♯4_14997267989228333932.bin
Binary file not shown.
Binary file added samples/C♯4_17451611629002635919.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/C♯5_13400097735607166231.bin
Binary file not shown.
Binary file added samples/D3_16196631068266624711.bin
Binary file not shown.
Binary file added samples/D3_18444562409007216198.bin
Binary file not shown.
Binary file added samples/D3_F3_A3_B3_2369237327170881937.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/D3_F3_A♭3_4400437206425957090.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/D3_F♯3_A3_15577080424539191294.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/D3_F♯3_A3_C4_892564687042373561.bin
Binary file not shown.
Binary file added samples/D3_G3_A3_C4_6498646774839009254.bin
Binary file not shown.
Binary file added samples/D4_2656459150908401698.bin
Binary file not shown.
Binary file added samples/D4_8014362495789169513.bin
Binary file not shown.
Binary file added samples/D4_F♯4_A4_9394056113901936681.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/D5_6018679753790929562.bin
Binary file not shown.
Binary file added samples/D♭3_10354072429227894497.bin
Binary file not shown.
Binary file added samples/D♭4_2186653291238513958.bin
Binary file not shown.
Binary file added samples/D♯3_6752755128074780241.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/D♯4_15621939029397038288.bin
Binary file not shown.
Binary file added samples/E3_6860504213111142350.bin
Binary file not shown.
Binary file added samples/E3_A3_B3_D4_13839322509583485144.bin
Binary file not shown.
Binary file added samples/E3_F3_17463920029253683972.bin
Binary file not shown.
Binary file added samples/E3_F3_F♯3_16972531701729065295.bin
Binary file not shown.
Binary file added samples/E3_G3_B3_17298895633196751641.bin
Binary file not shown.
Binary file added samples/E3_G3_B♭3_14330504753783189247.bin
Binary file not shown.
Binary file added samples/E3_G3_B♭3_D4_639443606553743795.bin
Binary file not shown.
Binary file added samples/E3_G♯3_B3_10913487611829576083.bin
Binary file not shown.
Binary file added samples/E3_G♯3_B3_D4_7226055498008574751.bin
Binary file not shown.
Binary file added samples/E4_15820456051683712460.bin
Binary file not shown.
Binary file added samples/E4_9329485045940018310.bin
Binary file not shown.
Binary file added samples/E4_G♯4_B4_17059164590282932723.bin
Binary file not shown.
Binary file added samples/E4_G♯4_C5_9503586855027572466.bin
Binary file not shown.
Binary file added samples/E5_5994639921338472529.bin
Binary file not shown.
Binary file added samples/E♭3_17235143537953841061.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added samples/E♭3_G♭3_A3_9259349447861020843.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/E♭4_2060572021328877904.bin
Binary file not shown.
Binary file added samples/E♭4_G4_B4_4832356197990273626.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/F3_16539520742415952783.bin
Binary file not shown.
Binary file added samples/F3_7950801072785575133.bin
Binary file not shown.
Binary file added samples/F3_A3_C4_11307864579835325574.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/F3_A♭3_B3_4332497849137612450.bin
Binary file not shown.
Binary file added samples/F3_A♭3_C4_1355102359697209995.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/F3_B♭3_D4_16692027902500825945.bin
Binary file not shown.
Binary file added samples/F3_B♭3_D4_17744299578108796808.bin
Binary file not shown.
Binary file added samples/F3_G3_A3_B3_5305941064774865297.bin
Binary file not shown.
Binary file added samples/F4_10904763293709692540.bin
Binary file not shown.
Binary file added samples/F4_6571238277525569326.bin
Binary file not shown.
Binary file added samples/F4_A4_C5_1248608058186282462.bin
Binary file not shown.
Binary file added samples/F4_A4_C5_9947940578599357514.bin
Binary file not shown.
Binary file added samples/F4_A4_C♯5_3042930637524323314.bin
Binary file not shown.
Binary file added samples/F4_A♭4_D♭5_4213406208658044324.bin
Binary file not shown.
Binary file added samples/F4_B♭4_D5_4624166178754337478.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/F5_3747367150457817029.bin
Binary file not shown.
Binary file added samples/F5_A♭5_C6_7301784480138511550.bin
Binary file not shown.
Binary file added samples/F♯3_7477442491868692976.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/F♯3_A3_C4_1591071203735525464.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added samples/F♯4_16824356764662142236.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/F♯4_A♯4_D5_7814351117781437041.bin
Binary file not shown.
Binary file added samples/F♯5_16104568988042523506.bin
Binary file not shown.
Binary file added samples/G3_3316245545651686790.bin
Binary file not shown.
Binary file added samples/G3_A3_B3_D4_12524365298676383491.bin
Binary file not shown.
Binary file added samples/G3_B3_D4_11467983826783945055.bin
Binary file not shown.
Binary file added samples/G3_B3_D4_2112711659097138676.bin
Binary file not shown.
Binary file added samples/G3_B3_D4_F4_13338229000746797908.bin
Binary file not shown.
Binary file not shown.
Binary file added samples/G3_B♭3_D4_10606109863528597598.bin
Binary file not shown.
Binary file added samples/G3_B♭3_D4_5574413140515868152.bin
Binary file not shown.
Binary file added samples/G3_B♭3_D♭4_3572619472176912413.bin
Binary file not shown.
Binary file added samples/G3_B♭3_E♭4_5545839377034539323.bin
Binary file not shown.
Binary file added samples/G3_B♭3_E♭4_825251809988502114.bin
Binary file not shown.
Binary file added samples/G3_C4_E♭4_5834158247099421966.bin
Binary file not shown.
Binary file added samples/G4_4654945295144228095.bin
Binary file not shown.
Binary file added samples/G4_7760857318399794168.bin
Binary file not shown.
Binary file added samples/G4_B4_D5_14226150188569230476.bin
Binary file not shown.
Binary file added samples/G4_B4_D♯5_9158029766180372297.bin
Binary file not shown.
Binary file added samples/G4_G♯4_A4_12494624235026406024.bin
Binary file not shown.
Binary file added samples/G5_2281724682067383580.bin
Binary file not shown.
Binary file added samples/G♭3_17206518042906139789.bin
Binary file not shown.
Binary file added samples/G♭4_8742558092056985343.bin
Binary file not shown.
Binary file added samples/G♯3_14454117319529149317.bin
Binary file not shown.
Binary file added samples/G♯3_B3_D4_17256496026329066508.bin
Binary file not shown.
Binary file added samples/G♯4_4127244235251280031.bin
Binary file not shown.
Binary file added samples/G♯5_9131614849553198404.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added samples/quarantine/G3_13701419381069431781.bin
Binary file not shown.
Binary file added samples/quarantine/G3_753513970140615125.bin
Binary file not shown.
Binary file not shown.
Binary file not shown.
53 changes: 52 additions & 1 deletion src/analyze/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,56 @@ pub fn get_time_space(data: &[f32]) -> Vec<(f32, f32)> {
buffer.into_iter().enumerate().map(|(k, d)| (k as f32, d.abs())).collect::<Vec<_>>()
}

pub fn compute_cqt(frequency_space: &[f32]) -> Vec<f32> {
const Q_FACTOR: f32 = 24.7; // Q-factor for the CQT
const MIN_FREQ: f32 = 65.41; // minimum frequency for the CQT
const MAX_FREQ: f32 = 2093.0; // maximum frequency for the CQT
const N_BINS: usize = 60; // number of frequency bins for the CQT

let mut cqt_output = vec![vec![0.0; frequency_space.len()]; N_BINS];

let log_min_freq = MIN_FREQ.log2();
let log_max_freq = MAX_FREQ.log2();
let log_freq_step = (log_max_freq - log_min_freq) / (N_BINS as f32 - 1.0);

for i in 0..N_BINS {
let log_freq_center = log_min_freq + i as f32 * log_freq_step;
let freq_center = 2.0f32.powf(log_freq_center);
let freq_bw = freq_center / Q_FACTOR;
let fft_freq_step = 1.0;

let start_bin = (freq_center - freq_bw / 2.0) / fft_freq_step;
let end_bin = (freq_center + freq_bw / 2.0) / fft_freq_step;

let mut cqt_bin = vec![rustfft::num_complex::Complex::new(0.0, 0.0); frequency_space.len()];

for j in start_bin as usize..=end_bin as usize {
let weight = (j as f32 - freq_center / fft_freq_step) / freq_bw;
let weight = weight * std::f32::consts::PI * 2.0;
let fft_bin = frequency_space[j];
cqt_bin[j] = rustfft::num_complex::Complex::new(fft_bin * weight.sin(), 0.0);
}

let ifft = rustfft::FftPlanner::<f32>::new().plan_fft_inverse(cqt_bin.len());
ifft.process(&mut cqt_bin);

for j in 0..frequency_space.len() {
cqt_output[i][j] = cqt_bin[j].abs();
}
}

let mut result = vec![];
for k in 0..N_BINS {
let mut sum = 0.0;
for j in 0..frequency_space.len() {
sum += cqt_output[k][j];
}
result.push(sum);
}

result
}

/// Calculates the "smoothed" frequency space by normalizing to 1.0 seconds of playback.
pub fn get_smoothed_frequency_space(frequency_space: &[(f32, f32)], length_in_seconds: u8) -> Vec<(f32, f32)> {
let mut smoothed_frequency_space = Vec::new();
Expand Down Expand Up @@ -232,6 +282,7 @@ fn reduce_notes_by_harmonic_series(notes: &[(Note, f32)], cutoff: f32) -> Vec<No
/// the one before, and the next one.
///
/// Returns a vector of tuples, where the first element is the note, and the second element is the frequency window as a (low, high) tuple.
/// The first and the last note supplied are ignored, so this method returns `notes.len() - 2` elements.
pub fn get_frequency_bins(notes: &[Note]) -> Vec<(Note, (f32, f32))> {
let mut bins = Vec::new();

Expand Down Expand Up @@ -344,7 +395,7 @@ pub(crate) mod tests {

#[test]
fn test_get_frequency_bins() {
let bins = get_frequency_bins(&ALL_PITCH_NOTES.iter().skip(23).take(62).cloned().collect::<Vec<_>>());
let bins = get_frequency_bins(&ALL_PITCH_NOTES.iter().skip(24).take(62).cloned().collect::<Vec<_>>());

assert_eq!(bins.len(), 60);
}
Expand Down
94 changes: 78 additions & 16 deletions src/bin.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::path::PathBuf;

use clap::{ArgAction, Parser, Subcommand};
use klib::core::{
use klib::{core::{
base::{Parsable, Res, Void},
chord::{Chord, Chordable},
note::Note,
octave::Octave,
};
}};

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -145,7 +145,7 @@ enum MlCommand {
#[cfg(feature = "ml_train")]
Train {
/// The source directory for the gathered samples.
#[arg(long, default_value = ".hidden/samples")]
#[arg(long, default_value = "samples")]
source: String,

/// The destination directory for the trained model.
Expand All @@ -156,24 +156,36 @@ enum MlCommand {
#[arg(long, default_value = ".hidden/train_log")]
log: String,

/// The device to use for training.
#[arg(long, default_value = "gpu")]
device: String,

/// Simulation data set size.
#[arg(long, default_value_t = 100)]
simulation_size: usize,

/// The device to use for training.
#[arg(long, default_value = "gpu")]
device: String,
/// Simulation peak radius.
#[arg(long, default_value_t = 1.0)]
simulation_peak_radius: f32,

/// Simulation harmonic decay.
#[arg(long, default_value_t = 0.1)]
simulation_harmonic_decay: f32,

/// Simulation frequency wobble.
#[arg(long, default_value_t = 0.4)]
simulation_frequency_wobble: f32,

/// The number of Multi Layer Perceptron (MLP) layers.
#[arg(long, default_value_t = 0)]
#[arg(long, default_value_t = 3)]
mlp_layers: usize,

/// The number of neurons in each Multi Layer Perceptron (MLP) layer.
#[arg(long, default_value_t = 1024)]
mlp_size: usize,

/// The Multi Layer Perceptron (MLP) dropout rate.
#[arg(long, default_value_t = 0.3)]
#[arg(long, default_value_t = 0.1)]
mlp_dropout: f64,

/// The number of epochs to train for.
Expand All @@ -185,19 +197,19 @@ enum MlCommand {
model_batch_size: usize,

/// The number of workers to use for training.
#[arg(long, default_value_t = 32)]
#[arg(long, default_value_t = 64)]
model_workers: usize,

/// The seed used for training.
#[arg(long, default_value_t = 76980)]
model_seed: u64,

/// The Adam optimizer learning rate.
#[arg(long, default_value_t = 1e-4)]
#[arg(long, default_value_t = 1e-5)]
adam_learning_rate: f64,

/// The Adam optimizer weight decay.
#[arg(long, default_value_t = 5e-5)]
#[arg(long, default_value_t = 5e-4)]
adam_weight_decay: f64,

/// The Adam optimizer beta1.
Expand All @@ -215,6 +227,10 @@ enum MlCommand {
/// The "sigmoid strength" of the final pass.
#[arg(long, default_value_t = 1.0)]
sigmoid_strength: f32,

/// Suppresses the training plots.
#[arg(long, action=ArgAction::SetTrue, default_value_t = false)]
no_plots: bool,
},

/// Records audio from the microphone, and using the trained model, guesses the chord.
Expand All @@ -238,6 +254,26 @@ enum MlCommand {
#[arg(long, default_value_t = 8192.0)]
x_max: f32,
},

/// Runs the ML trainer across various hyperparameters, and outputs the results.
#[cfg(feature = "ml_train")]
Hpt {
/// The source directory for the gathered samples.
#[arg(long, default_value = "samples")]
source: String,

/// The destination directory for the trained model.
#[arg(long, default_value = "model")]
destination: String,

/// The log directory for training.
#[arg(long, default_value = ".hidden/train_log")]
log: String,

/// The device to use for training.
#[arg(long, default_value = "gpu")]
device: String,
}
}

#[derive(Subcommand, Debug)]
Expand Down Expand Up @@ -360,6 +396,9 @@ fn start(args: Args) -> Void {
log,
simulation_size,
device,
simulation_peak_radius,
simulation_harmonic_decay,
simulation_frequency_wobble,
mlp_layers,
mlp_size,
mlp_dropout,
Expand All @@ -373,6 +412,7 @@ fn start(args: Args) -> Void {
adam_beta2,
adam_epsilon,
sigmoid_strength,
no_plots
}) => {
use burn_autodiff::ADBackendDecorator;
use klib::ml::base::TrainConfig;
Expand All @@ -382,6 +422,9 @@ fn start(args: Args) -> Void {
destination,
log,
simulation_size,
simulation_peak_radius,
simulation_harmonic_decay,
simulation_frequency_wobble,
mlp_layers,
mlp_size,
mlp_dropout,
Expand All @@ -395,23 +438,27 @@ fn start(args: Args) -> Void {
adam_beta2,
adam_epsilon,
sigmoid_strength,
no_plots,
};

match device.as_str() {
#[cfg(feature = "ml_gpu")]
"gpu" => {
use burn_tch::{TchBackend, TchDevice};

#[cfg(not(target_os = "macos"))]
let device = TchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = TchDevice::Mps;

klib::ml::train::run_training::<ADBackendDecorator<TchBackend<f32>>>(device, &config, true)?;
klib::ml::train::run_training::<ADBackendDecorator<TchBackend<f32>>>(device, &config, true, true)?;
}
"cpu" => {
use burn_ndarray::{NdArrayBackend, NdArrayDevice};

let device = NdArrayDevice::Cpu;

klib::ml::train::run_training::<ADBackendDecorator<NdArrayBackend<f32>>>(device, &config, true)?;
klib::ml::train::run_training::<ADBackendDecorator<NdArrayBackend<f32>>>(device, &config, true, true)?;
}
_ => {
return Err(anyhow::Error::msg("Invalid device (must choose either `gpu` [requires `ml_gpu` feature] or `cpu`)."));
Expand Down Expand Up @@ -464,11 +511,10 @@ fn start(args: Args) -> Void {
Some(MlCommand::Plot { source, x_min, x_max }) => {
use anyhow::Context;
use klib::{
analyze::base::translate_frequency_space_to_peak_space,
analyze::base::{translate_frequency_space_to_peak_space, compute_cqt},
helpers::plot_frequency_space,
ml::{
base::MEL_SPACE_SIZE,
train::helpers::{load_kord_item, mel_filter_banks_from},
base::{MEL_SPACE_SIZE, helpers::{load_kord_item, mel_filter_banks_from, harmonic_convolution}},
},
};

Expand All @@ -482,6 +528,16 @@ fn start(args: Args) -> Void {
let frequency_space = kord_item.frequency_space.into_iter().enumerate().map(|(k, v)| (k as f32, v)).collect::<Vec<_>>();
plot_frequency_space(&frequency_space, "KordItem Frequency Space", &frequency_file_name, x_min, x_max);

// Plot harmonic convolution.
let harmonic_file_name = format!("{}_harmonic", name);
let harmonic_space = harmonic_convolution(&kord_item.frequency_space).into_iter().enumerate().map(|(k, v)| (k as f32, v)).collect::<Vec<_>>();
plot_frequency_space(&harmonic_space, "KordItem Harmonic Space", &harmonic_file_name, x_min, x_max);

// Plot CQT space.
let cqt_file_name = format!("{}_cqt", name);
let cqt_space = compute_cqt(&kord_item.frequency_space).into_iter().enumerate().map(|(k, v)| (k as f32, v)).collect::<Vec<_>>();
plot_frequency_space(&cqt_space, "KordItem CQT Space", &cqt_file_name, 0.0, 256.0);

// Plot mel space.
let mel_file_name = format!("{}_mel", name);
let mel_space = mel_filter_banks_from(&kord_item.frequency_space)
Expand Down Expand Up @@ -521,6 +577,12 @@ fn start(args: Args) -> Void {
let time_space = klib::analyze::base::get_time_space(&peak_space);
plot_frequency_space(&time_space, "KordItem Time Space", &harmonic_file_name, x_min, x_max);
}
#[cfg(feature = "ml_train")]
Some(MlCommand::Hpt { source, destination, log, device }) => {
use klib::ml::train::execute::hyper_parameter_tuning;

hyper_parameter_tuning(source, destination, log, device)?;
}
None => {
return Err(anyhow::Error::msg("No subcommand given for `ml`."));
}
Expand Down
Loading

0 comments on commit 2144b5c

Please sign in to comment.