From f52345c429deab7df4b7f79379ddbaa28dca0a74 Mon Sep 17 00:00:00 2001 From: Jiahao Yuan Date: Sat, 1 Jun 2024 19:58:09 +0800 Subject: [PATCH] perf: specialized iir filter (#174) --- Cargo.lock | 18 +-- Cargo.toml | 1 - src/pulse.rs | 32 +---- src/pulse/iir.rs | 315 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 319 insertions(+), 47 deletions(-) create mode 100644 src/pulse/iir.rs diff --git a/Cargo.lock b/Cargo.lock index ea61b00..dc426bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,15 +38,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" -[[package]] -name = "biquad" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "820524f5e3e3add696ddf69f79575772e152c0e78e9f0370b56990a7e808ec3e" -dependencies = [ - "libm 0.1.4", -] - [[package]] name = "bitflags" version = "1.3.2" @@ -58,7 +49,6 @@ name = "bosing" version = "0.0.0-dev" dependencies = [ "anyhow", - "biquad", "bspline", "cached", "float-cmp", @@ -284,12 +274,6 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" -[[package]] -name = "libm" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fc7aa29613bd6a620df431842069224d8bc9011086b1db4c0e0cd47fa03ec9a" - [[package]] name = "libm" version = "0.2.8" @@ -543,7 +527,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "762ad20f6b65f5a33874de3b7506d6cf6631d37fb9c0c6ee0e310d7210a167d9" dependencies = [ "bytemuck", - "libm 0.2.8", + "libm", "num-complex", "reborrow", ] diff --git a/Cargo.toml b/Cargo.toml index 955e390..2097fea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0.82" -biquad = "0.4.2" bspline = "1.1.0" cached = "0.50.0" float-cmp = "0.9.0" diff --git a/src/pulse.rs b/src/pulse.rs index db84e13..20f2e24 100644 --- a/src/pulse.rs +++ b/src/pulse.rs @@ -1,10 +1,11 @@ +mod iir; + use std::{ ops::{Add, Mul}, sync::Arc, }; use anyhow::{bail, Context, Result}; -use biquad::Biquad as _; use cached::proc_macro::cached; use float_cmp::approx_eq; use hashbrown::HashMap; @@ -480,34 +481,7 @@ pub(crate) fn apply_offset_inplace(waveform: &mut ArrayViewMut2, offset: Ar } pub(crate) fn apply_iir_inplace(waveform: &mut ArrayViewMut2, sos: ArrayView2) { - let mut biquads: Vec<_> = sos - .axis_iter(Axis(0)) - .map(|row| { - let b0 = row[0]; - let b1 = row[1]; - let b2 = row[2]; - let a1 = row[4]; - let a2 = row[5]; - let coef = biquad::Coefficients { b0, a1, a2, b1, b2 }; - biquad::DirectForm2Transposed::::new(coef) - }) - .collect(); - for mut row in waveform.axis_iter_mut(Axis(0)) { - apply_iir_inplace_1d(row.as_slice_mut().unwrap(), &mut biquads); - } -} - -fn apply_iir_inplace_1d(waveform: &mut [f64], biquads: &mut [biquad::DirectForm2Transposed]) { - for biquad in biquads.iter_mut() { - biquad.reset_state(); - } - for y in waveform.iter_mut() { - let mut x = *y; - for biquad in biquads.iter_mut() { - x = biquad.run(x); - } - *y = x; - } + self::iir::iir_filter_inplace(waveform.view_mut(), sos).unwrap() } pub(crate) fn apply_fir_inplace(waveform: &mut ArrayViewMut2, taps: ArrayView1) { diff --git a/src/pulse/iir.rs b/src/pulse/iir.rs new file mode 100644 index 0000000..71ca2ce --- /dev/null +++ b/src/pulse/iir.rs @@ -0,0 +1,315 @@ +use std::{ + array, + ops::{Add, Mul, Sub}, +}; + +use ndarray::{ArrayView1, ArrayView2, ArrayViewMut2}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub(crate) enum Error { + #[error("Invalid SOS format")] + InvalidSosFormat, +} + +type Result = std::result::Result; + +#[derive(Debug, Clone, Copy)] +struct BiquadCoefficients { + b0: T, + b1: T, + b2: T, + a1: T, + a2: T, +} + +#[derive(Debug)] +struct Biquad { + coefficients: BiquadCoefficients, + s1: T, + s2: T, +} + +#[derive(Debug)] +struct Iir { + biquads: Vec>, +} + +#[derive(Debug)] +struct IirPipeline { + b0: [T; N], + b1: [T; N], + b2: [T; N], + a1: [T; N], + a2: [T; N], + s1: [T; N], + s2: [T; N], + y: [T; N], +} + +impl Biquad { + fn new(coefficients: BiquadCoefficients) -> Self { + Self { + coefficients, + s1: Default::default(), + s2: Default::default(), + } + } + + fn reset(&mut self) { + self.s1 = Default::default(); + self.s2 = Default::default(); + } +} + +impl Biquad +where + T: Add + Mul + Sub + Copy, +{ + fn run(&mut self, x: T) -> T { + let y = self.coefficients.b0 * x + self.s1; + self.s1 = self.coefficients.b1 * x - self.coefficients.a1 * y + self.s2; + self.s2 = self.coefficients.b2 * x - self.coefficients.a2 * y; + y + } +} + +impl Iir { + fn reset(&mut self) { + for biquad in &mut self.biquads { + biquad.reset(); + } + } +} + +impl Iir +where + T: Add + Mul + Sub + Copy, +{ + fn run(&mut self, x: T) -> T { + let mut y = x; + for biquad in &mut self.biquads { + y = biquad.run(y); + } + y + } + + fn filter_inplace(&mut self, x: &mut [T]) { + for x in x.iter_mut() { + *x = self.run(*x); + } + } +} + +impl IirPipeline +where + [T; N]: Default, +{ + fn reset(&mut self) { + self.s1 = Default::default(); + self.s2 = Default::default(); + self.y = Default::default(); + } +} + +impl IirPipeline +where + T: Add + Mul + Sub + Copy + Default, +{ + fn run(&mut self, x: T) -> T { + let res = self.y[N - 1]; + for i in (0..N).rev() { + let x = if i == 0 { x } else { self.y[i - 1] }; + let y = self.b0[i] * x + self.s1[i]; + self.s1[i] = self.b1[i] * x - self.a1[i] * y + self.s2[i]; + self.s2[i] = self.b2[i] * x - self.a2[i] * y; + self.y[i] = y; + } + res + } + + fn filter_inplace(&mut self, signal: &mut [T]) { + for i in 0..signal.len() + N { + let x = if i < signal.len() { + signal[i] + } else { + Default::default() + }; + let y = self.run(x); + if i >= N { + signal[i - N] = y; + } + } + } +} + +impl<'a, T: Copy> TryFrom> for BiquadCoefficients { + type Error = Error; + + fn try_from(value: ArrayView1<'a, T>) -> Result { + if value.dim() != 6 { + return Err(Error::InvalidSosFormat); + } + Ok(Self { + b0: value[0], + b1: value[1], + b2: value[2], + a1: value[4], + a2: value[5], + }) + } +} + +impl<'a, T: Copy + Default> TryFrom> for Biquad { + type Error = Error; + + fn try_from(value: ArrayView1<'a, T>) -> Result { + Ok(Self::new(value.try_into()?)) + } +} + +impl<'a, T: Copy + Default> TryFrom> for Iir { + type Error = Error; + + fn try_from(value: ArrayView2<'a, T>) -> Result { + let biquads = value + .outer_iter() + .map(Biquad::try_from) + .collect::>>()?; + Ok(Self { biquads }) + } +} + +impl<'a, T, const N: usize> TryFrom> for IirPipeline +where + T: Copy + Default, + [T; N]: Default, +{ + type Error = Error; + + fn try_from(value: ArrayView2<'a, T>) -> Result { + if value.dim().0 != N { + panic!("N should be equal to the number of biquads in the pipeline"); + } + if value.dim().1 != 6 { + return Err(Error::InvalidSosFormat); + } + let b0 = array::from_fn(|i| value[(i, 0)]); + let b1 = array::from_fn(|i| value[(i, 1)]); + let b2 = array::from_fn(|i| value[(i, 2)]); + let a1 = array::from_fn(|i| value[(i, 4)]); + let a2 = array::from_fn(|i| value[(i, 5)]); + Ok(Self { + b0, + b1, + b2, + a1, + a2, + s1: Default::default(), + s2: Default::default(), + y: Default::default(), + }) + } +} + +pub(crate) fn iir_filter_inplace(signal: ArrayViewMut2, sos: ArrayView2) -> Result<()> +where + T: Add + Mul + Sub + Copy + Default, +{ + match sos.dim().0 { + 0 => Ok(()), + 1 => specialized_filter::(signal, sos), + 2 => specialized_filter::(signal, sos), + 3 => specialized_filter::(signal, sos), + 4 => specialized_filter::(signal, sos), + _ => fallback_filter(signal, sos), + } +} + +fn specialized_filter( + mut signal: ArrayViewMut2, + sos: ArrayView2, +) -> Result<()> +where + T: Add + Mul + Sub + Copy + Default, + [T; N]: Default, +{ + let mut iir: IirPipeline = sos.try_into()?; + for mut row in signal.outer_iter_mut() { + let row = row.as_slice_mut().expect("Row should be contiguous"); + iir.reset(); + iir.filter_inplace(row); + } + Ok(()) +} + +fn fallback_filter(mut signal: ArrayViewMut2, sos: ArrayView2) -> Result<()> +where + T: Add + Mul + Sub + Copy + Default, +{ + let mut iir: Iir = sos.try_into()?; + for mut row in signal.outer_iter_mut() { + let row = row.as_slice_mut().expect("Row should be contiguous"); + iir.reset(); + iir.filter_inplace(row); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use ndarray::Array2; + use numpy::array; + + use super::*; + + fn get_test_case() -> (Array2, Array2, Array2) { + // Generated using scipy.signal.sosfilt + let signal = Array2::ones((1, 10)); + let sos = array![ + [ + 0.41745636685930126, + -0.8124669318632639, + 0.39511212971824017, + 1.0, + -1.9517167839624654, + 0.9518873960332694 + ], + [ + 1.0, + -1.998900504749875, + 0.9989006046949026, + 1.0, + -1.9990955877237742, + 0.9990956472206675 + ], + ]; + let expected = array![[ + 0.41745636685930126, + 0.419827471396848, + 0.4221187550258305, + 0.4243336840250656, + 0.4264755709173414, + 0.42854758130187326, + 0.43055274038309926, + 0.4324939392093115, + 0.4343739406340189, + 0.4361953850123652 + ]]; + (signal, sos, expected) + } + + #[test] + fn test_specialized_filter() { + let (mut signal, sos, expected) = get_test_case(); + specialized_filter::<_, 2>(signal.view_mut(), sos.view()).unwrap(); + assert_eq!(signal, expected); + } + + #[test] + fn test_fallback_filter() { + let (mut signal, sos, expected) = get_test_case(); + fallback_filter(signal.view_mut(), sos.view()).unwrap(); + assert_eq!(signal, expected); + } +}