Skip to content

Commit

Permalink
refactor(derive/strandedness): separate out results from compute
Browse files Browse the repository at this point in the history
  • Loading branch information
a-frantz committed Feb 9, 2024
1 parent 3c15c10 commit 7f95d34
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 222 deletions.
42 changes: 20 additions & 22 deletions src/derive/command/strandedness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use noodles::gff;
use rust_lapper::{Interval, Lapper};
use tracing::debug;
use tracing::info;
use tracing::warn;

use crate::derive::strandedness::compute;
use crate::derive::strandedness::compute::ParsedBAMFile;
use crate::derive::strandedness::results;
use crate::utils::formats;

/// Clap arguments for the `ngs derive strandedness` subcommand.
Expand Down Expand Up @@ -95,8 +95,8 @@ pub fn derive(args: DeriveStrandednessArgs) -> anyhow::Result<()> {

let mut gene_records = Vec::new();
let mut exon_records = Vec::new();
let mut gene_metrics = compute::GeneRecordMetrics::default();
let mut exon_metrics = compute::ExonRecordMetrics::default();
let mut gene_metrics = results::GeneRecordMetrics::default();
let mut exon_metrics = results::ExonRecordMetrics::default();
for result in gff.records() {
let record = result.unwrap();
if record.ty() == args.gene_feature_name {
Expand Down Expand Up @@ -214,17 +214,17 @@ pub fn derive(args: DeriveStrandednessArgs) -> anyhow::Result<()> {
counts: HashMap::new(),
found_rgs: HashSet::new(),
};
let mut metrics = compute::RecordTracker {
let mut metrics = results::RecordTracker {
genes: gene_metrics,
exons: exon_metrics,
reads: compute::ReadRecordMetrics::default(),
reads: results::ReadRecordMetrics::default(),
};

let mut result: compute::DerivedStrandednessResult;
let mut result: Option<results::DerivedStrandednessResult> = None;
for try_num in 1..=args.max_tries {
info!("Starting try {} of {}", try_num, args.max_tries);

result = compute::predict(
let attempt = compute::predict(
&mut parsed_bam,
&mut gene_records,
&exons,
Expand All @@ -233,25 +233,23 @@ pub fn derive(args: DeriveStrandednessArgs) -> anyhow::Result<()> {
&mut metrics,
)?;

if result.succeeded {
if attempt.succeeded {
info!("Strandedness test succeeded.");

// (#) Print the output to stdout as JSON (more support for different output
// types may be added in the future, but for now, only JSON).
let output = serde_json::to_string_pretty(&result).unwrap();
print!("{}", output);
break;
} else {
warn!("Strandedness test inconclusive.");

if try_num >= args.max_tries {
info!("Strandedness test failed after {} tries.", args.max_tries);
let output = serde_json::to_string_pretty(&result).unwrap();
print!("{}", output);
break;
}
info!("Strandedness test inconclusive.");
}
result = Some(attempt);
}
let result = result.unwrap();

if !result.succeeded {
info!("Strandedness test failed after {} tries.", args.max_tries);
}

// (4) Print the output to stdout as JSON (more support for different output
// types may be added in the future, but for now, only JSON).
let output = serde_json::to_string_pretty(&result).unwrap();
print!("{}", output);

anyhow::Ok(())
}
1 change: 1 addition & 0 deletions src/derive/strandedness.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Supporting functionality for the `ngs derive strandedness` subcommand.
pub mod compute;
pub mod results;
213 changes: 13 additions & 200 deletions src/derive/strandedness/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,86 +7,16 @@ use noodles::sam;
use noodles::sam::record::data::field::Tag;
use rand::Rng;
use rust_lapper::Lapper;
use serde::Serialize;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;

use crate::derive::strandedness::results;
use crate::utils::read_groups::{validate_read_group_info, UNKNOWN_READ_GROUP};

const STRANDED_THRESHOLD: f64 = 80.0;
const UNSTRANDED_THRESHOLD: f64 = 40.0;

/// General gene metrics that are tallied as a part of the
/// strandedness subcommand.
#[derive(Clone, Default, Serialize, Debug)]
pub struct GeneRecordMetrics {
/// The total number of genes found in the GFF.
pub total: usize,

/// The number of genes that were found to be protein coding.
/// If --all-genes is set this will not be tallied.
pub protein_coding: usize,

/// The number of genes tested.
pub tested: usize,

/// The number of genes which were discarded due to having
/// an unknown/invalid strand OR with exons on both strands.
pub bad_strands: usize,

/// The number of genes which were discarded due to not having
/// enough reads.
pub not_enough_reads: usize,
}

/// General exon metrics that are tallied as a part of the
/// strandedness subcommand.
#[derive(Clone, Default, Serialize, Debug)]
pub struct ExonRecordMetrics {
/// The total number of exons found in the GFF.
pub total: usize,

/// The number of exons discarded due to having an unknown/invalid strand.
pub bad_strand: usize,
}

/// General read record metrics that are tallied as a part of the
/// strandedness subcommand.
#[derive(Clone, Default, Serialize, Debug)]
pub struct ReadRecordMetrics {
/// The number of records that have been filtered because of their flags.
/// (i.e. they were qc_fail, duplicates, secondary, or supplementary)
/// These conditions can be toggled on/off with CL flags
pub filtered_by_flags: usize,

/// The number of records that have been filtered because
/// they failed the MAPQ filter.
pub low_mapq: usize,

/// The number of records whose MAPQ couldn't be parsed and were thus discarded.
pub missing_mapq: usize,

/// The number of records determined to be Paired-End.
pub paired_end_reads: usize,

/// The number of records determined to be Single-End.
pub single_end_reads: usize,
}

/// Struct for managing record tracking.
#[derive(Clone, Default, Debug)]
pub struct RecordTracker {
/// Gene metrics.
pub genes: GeneRecordMetrics,

/// Exon metrics.
pub exons: ExonRecordMetrics,

/// Read metrics.
pub reads: ReadRecordMetrics,
}

/// Struct for tracking count results.
#[derive(Clone, Default)]
pub struct Counts {
Expand All @@ -97,126 +27,6 @@ pub struct Counts {
reverse: usize,
}

/// Struct holding the per read group results for an `ngs derive strandedness`
/// subcommand call.
#[derive(Debug, Serialize)]
pub struct ReadGroupDerivedStrandednessResult {
/// Name of the read group.
pub read_group: String,

/// Whether or not strandedness was determined for this read group.
pub succeeded: bool,

/// The strandedness of this read group or "Inconclusive".
pub strandedness: String,

/// The total number of reads in this read group.
pub total: usize,

/// The number of reads that are evidence of Forward Strandedness.
pub forward: usize,

/// The number of reads that are evidence of Reverse Strandedness.
pub reverse: usize,

/// The percent of evidence for Forward Strandedness.
pub forward_pct: f64,

/// The percent of evidence for Reverse Strandedness.
pub reverse_pct: f64,
}

impl ReadGroupDerivedStrandednessResult {
/// Creates a new [`ReadGroupDerivedStrandednessResult`].
fn new(
read_group: String,
succeeded: bool,
strandedness: String,
forward: usize,
reverse: usize,
) -> Self {
ReadGroupDerivedStrandednessResult {
read_group,
succeeded,
strandedness,
total: forward + reverse,
forward,
reverse,
forward_pct: (forward as f64 / (forward + reverse) as f64) * 100.0,
reverse_pct: (reverse as f64 / (forward + reverse) as f64) * 100.0,
}
}
}

/// Struct holding the final results for an `ngs derive strandedness` subcommand
/// call.
#[derive(Debug, Serialize)]
pub struct DerivedStrandednessResult {
/// Whether or not the `ngs derive strandedness` subcommand succeeded.
pub succeeded: bool,

/// The strandedness of this read group or "Inconclusive".
pub strandedness: String,

/// The total number of reads.
pub total: usize,

/// The number of reads that are evidence of Forward Strandedness.
pub forward: usize,

/// The number of reads that are evidence of Reverse Strandedness.
pub reverse: usize,

/// The percent of evidence for Forward Strandedness.
pub forward_pct: f64,

/// The percent of evidence for Reverse Strandedness.
pub reverse_pct: f64,

/// Vector of [`ReadGroupDerivedStrandednessResult`]s.
/// One for each read group in the BAM,
/// and potentially one for any reads with an unknown read group.
pub read_groups: Vec<ReadGroupDerivedStrandednessResult>,

/// General read record metrics that are tallied as a part of the
/// strandedness subcommand.
pub read_metrics: ReadRecordMetrics,

/// General gene metrics that are tallied as a part of the
/// strandedness subcommand.
pub gene_metrics: GeneRecordMetrics,

/// General exon metrics that are tallied as a part of the
/// strandedness subcommand.
pub exon_metrics: ExonRecordMetrics,
}

impl DerivedStrandednessResult {
/// Creates a new [`DerivedStrandednessResult`].
fn new(
succeeded: bool,
strandedness: String,
forward: usize,
reverse: usize,
read_groups: Vec<ReadGroupDerivedStrandednessResult>,
metrics: RecordTracker,
) -> Self {
DerivedStrandednessResult {
succeeded,
strandedness,
total: forward + reverse,
forward,
reverse,
forward_pct: (forward as f64 / (forward + reverse) as f64) * 100.0,
reverse_pct: (reverse as f64 / (forward + reverse) as f64) * 100.0,
read_groups,
read_metrics: metrics.reads,
gene_metrics: metrics.genes,
exon_metrics: metrics.exons,
}
}
}

#[derive(Clone, Copy, Debug)]
enum Strand {
Forward,
Expand Down Expand Up @@ -349,7 +159,7 @@ fn query_and_filter(
parsed_bam: &mut ParsedBAMFile,
gene: &gff::Record,
params: &StrandednessParams,
read_metrics: &mut ReadRecordMetrics,
read_metrics: &mut results::ReadRecordMetrics,
) -> Vec<sam::alignment::Record> {
let start = gene.start();
let end = gene.end();
Expand Down Expand Up @@ -405,7 +215,7 @@ fn classify_read(
read: &sam::alignment::Record,
gene_strand: &Strand,
all_counts: &mut AllReadGroupsCounts,
read_metrics: &mut ReadRecordMetrics,
read_metrics: &mut results::ReadRecordMetrics,
) {
let read_group = match read.data().get(Tag::ReadGroup) {
Some(rg) => {
Expand Down Expand Up @@ -455,9 +265,12 @@ fn classify_read(
}

/// Method to predict the strandedness of a read group.
fn predict_strandedness(rg_name: &str, counts: &Counts) -> ReadGroupDerivedStrandednessResult {
fn predict_strandedness(
rg_name: &str,
counts: &Counts,
) -> results::ReadGroupDerivedStrandednessResult {
if counts.forward == 0 && counts.reverse == 0 {
return ReadGroupDerivedStrandednessResult {
return results::ReadGroupDerivedStrandednessResult {
read_group: rg_name.to_string(),
succeeded: false,
strandedness: "Inconclusive".to_string(),
Expand All @@ -468,7 +281,7 @@ fn predict_strandedness(rg_name: &str, counts: &Counts) -> ReadGroupDerivedStran
reverse_pct: 0.0,
};
}
let mut result = ReadGroupDerivedStrandednessResult::new(
let mut result = results::ReadGroupDerivedStrandednessResult::new(
rg_name.to_string(),
false,
"Inconclusive".to_string(),
Expand Down Expand Up @@ -500,8 +313,8 @@ pub fn predict(
exons: &HashMap<&str, Lapper<usize, gff::record::Strand>>,
all_counts: &mut AllReadGroupsCounts,
params: &StrandednessParams,
metrics: &mut RecordTracker,
) -> Result<DerivedStrandednessResult, anyhow::Error> {
metrics: &mut results::RecordTracker,
) -> Result<results::DerivedStrandednessResult, anyhow::Error> {
let mut rng = rand::thread_rng();
let mut num_tested_genes: usize = 0; // Local to this attempt
let genes_remaining = gene_records.len();
Expand Down Expand Up @@ -578,7 +391,7 @@ pub fn predict(
}

let overall_result = predict_strandedness("overall", &overall_counts);
let final_result = DerivedStrandednessResult::new(
let final_result = results::DerivedStrandednessResult::new(
overall_result.succeeded,
overall_result.strandedness,
overall_result.forward,
Expand Down Expand Up @@ -650,7 +463,7 @@ mod tests {
counts: HashMap::new(),
found_rgs: HashSet::new(),
};
let mut read_metrics = ReadRecordMetrics::default();
let mut read_metrics = results::ReadRecordMetrics::default();
let counts_key = Arc::new("rg1".to_string());
let rg_tag = sam::record::data::field::Value::String("rg1".to_string());

Expand Down
Loading

0 comments on commit 7f95d34

Please sign in to comment.