From 57a20e8289c48916e9e9dd6e9a0289007dc8c922 Mon Sep 17 00:00:00 2001 From: John Lees Date: Thu, 1 Jun 2023 16:28:55 +0100 Subject: [PATCH 1/8] Converting R optimisation code into rust --- Cargo.toml | 8 ++++- src/cli.rs | 22 ++++++++++++ src/coverage.rs | 93 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 16 +++++++++ 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 src/coverage.rs diff --git a/Cargo.toml b/Cargo.toml index 73209de..00f87ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,13 @@ [package] name = "ska" -version = "0.2.5" +version = "0.3.0" authors = [ "John Lees ", "Simon Harris ", "Johanna von Wachsmann ", + "Tommi Maklin ", + "Joel Hellewell ", + "Timothy Russell " ] edition = "2021" description = "Split k-mer analysis" @@ -44,6 +47,9 @@ hashbrown = "0.12" ahash = "0.8.2" ndarray = { version = "0.15.6", features = ["serde"] } num-traits = "0.2.15" +# coverage model +libm = "0.2.7" +cached = "0.43.0" [dev-dependencies] # testing diff --git a/src/cli.rs b/src/cli.rs index 260bd3a..afb46c7 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -249,6 +249,28 @@ pub enum Commands { #[arg(long, default_value_t = false)] full_info: bool, }, + /// Estimate a coverage cutoff using a k-mer count profile (FASTQ only) + Cov { + /// FASTQ file (or .fastq.gz) with forward reads + #[arg(short)] + fastq_fwd: String, + + /// FASTQ file (or .fastq.gz) with reverse reads + #[arg(short)] + fastq_rev: String, + + /// K-mer size + #[arg(short, value_parser = valid_kmer, default_value_t = DEFAULT_KMER)] + k: usize, + + /// Ignore reverse complement (all contigs are oriented along same strand) + #[arg(long, default_value_t = DEFAULT_STRAND)] + single_strand: bool, + + /// Number of CPU threads + #[arg(long, value_parser = valid_cpus, default_value_t = 1)] + threads: usize, + } } /// Function to parse command line args into [`Args`] struct diff --git a/src/coverage.rs b/src/coverage.rs new file mode 100644 index 0000000..e669129 --- /dev/null +++ b/src/coverage.rs @@ -0,0 +1,93 @@ + +use hashbrown::HashMap; +extern crate needletail; +use ndarray::Array1; +use needletail::{parse_fastx_file, parser::Format}; +use libm::lgamma; +use cached::proc_macro::cached; + +use crate::ska_dict::bit_encoding::UInt; + +pub struct CoverageHistogram { + /// Dictionary of k-mers + kmer_counts: HashMap, + /// Count histogram + counts: Vec, + /// K-mer size + k: usize, + /// Whether reverse complement split k-mers were used + rc: bool, +} + +impl CoverageHistogram +where + IntT: for<'a> UInt<'a>, +{ + // e.g. f = |x: &Array1| log_likelihood(x, self.counts); + +} + +// Called by lib.rs +pub fn coverage_cutoff UInt<'a>>(fastq1: &String, fastq2: &String, k: usize, rc: bool, threads: usize) { + +} + +// log-sum-exp +fn lse(a: f64, b: f64) -> f64 { + let xstar = f64::max(a, b); + xstar + f64::ln(f64::exp(a - xstar) + f64::exp(b - xstar)) +} + +// TODO: 64 needs to be cached as bytes + +// Natural log of Poisson density +#[cached] +fn ln_dpois(x: u64, lambda: [u8; 8]) -> f64 { + let lambda_f = f64::from_le_bytes(lambda); + x as f64 * (lambda_f) - lgamma(x as f64 + 1.0) - lambda_f +} + +// error component +#[cached] +fn a(w0: f64, i: f64) -> f64 { + f64::ln(w0) + ln_dpois(i, 1.0) +} + +// coverage component +#[cached] +fn b(w0: f64, c: f64, i: f64) -> f64 { + f64::ln(1.0 - w0) + ln_dpois(i, c) +} + +// Mixture likelihood +fn log_likelihood(pars: &Array1, counts: &[f64]) -> f64 { + let w0 = pars[0]; + let c = pars[1]; + let mut ll = 0.0; + if w0 > 1.0 || w0 < 0.0 || c < 1.0 { + ll = f64::NEG_INFINITY; + } else { + for (i, c) in counts.iter().enumerate() { + ll += counts[i] * lse(a(w0, i as f64), b(w0, *c, i as f64)); + } + } + ll +} + +fn grad_ll(pars: &Array1, counts: &[f64]) -> Array1 { + let w0 = pars[0]; + let c = pars[1]; + + let mut grad_w0 = 0.0; + let mut grad_c = 0.0; + for (i, c) in counts.iter().enumerate() { + let i_f64 = i as f64; + let a_val = a(w0, i_f64); + let b_val = b(w0, *c, i_f64); + let dlda = 1.0 / (1.0 + f64::exp(b_val - a_val)); + let dldb = 1.0 / (1.0 + f64::exp(a_val - b_val)); + grad_w0 += counts[i] * (dlda/w0 - dldb/(1.0 - w0)); + grad_c += counts[i] * (dldb*(i_f64/c - 1.0)); + } + Array1::from_vec(vec![grad_w0, grad_c]) +} diff --git a/src/lib.rs b/src/lib.rs index 83bb5f0..76a26ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -321,6 +321,9 @@ use crate::cli::*; pub mod io_utils; use crate::io_utils::*; +pub mod coverage; +use crate::coverage::*; + /// Possible quality score filters when building with reads #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] pub enum QualFilter { @@ -545,6 +548,19 @@ pub fn main() { panic!("Could not read input file: {skf_file}"); } } + Commands::Cov { fastq_fwd, fastq_rev, k, single_strand, threads } => { + check_threads(*threads); + + // Build, merge + let rc = !*single_strand; + if *k <= 31 { + log::info!("k={}: using 64-bit representation", *k); + coverage_cutoff::(fastq_fwd, fastq_rev, *k, rc, *threads); + } else { + log::info!("k={}: using 128-bit representation", *k); + coverage_cutoff::(fastq_fwd, fastq_rev, *k, rc, *threads); + } + } } let end = Instant::now(); From 3d01766dcde564c7bf422f3d84e5637f330ef154 Mon Sep 17 00:00:00 2001 From: John Lees Date: Fri, 2 Jun 2023 18:20:49 +0100 Subject: [PATCH 2/8] Working version of optimiser --- Cargo.toml | 3 +- src/cli.rs | 10 +-- src/coverage.rs | 224 +++++++++++++++++++++++++++++++++++++++++------- src/lib.rs | 21 +++-- 4 files changed, 214 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 00f87ed..ccaf7b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,8 @@ ndarray = { version = "0.15.6", features = ["serde"] } num-traits = "0.2.15" # coverage model libm = "0.2.7" -cached = "0.43.0" +argmin = { version = "0.8.1", features = ["slog-logger"] } +argmin-math = "0.3.0" [dev-dependencies] # testing diff --git a/src/cli.rs b/src/cli.rs index afb46c7..979ecd0 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -252,25 +252,19 @@ pub enum Commands { /// Estimate a coverage cutoff using a k-mer count profile (FASTQ only) Cov { /// FASTQ file (or .fastq.gz) with forward reads - #[arg(short)] fastq_fwd: String, /// FASTQ file (or .fastq.gz) with reverse reads - #[arg(short)] fastq_rev: String, /// K-mer size #[arg(short, value_parser = valid_kmer, default_value_t = DEFAULT_KMER)] k: usize, - /// Ignore reverse complement (all contigs are oriented along same strand) + /// Ignore reverse complement (all reads are oriented along same strand) #[arg(long, default_value_t = DEFAULT_STRAND)] single_strand: bool, - - /// Number of CPU threads - #[arg(long, value_parser = valid_cpus, default_value_t = 1)] - threads: usize, - } + }, } /// Function to parse command line args into [`Args`] struct diff --git a/src/coverage.rs b/src/coverage.rs index e669129..8f5fdb2 100644 --- a/src/coverage.rs +++ b/src/coverage.rs @@ -1,35 +1,207 @@ +use core::panic; + +use argmin::core::observers::{ObserverMode, SlogLogger}; +use argmin::core::{ + CostFunction, Error, Executor, Gradient, State, TerminationReason::SolverConverged, +}; +use argmin::solver::linesearch::condition::ArmijoCondition; +use argmin::solver::linesearch::BacktrackingLineSearch; +use argmin::solver::quasinewton::BFGS; + +use libm::lgamma; use hashbrown::HashMap; extern crate needletail; -use ndarray::Array1; use needletail::{parse_fastx_file, parser::Format}; -use libm::lgamma; -use cached::proc_macro::cached; use crate::ska_dict::bit_encoding::UInt; +use crate::ska_dict::split_kmer::SplitKmer; +use crate::QualFilter; + +const MAX_COUNT: usize = 1000; +const MIN_FREQ: u32 = 50; pub struct CoverageHistogram { - /// Dictionary of k-mers - kmer_counts: HashMap, - /// Count histogram - counts: Vec, /// K-mer size k: usize, /// Whether reverse complement split k-mers were used rc: bool, + /// Dictionary of k-mers and their counts + kmer_dict: HashMap, + /// Count histogram + counts: Vec, + /// Estimated error weight + w0: f64, + /// Estimated coverage + c: f64, + /// Coverage cutoff + cutoff: u32, + /// Show logging + verbose: bool, } impl CoverageHistogram where IntT: for<'a> UInt<'a>, { - // e.g. f = |x: &Array1| log_likelihood(x, self.counts); + // Called by lib.rs + pub fn new(fastq1: &String, fastq2: &String, k: usize, rc: bool, verbose: bool) -> Self { + if !(5..=63).contains(&k) || k % 2 == 0 { + panic!("Invalid k-mer length"); + } + + let mut cov_counts = Self { + k, + rc, + kmer_dict: HashMap::default(), + counts: vec![0; MAX_COUNT], + w0: 0.8, + c: 20.0, + cutoff: 0, + verbose, + }; + + // Check if we're working with reads first + for fastx_file in [fastq1, fastq2] { + let mut reader_peek = parse_fastx_file(fastx_file) + .unwrap_or_else(|_| panic!("Invalid path/file: {}", fastx_file)); + let seq_peek = reader_peek + .next() + .expect("Invalid FASTA/Q record") + .expect("Invalid FASTA/Q record"); + if seq_peek.format() != Format::Fastq { + panic!("{fastx_file} appears to be FASTA.\nCoverage can only be used with FASTQ files, not FASTA."); + } + } + + log::info!("Counting k-mers"); + for fastx_file in [fastq1, fastq2] { + let mut reader = parse_fastx_file(fastx_file) + .unwrap_or_else(|_| panic!("Invalid path/file: {fastx_file}")); + while let Some(record) = reader.next() { + let seqrec = record.expect("Invalid FASTA/Q record"); + let kmer_opt = SplitKmer::new( + seqrec.seq(), + seqrec.num_bases(), + seqrec.qual(), + cov_counts.k, + cov_counts.rc, + 0, + QualFilter::NoFilter, + false, + ); + if let Some(mut kmer_it) = kmer_opt { + let (kmer, _base, _rc) = kmer_it.get_curr_kmer(); + cov_counts + .kmer_dict + .entry(kmer) + .and_modify(|count| *count += 1) + .or_insert(1); + while let Some((kmer, _base, _rc)) = kmer_it.get_next_kmer() { + cov_counts + .kmer_dict + .entry(kmer) + .and_modify(|count| *count += 1) + .or_insert(1); + } + } + } + } + + cov_counts + } + + pub fn fit_histogram(&mut self) -> Result { + // Calculate k-mer histogram + log::info!("Calculating k-mer histogram"); + for kmer_count in self.kmer_dict.values() { + let kc = (*kmer_count - 1) as usize; + if kc < MAX_COUNT { + self.counts[kc] += 1; + } + } + + // Truncate count vec and covert to float + let mut counts_f64: Vec = Vec::new(); + for hist_bin in &self.counts { + if *hist_bin < MIN_FREQ { + break; + } else { + counts_f64.push(*hist_bin as f64); + } + } + + log::info!("Fitting Poisson mixture model using maximum likelihood"); + let mixture_fit = MixPoisson { counts: counts_f64 }; + let init_param: Vec = vec![self.w0, self.c]; + let init_hessian: Vec> = vec![vec![1.0, 0.0], vec![0.0, 1.0]]; + let linesearch = BacktrackingLineSearch::new(ArmijoCondition::new(0.0001f64)?); + let solver = BFGS::new(linesearch); + let mut exec = Executor::new(mixture_fit, solver).configure(|state| { + state + .param(init_param) + .inv_hessian(init_hessian) + .max_iters(100) + }); + if self.verbose { + exec = exec.add_observer(SlogLogger::term(), ObserverMode::Always); + } + let res = exec.run()?; + + // print diagnostics + log::info!("{res}"); + if let Some(termination_reason) = res.state().get_termination_reason() { + if *termination_reason == SolverConverged { + // Best parameter vector + let best = res.state().get_best_param().unwrap(); + self.w0 = best[0]; + self.c = best[1]; + + // TODO calculate the coverage cutoff + Ok(self.cutoff) + } else { + Err(Error::msg(format!( + "Optimiser did not converge: {}", + termination_reason.text() + ))) + } + } else { + Err(Error::msg("Optimiser did not finish running")) + } + } + pub fn plot_hist() { + todo!() + } +} + +struct MixPoisson { + counts: Vec, } -// Called by lib.rs -pub fn coverage_cutoff UInt<'a>>(fastq1: &String, fastq2: &String, k: usize, rc: bool, threads: usize) { +impl CostFunction for MixPoisson { + /// Type of the parameter vector + type Param = Vec; + /// Type of the return value computed by the cost function + type Output = f64; + /// Apply the cost function to a parameter `p` + fn cost(&self, p: &Self::Param) -> Result { + Ok(-log_likelihood(p, &self.counts)) + } +} + +impl Gradient for MixPoisson { + /// Type of the parameter vector + type Param = Vec; + /// Type of the gradient + type Gradient = Vec; + + /// Compute the gradient at parameter `p`. + fn gradient(&self, p: &Self::Param) -> Result { + // Compute gradient of 2D Rosenbrock function + Ok(grad_ll(p, &self.counts).iter().map(|x| -*x).collect()) + } } // log-sum-exp @@ -38,56 +210,50 @@ fn lse(a: f64, b: f64) -> f64 { xstar + f64::ln(f64::exp(a - xstar) + f64::exp(b - xstar)) } -// TODO: 64 needs to be cached as bytes - // Natural log of Poisson density -#[cached] -fn ln_dpois(x: u64, lambda: [u8; 8]) -> f64 { - let lambda_f = f64::from_le_bytes(lambda); - x as f64 * (lambda_f) - lgamma(x as f64 + 1.0) - lambda_f +fn ln_dpois(x: f64, lambda: f64) -> f64 { + x * f64::ln(lambda) - lgamma(x + 1.0) - lambda } // error component -#[cached] fn a(w0: f64, i: f64) -> f64 { f64::ln(w0) + ln_dpois(i, 1.0) } // coverage component -#[cached] fn b(w0: f64, c: f64, i: f64) -> f64 { f64::ln(1.0 - w0) + ln_dpois(i, c) } // Mixture likelihood -fn log_likelihood(pars: &Array1, counts: &[f64]) -> f64 { +fn log_likelihood(pars: &[f64], counts: &[f64]) -> f64 { let w0 = pars[0]; let c = pars[1]; let mut ll = 0.0; if w0 > 1.0 || w0 < 0.0 || c < 1.0 { - ll = f64::NEG_INFINITY; + ll = f64::MIN; } else { - for (i, c) in counts.iter().enumerate() { - ll += counts[i] * lse(a(w0, i as f64), b(w0, *c, i as f64)); + for (i, count) in counts.iter().enumerate() { + ll += *count * lse(a(w0, i as f64 + 1.0), b(w0, c, i as f64 + 1.0)); } } ll } -fn grad_ll(pars: &Array1, counts: &[f64]) -> Array1 { +fn grad_ll(pars: &[f64], counts: &[f64]) -> Vec { let w0 = pars[0]; let c = pars[1]; let mut grad_w0 = 0.0; let mut grad_c = 0.0; - for (i, c) in counts.iter().enumerate() { - let i_f64 = i as f64; + for (i, count) in counts.iter().enumerate() { + let i_f64 = i as f64 + 1.0; let a_val = a(w0, i_f64); - let b_val = b(w0, *c, i_f64); + let b_val = b(w0, c, i_f64); let dlda = 1.0 / (1.0 + f64::exp(b_val - a_val)); let dldb = 1.0 / (1.0 + f64::exp(a_val - b_val)); - grad_w0 += counts[i] * (dlda/w0 - dldb/(1.0 - w0)); - grad_c += counts[i] * (dldb*(i_f64/c - 1.0)); + grad_w0 += *count as f64 * (dlda / w0 - dldb / (1.0 - w0)); + grad_c += *count as f64 * (dldb * (i_f64 / c - 1.0)); } - Array1::from_vec(vec![grad_w0, grad_c]) + vec![grad_w0, grad_c] } diff --git a/src/lib.rs b/src/lib.rs index 76a26ea..0827149 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -322,7 +322,7 @@ pub mod io_utils; use crate::io_utils::*; pub mod coverage; -use crate::coverage::*; +use crate::coverage::CoverageHistogram; /// Possible quality score filters when building with reads #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -548,18 +548,27 @@ pub fn main() { panic!("Could not read input file: {skf_file}"); } } - Commands::Cov { fastq_fwd, fastq_rev, k, single_strand, threads } => { - check_threads(*threads); - + Commands::Cov { + fastq_fwd, + fastq_rev, + k, + single_strand, + } => { // Build, merge let rc = !*single_strand; + let cutoff; if *k <= 31 { log::info!("k={}: using 64-bit representation", *k); - coverage_cutoff::(fastq_fwd, fastq_rev, *k, rc, *threads); + let mut cov = + CoverageHistogram::::new(fastq_fwd, fastq_rev, *k, rc, args.verbose); + cutoff = cov.fit_histogram().expect("Couldn't fit coverage model"); } else { log::info!("k={}: using 128-bit representation", *k); - coverage_cutoff::(fastq_fwd, fastq_rev, *k, rc, *threads); + let mut cov = + CoverageHistogram::::new(fastq_fwd, fastq_rev, *k, rc, args.verbose); + cutoff = cov.fit_histogram().expect("Couldn't fit coverage model"); } + println!("Estimated cutoff\t{cutoff}"); } } let end = Instant::now(); From 4d5a1fe32f4f876ae1cd81ea5127b17af053ef9e Mon Sep 17 00:00:00 2001 From: John Lees Date: Sat, 3 Jun 2023 13:22:02 +0100 Subject: [PATCH 3/8] Calculate a cutoff --- src/coverage.rs | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/coverage.rs b/src/coverage.rs index 8f5fdb2..49a63b6 100644 --- a/src/coverage.rs +++ b/src/coverage.rs @@ -20,6 +20,8 @@ use crate::QualFilter; const MAX_COUNT: usize = 1000; const MIN_FREQ: u32 = 50; +const INIT_W0: f64 = 0.8f64; +const INIT_C: f64 = 20.0f64; pub struct CoverageHistogram { /// K-mer size @@ -35,7 +37,7 @@ pub struct CoverageHistogram { /// Estimated coverage c: f64, /// Coverage cutoff - cutoff: u32, + cutoff: usize, /// Show logging verbose: bool, } @@ -55,8 +57,8 @@ where rc, kmer_dict: HashMap::default(), counts: vec![0; MAX_COUNT], - w0: 0.8, - c: 20.0, + w0: INIT_W0, + c: INIT_C, cutoff: 0, verbose, }; @@ -111,7 +113,7 @@ where cov_counts } - pub fn fit_histogram(&mut self) -> Result { + pub fn fit_histogram(&mut self) -> Result { // Calculate k-mer histogram log::info!("Calculating k-mer histogram"); for kmer_count in self.kmer_dict.values() { @@ -130,6 +132,7 @@ where counts_f64.push(*hist_bin as f64); } } + let count_len = counts_f64.len(); log::info!("Fitting Poisson mixture model using maximum likelihood"); let mixture_fit = MixPoisson { counts: counts_f64 }; @@ -157,7 +160,8 @@ where self.w0 = best[0]; self.c = best[1]; - // TODO calculate the coverage cutoff + // calculate the coverage cutoff + self.cutoff = find_cutoff(best, count_len); Ok(self.cutoff) } else { Err(Error::msg(format!( @@ -257,3 +261,19 @@ fn grad_ll(pars: &[f64], counts: &[f64]) -> Vec { } vec![grad_w0, grad_c] } + +fn find_cutoff(pars: &[f64], max_cutoff: usize) -> usize { + let w0 = pars[0]; + let c = pars[1]; + + let mut cutoff = 1; + while cutoff < max_cutoff { + let cutoff_f64 = cutoff as f64; + let root = a(w0, cutoff_f64) - b(w0, c, cutoff_f64); + if root < 0.0 { + break; + } + cutoff += 1; + } + cutoff + } From 0cbe56882be9f714781eba86959177b2616f0db8 Mon Sep 17 00:00:00 2001 From: John Lees Date: Mon, 5 Jun 2023 18:02:05 +0100 Subject: [PATCH 4/8] Add docs and plotting output to coverage module --- src/coverage.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++----- src/lib.rs | 4 ++- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/src/coverage.rs b/src/coverage.rs index 49a63b6..6e984e4 100644 --- a/src/coverage.rs +++ b/src/coverage.rs @@ -1,3 +1,12 @@ +//! Tools for estimating a count cutoff with FASTQ input. +//! +//! This module has a basic k-mer counter using a dictionary, and then uses +//! maximum likelihood with some basic numerical optimisation to fit a two-component +//! mixture of Poissons to determine a coverage model. This can be used to classify +//! a count cutoff with noisy data. +//! +//! [`CoverageHistogram`] is the main interface. + use core::panic; use argmin::core::observers::{ObserverMode, SlogLogger}; @@ -23,6 +32,11 @@ const MIN_FREQ: u32 = 50; const INIT_W0: f64 = 0.8f64; const INIT_C: f64 = 20.0f64; +/// K-mer counts and a coverage model for a single sample, using a pair of FASTQ files as input +/// +/// Call [`CoverageHistogram::new()`] to count k-mers, then [`CoverageHistogram::fit_histogram()`] +/// to fit the model and find a cutoff. [`CoverageHistogram::plot_hist()`] can be used to +/// extract a table of the output for plotting purposes. pub struct CoverageHistogram { /// K-mer size k: usize, @@ -40,13 +54,18 @@ pub struct CoverageHistogram { cutoff: usize, /// Show logging verbose: bool, + /// Has the fit been run + fitted: bool, } impl CoverageHistogram where IntT: for<'a> UInt<'a>, { - // Called by lib.rs + /// Count split k-mers from a pair of input FASTQ files. + /// + /// Parameters the same as for [`crate::ska_dict::SkaDict`]. `verbose` will + /// also print to stderr on each iteration of the optiser. pub fn new(fastq1: &String, fastq2: &String, k: usize, rc: bool, verbose: bool) -> Self { if !(5..=63).contains(&k) || k % 2 == 0 { panic!("Invalid k-mer length"); @@ -61,6 +80,7 @@ where c: INIT_C, cutoff: 0, verbose, + fitted: false, }; // Check if we're working with reads first @@ -113,7 +133,22 @@ where cov_counts } + /// Fit the coverage model to the histogram of counts + /// + /// Returns the fitted cutoff if successful. + /// + /// # Errors + /// - If the optimiser didn't finish (reached 100 iterations or another problem). + /// - If the linesearch cannot be constructed (may be a bounds issue, or poor data). + /// - If the optimiser is still running (this shouldn't happen). + /// + /// # Panics + /// - If the fit has already been run pub fn fit_histogram(&mut self) -> Result { + if self.fitted { + panic!("Model already fitted"); + } + // Calculate k-mer histogram log::info!("Calculating k-mer histogram"); for kmer_count in self.kmer_dict.values() { @@ -162,6 +197,7 @@ where // calculate the coverage cutoff self.cutoff = find_cutoff(best, count_len); + self.fitted = true; Ok(self.cutoff) } else { Err(Error::msg(format!( @@ -174,8 +210,37 @@ where } } - pub fn plot_hist() { - todo!() + /// Prints the counts and model to stdout, for use in plotting. + /// + /// Creates a table with count, number of k-mers at that count, mixture + /// density, and most likely component. + /// Plot with the `plot_hist.py` helper script. + pub fn plot_hist(&self) { + if !self.fitted { + panic!("Model has not yet been fitted"); + } + + log::info!("Calculating and printing count series"); + println!("Count\tK_mers\tMixture_density\tComponent"); + for (idx, count) in self.counts.iter().enumerate() { + if *count < MIN_FREQ { + break; + } + println!( + "{}\t{}\t{:e}\t{}", + idx + 1, + *count, + f64::exp(lse( + a(self.w0, idx as f64 + 1.0), + b(self.w0, self.c, idx as f64 + 1.0) + )), + if (idx + 1) < self.cutoff { + "Error" + } else { + "Coverage" + } + ) + } } } @@ -234,11 +299,12 @@ fn log_likelihood(pars: &[f64], counts: &[f64]) -> f64 { let w0 = pars[0]; let c = pars[1]; let mut ll = 0.0; - if w0 > 1.0 || w0 < 0.0 || c < 1.0 { + if !(0.0..=1.0).contains(&w0) || c < 1.0 { ll = f64::MIN; } else { for (i, count) in counts.iter().enumerate() { - ll += *count * lse(a(w0, i as f64 + 1.0), b(w0, c, i as f64 + 1.0)); + let i_f64 = i as f64 + 1.0; + ll += *count * lse(a(w0, i_f64), b(w0, c, i_f64)); } } ll @@ -256,8 +322,8 @@ fn grad_ll(pars: &[f64], counts: &[f64]) -> Vec { let b_val = b(w0, c, i_f64); let dlda = 1.0 / (1.0 + f64::exp(b_val - a_val)); let dldb = 1.0 / (1.0 + f64::exp(a_val - b_val)); - grad_w0 += *count as f64 * (dlda / w0 - dldb / (1.0 - w0)); - grad_c += *count as f64 * (dldb * (i_f64 / c - 1.0)); + grad_w0 += *count * (dlda / w0 - dldb / (1.0 - w0)); + grad_c += *count * (dldb * (i_f64 / c - 1.0)); } vec![grad_w0, grad_c] } @@ -276,4 +342,4 @@ fn find_cutoff(pars: &[f64], max_cutoff: usize) -> usize { cutoff += 1; } cutoff - } +} diff --git a/src/lib.rs b/src/lib.rs index 0827149..3fcbf4d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -562,13 +562,15 @@ pub fn main() { let mut cov = CoverageHistogram::::new(fastq_fwd, fastq_rev, *k, rc, args.verbose); cutoff = cov.fit_histogram().expect("Couldn't fit coverage model"); + cov.plot_hist(); } else { log::info!("k={}: using 128-bit representation", *k); let mut cov = CoverageHistogram::::new(fastq_fwd, fastq_rev, *k, rc, args.verbose); cutoff = cov.fit_histogram().expect("Couldn't fit coverage model"); + cov.plot_hist(); } - println!("Estimated cutoff\t{cutoff}"); + eprintln!("Estimated cutoff\t{cutoff}"); } } let end = Instant::now(); From 6d4e700d33792cf1c8f492a24c5eb3626825d6a3 Mon Sep 17 00:00:00 2001 From: John Lees Date: Tue, 6 Jun 2023 10:32:28 +0100 Subject: [PATCH 5/8] Comment the code better --- src/coverage.rs | 36 ++++++++++++++++++++++++++++-------- src/lib.rs | 21 +++++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/coverage.rs b/src/coverage.rs index 6e984e4..f6f0f8c 100644 --- a/src/coverage.rs +++ b/src/coverage.rs @@ -96,6 +96,9 @@ where } } + // Count the k-mers using the same iterator as in SkaDict, and just use + // a dictionary rather than a special filter (~10-20s per sample) + // NB: Quality scores totally ignored here, could change this in future log::info!("Counting k-mers"); for fastx_file in [fastq1, fastq2] { let mut reader = parse_fastx_file(fastx_file) @@ -149,7 +152,7 @@ where panic!("Model already fitted"); } - // Calculate k-mer histogram + // Calculate k-mer histogram from k-mer counts log::info!("Calculating k-mer histogram"); for kmer_count in self.kmer_dict.values() { let kc = (*kmer_count - 1) as usize; @@ -169,12 +172,17 @@ where } let count_len = counts_f64.len(); + // Fit with maximum likelihood. Using BFGS optimiser and simple line search + // seems to work fine log::info!("Fitting Poisson mixture model using maximum likelihood"); let mixture_fit = MixPoisson { counts: counts_f64 }; let init_param: Vec = vec![self.w0, self.c]; + // This is required. I tried the numerical Hessian but the scale was wrong + // and it gave very poor results for the c optimisation let init_hessian: Vec> = vec![vec![1.0, 0.0], vec![0.0, 1.0]]; let linesearch = BacktrackingLineSearch::new(ArmijoCondition::new(0.0001f64)?); let solver = BFGS::new(linesearch); + // Usually around 10 iterations should be enough let mut exec = Executor::new(mixture_fit, solver).configure(|state| { state .param(init_param) @@ -186,7 +194,7 @@ where } let res = exec.run()?; - // print diagnostics + // Print diagnostics log::info!("{res}"); if let Some(termination_reason) = res.state().get_termination_reason() { if *termination_reason == SolverConverged { @@ -244,22 +252,27 @@ where } } +// Helper struct for optimisation which keep counts as state struct MixPoisson { counts: Vec, } +// These just use Vec rather than ndarray, simpler packaging and doubt +// there's any performance difference with two params +// negative log-likelihood impl CostFunction for MixPoisson { /// Type of the parameter vector type Param = Vec; /// Type of the return value computed by the cost function type Output = f64; - /// Apply the cost function to a parameter `p` + /// Apply the cost function to a parameters `p` fn cost(&self, p: &Self::Param) -> Result { Ok(-log_likelihood(p, &self.counts)) } } +// negative grad(ll) impl Gradient for MixPoisson { /// Type of the parameter vector type Param = Vec; @@ -268,12 +281,13 @@ impl Gradient for MixPoisson { /// Compute the gradient at parameter `p`. fn gradient(&self, p: &Self::Param) -> Result { - // Compute gradient of 2D Rosenbrock function + // As doing minimisation, need to invert sign of gradients Ok(grad_ll(p, &self.counts).iter().map(|x| -*x).collect()) } } -// log-sum-exp +// log-sum-exp needed to combine components likelihoods +// (hard coded as two here, of course could be generalised to N) fn lse(a: f64, b: f64) -> f64 { let xstar = f64::max(a, b); xstar + f64::ln(f64::exp(a - xstar) + f64::exp(b - xstar)) @@ -284,21 +298,23 @@ fn ln_dpois(x: f64, lambda: f64) -> f64 { x * f64::ln(lambda) - lgamma(x + 1.0) - lambda } -// error component +// error component (mean of 1) fn a(w0: f64, i: f64) -> f64 { f64::ln(w0) + ln_dpois(i, 1.0) } -// coverage component +// coverage component (mean of coverage) fn b(w0: f64, c: f64, i: f64) -> f64 { f64::ln(1.0 - w0) + ln_dpois(i, c) } -// Mixture likelihood +// Mixture model likelihood fn log_likelihood(pars: &[f64], counts: &[f64]) -> f64 { let w0 = pars[0]; let c = pars[1]; let mut ll = 0.0; + // 'soft' bounds. I think f64::NEG_INFINITY might be mathematically better + // but arg_min doesn't like it if !(0.0..=1.0).contains(&w0) || c < 1.0 { ll = f64::MIN; } else { @@ -310,6 +326,8 @@ fn log_likelihood(pars: &[f64], counts: &[f64]) -> f64 { ll } +// Analytic gradient. Bounds not needed as this is only evaluated +// when the ll is valid fn grad_ll(pars: &[f64], counts: &[f64]) -> Vec { let w0 = pars[0]; let c = pars[1]; @@ -328,6 +346,8 @@ fn grad_ll(pars: &[f64], counts: &[f64]) -> Vec { vec![grad_w0, grad_c] } +// Root finder at integer steps -- when is the responsibility of +// the b component higher than the a component fn find_cutoff(pars: &[f64], max_cutoff: usize) -> usize { let w0 = pars[0]; let c = pars[1]; diff --git a/src/lib.rs b/src/lib.rs index 3fcbf4d..698d03c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -233,6 +233,27 @@ //! These are not precisely canonical k-mers, as the encoding order `{A, C, T, G}` is used internally. //! But if you can't find a sequence in your input file, you will find its reverse complement. //! +//! ## ska cov +//! +//! Estimate a coverage cutoff for use with read data. This will count the split +//! k-mers in a pair of FASTQ samples, and create a histogram of these counts. +//! A mixture model is then fitted to this histogram using maximum likelihood, +//! which can give a suggested cutoff with noisy data. +//! +//! The cutoff will be printed to STDERR. Use `-v` to get additional information on the +//! optimisation progress and result. A table of counts and the fit will be printed +//! to STDOUT, which can then be plotted by the companion script in +//! `scripts/plot_cov.py` (requires `matplotlib`): +//! ```bash +//! ska cov reads_1.fastq.gz reads_2.fastq.gz -k 31 -v > cov_plot.txt +//! python scripts/plot_cov.py cov_plot.txt +//! ``` +//! +//! The cutoff can be used with the `--min-count` parameter of `ska build`. For +//! a set of sequence experiments with similar characteristics it may be sufficient +//! to use the same cutoff. Alternatively `ska cov` can be run on every sample +//! independently (`gnu parallel` would be an efficient way to do this). +//! //! # API usage //! //! See the submodule documentation linked below. From fd3e73f2208e364aec7ab9c64e9f888357df1cd4 Mon Sep 17 00:00:00 2001 From: John Lees Date: Tue, 6 Jun 2023 12:04:14 +0100 Subject: [PATCH 6/8] Add histogram plotting code --- scripts/plot_cov.py | 79 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 scripts/plot_cov.py diff --git a/scripts/plot_cov.py b/scripts/plot_cov.py new file mode 100644 index 0000000..d664ddc --- /dev/null +++ b/scripts/plot_cov.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# vim: set fileencoding= : + +import matplotlib as mpl +mpl.use('Agg') +mpl.rcParams.update({'font.size': 18}) +import matplotlib.pyplot as plt +from math import log + +import argparse + +def norm_series(series: list()): + norm = max(series) + series = [item/norm for item in series] + return series + +def main(): + parser = argparse.ArgumentParser( + prog='plot_cov', + description='Plot the results of `ska cov`', + epilog='Requires matplotlib') + parser.add_argument('histfile', help="Input table (stdout from `ska cov`)") + parser.add_argument('--output', help="Output prefix", default="coverage_histogram") + args = parser.parse_args() + + cutoff = 0 + kmers = list() + density = list() + idx_xseries = list() + with open(args.histfile, 'r') as hfile: + hfile.readline() + for line in hfile: + (idx, count, ll, comp) = line.rstrip().split("\t") + kmers.append(int(count)) + density.append(float(ll)) + idx_xseries.append(int(idx)) + if comp == "Coverage" and cutoff == 0: + cutoff = int(idx) + + k_norm = norm_series(kmers) + #d_norm = norm_series(density) + + fig, (ax1, ax2) = plt.subplots(2) + fig.suptitle('Coverage histogram fit') + fig.set_dpi(160) + fig.set_facecolor('w') + fig.set_edgecolor('k') + fig.set_figwidth(11) + fig.set_figheight(11) + plt.tight_layout() + + ax1.set_xlabel('K-mer count') + ax1.set_ylabel('Frequency') + ax1.set_ylim(0, k_norm[1]) + ax1.plot(idx_xseries, k_norm, color='black', linewidth=2, + label='K-mer count frequency') + ax1.plot(idx_xseries, density, color='red', linewidth=2, linestyle='--', + label='Mixture model fit') + ax1.plot([cutoff, cutoff], [0, 1], color='darkgray', linewidth=1, + linestyle='-.', label=f'Count cutoff ({cutoff})') + ax1.legend(loc='upper right') + + ax2.set_yscale("log") + ax2.set_xlabel('K-mer count') + ax2.set_ylabel('log(Frequency)') + ax2.set_ylim(min(k_norm), k_norm[1]) + ax2.plot(idx_xseries, k_norm, color='black', linewidth=2, + label='K-mer count frequency') + ax2.plot(idx_xseries, density, color='red', linewidth=2, linestyle='--', + label='Mixture model fit') + ax2.plot([cutoff, cutoff], [0, 1], color='darkgray', linewidth=1, + linestyle='-.', label=f'Count cutoff ({cutoff})') + + plt.savefig(args.output + ".png", + bbox_inches='tight') + plt.close() + +if __name__ == '__main__': + main() From 7ae07dc7bc667b5d3a617660358a393ceb345284 Mon Sep 17 00:00:00 2001 From: John Lees Date: Tue, 6 Jun 2023 14:57:07 +0100 Subject: [PATCH 7/8] Add tests for coverage mode --- src/coverage.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ src/ska_dict.rs | 4 ++-- tests/fastq_input.rs | 17 +++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/coverage.rs b/src/coverage.rs index f6f0f8c..01cedd0 100644 --- a/src/coverage.rs +++ b/src/coverage.rs @@ -37,6 +37,7 @@ const INIT_C: f64 = 20.0f64; /// Call [`CoverageHistogram::new()`] to count k-mers, then [`CoverageHistogram::fit_histogram()`] /// to fit the model and find a cutoff. [`CoverageHistogram::plot_hist()`] can be used to /// extract a table of the output for plotting purposes. +#[derive(Default, Debug)] pub struct CoverageHistogram { /// K-mer size k: usize, @@ -363,3 +364,45 @@ fn find_cutoff(pars: &[f64], max_cutoff: usize) -> usize { } cutoff } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fit_histogram() { + // Initialize the test object + let mut test_obj = CoverageHistogram:: { + k: 31, + rc: true, + kmer_dict: HashMap::default(), + counts: vec![44633459, 950672, 104410, 44137, 24170, 21232, 21699, + 24145, 30696, 39210, 49878, 63683, 77690, 95147, 112416, + 130307, 146531, 160932, 175130, 185113, 193149, 197468, + 199189, 198235, 192150, 185565, 176362, 165455, 152487, + 139495, 127036, 112803, 103080, 90425, 80637, 70960, 62698, + 54949, 46744, 41240, 35591, 30025, 25856, 22105, + 19405, 16668, 14780, 12620, 11074, 9807, 8517, 7731, 7112, + 6846, 6126, 5696, 5233, 4779, 4288, 3873, 3519, 3406, 2994, + 2859, 2650, 2394, 2376, 2260, 2233, 2050, 1859, 1863, 1792, + 1777, 1773, 1738, 1648], + w0: INIT_W0, + c: INIT_C, + cutoff: 0, + verbose: false, + fitted: false, + }; + + let cutoff = test_obj.fit_histogram(); + assert_eq!(cutoff.is_ok(), true); + assert_eq!(cutoff.unwrap(), 9); + } + + #[test] + #[should_panic] + fn print_before_fit() { + let test_obj = CoverageHistogram::::default(); + test_obj.plot_hist(); + } + +} diff --git a/src/ska_dict.rs b/src/ska_dict.rs index fd8f6a4..f57ea94 100644 --- a/src/ska_dict.rs +++ b/src/ska_dict.rs @@ -263,8 +263,8 @@ mod tests { #[test] fn test_add_palindrome_to_dict() { - // Initialize the test subject - let mut test_obj = SkaDict::::default(); // Replace YourStruct with the actual struct name + // Initialize the test object + let mut test_obj = SkaDict::::default(); // Test case 1: Updating existing entry test_obj.split_kmers.insert(123, b'W'); diff --git a/tests/fastq_input.rs b/tests/fastq_input.rs index 5701290..dadf0de 100644 --- a/tests/fastq_input.rs +++ b/tests/fastq_input.rs @@ -469,3 +469,20 @@ fn error_fastq() { assert_eq!(var_hash(&fastq_align_out_quality), all_hash); } + +// Just checks that counter runs, model fit is in a unit test +#[test] +fn cov_check() { + let sandbox = TestSetup::setup(); + + Command::new(cargo_bin("ska")) + .current_dir(sandbox.get_wd()) + .arg("cov") + .arg(sandbox.file_string("test_1_fwd.fastq.gz", TestDir::Input)) + .arg(sandbox.file_string("test_1_rev.fastq.gz", TestDir::Input)) + .arg("-k") + .arg("9") + .arg("-v") + .assert() + .success(); +} \ No newline at end of file From 899e5fd58c7e22beebe72e8c27c0d0a0eb2ad196 Mon Sep 17 00:00:00 2001 From: John Lees Date: Tue, 6 Jun 2023 15:31:16 +0100 Subject: [PATCH 8/8] Increase test coverage --- src/coverage.rs | 73 ++++++++++++++++++++++++++++++++++++++------ tests/fastq_input.rs | 23 ++++++++++++++ 2 files changed, 86 insertions(+), 10 deletions(-) diff --git a/src/coverage.rs b/src/coverage.rs index 01cedd0..ee5aef9 100644 --- a/src/coverage.rs +++ b/src/coverage.rs @@ -369,23 +369,28 @@ fn find_cutoff(pars: &[f64], max_cutoff: usize) -> usize { mod tests { use super::*; + + #[test] fn test_fit_histogram() { + let example_counts = + vec![44633459, 950672, 104410, 44137, 24170, 21232, 21699, + 24145, 30696, 39210, 49878, 63683, 77690, 95147, 112416, + 130307, 146531, 160932, 175130, 185113, 193149, 197468, + 199189, 198235, 192150, 185565, 176362, 165455, 152487, + 139495, 127036, 112803, 103080, 90425, 80637, 70960, 62698, + 54949, 46744, 41240, 35591, 30025, 25856, 22105, + 19405, 16668, 14780, 12620, 11074, 9807, 8517, 7731, 7112, + 6846, 6126, 5696, 5233, 4779, 4288, 3873, 3519, 3406, 2994, + 2859, 2650, 2394, 2376, 2260, 2233, 2050, 1859, 1863, 1792, + 1777, 1773, 1738, 1648]; + // Initialize the test object let mut test_obj = CoverageHistogram:: { k: 31, rc: true, kmer_dict: HashMap::default(), - counts: vec![44633459, 950672, 104410, 44137, 24170, 21232, 21699, - 24145, 30696, 39210, 49878, 63683, 77690, 95147, 112416, - 130307, 146531, 160932, 175130, 185113, 193149, 197468, - 199189, 198235, 192150, 185565, 176362, 165455, 152487, - 139495, 127036, 112803, 103080, 90425, 80637, 70960, 62698, - 54949, 46744, 41240, 35591, 30025, 25856, 22105, - 19405, 16668, 14780, 12620, 11074, 9807, 8517, 7731, 7112, - 6846, 6126, 5696, 5233, 4779, 4288, 3873, 3519, 3406, 2994, - 2859, 2650, 2394, 2376, 2260, 2233, 2050, 1859, 1863, 1792, - 1777, 1773, 1738, 1648], + counts: example_counts.clone(), w0: INIT_W0, c: INIT_C, cutoff: 0, @@ -396,6 +401,22 @@ mod tests { let cutoff = test_obj.fit_histogram(); assert_eq!(cutoff.is_ok(), true); assert_eq!(cutoff.unwrap(), 9); + + // The other template + let mut test_obj = CoverageHistogram:: { + k: 33, + rc: true, + kmer_dict: HashMap::default(), + counts: example_counts.clone(), + w0: INIT_W0, + c: INIT_C, + cutoff: 0, + verbose: false, + fitted: false, + }; + + test_obj.fit_histogram().unwrap(); + test_obj.plot_hist(); } #[test] @@ -405,4 +426,36 @@ mod tests { test_obj.plot_hist(); } + #[test] + #[should_panic] + fn double_fit() { + let example_counts = + vec![44633459, 950672, 104410, 44137, 24170, 21232, 21699, + 24145, 30696, 39210, 49878, 63683, 77690, 95147, 112416, + 130307, 146531, 160932, 175130, 185113, 193149, 197468, + 199189, 198235, 192150, 185565, 176362, 165455, 152487, + 139495, 127036, 112803, 103080, 90425, 80637, 70960, 62698, + 54949, 46744, 41240, 35591, 30025, 25856, 22105, + 19405, 16668, 14780, 12620, 11074, 9807, 8517, 7731, 7112, + 6846, 6126, 5696, 5233, 4779, 4288, 3873, 3519, 3406, 2994, + 2859, 2650, 2394, 2376, 2260, 2233, 2050, 1859, 1863, 1792, + 1777, 1773, 1738, 1648]; + + // Initialize the test object + let mut test_obj = CoverageHistogram:: { + k: 31, + rc: true, + kmer_dict: HashMap::default(), + counts: example_counts.clone(), + w0: INIT_W0, + c: INIT_C, + cutoff: 0, + verbose: false, + fitted: false, + }; + + test_obj.fit_histogram().unwrap(); + test_obj.fit_histogram().unwrap(); + } + } diff --git a/tests/fastq_input.rs b/tests/fastq_input.rs index dadf0de..085fdb7 100644 --- a/tests/fastq_input.rs +++ b/tests/fastq_input.rs @@ -485,4 +485,27 @@ fn cov_check() { .arg("-v") .assert() .success(); + + Command::new(cargo_bin("ska")) + .current_dir(sandbox.get_wd()) + .arg("cov") + .arg(sandbox.file_string("test_long_1_fwd.fastq.gz", TestDir::Input)) + .arg(sandbox.file_string("test_long_1_rev.fastq.gz", TestDir::Input)) + .arg("-k") + .arg("33") + .arg("-v") + .assert() + .success(); + + // Doesn't run with fasta + Command::new(cargo_bin("ska")) + .current_dir(sandbox.get_wd()) + .arg("cov") + .arg(sandbox.file_string("test_1.fa", TestDir::Input)) + .arg(sandbox.file_string("test_2.fa", TestDir::Input)) + .arg("-k") + .arg("9") + .arg("-v") + .assert() + .failure(); } \ No newline at end of file