diff --git a/python/jax/carfac.py b/python/jax/carfac.py index a3a140d..984fefa 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -517,8 +517,7 @@ def tree_unflatten(cls, _, children): @dataclasses.dataclass class IhcDesignParameters: """Variables needed for the inner hair cell implementation.""" - just_hwr: bool = False - n_caps: int = 2 + ihc_style: str = 'two_cap' tau_lpf: float = 0.000080 # 80 microseconds smoothing twice tau_out: float = 0.0005 # depletion tau is pretty fast tau_in: float = 0.010 # recovery tau is slower @@ -530,24 +529,26 @@ class IhcDesignParameters: # The following 2 functions are boiler code required by pytree. # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): # pylint: disable=missing-function-docstring - children = (self.just_hwr, - self.n_caps, - self.tau_lpf, - self.tau_out, - self.tau_in, - self.tau1_out, - self.tau1_in, - self.tau2_out, - self.tau2_in) - aux_data = ('just_hwr', - 'n_caps', - 'tau_lpf', - 'tau_out', - 'tau_in', - 'tau1_out', - 'tau1_in', - 'tau2_out', - 'tau2_in') + children = ( + self.ihc_style, + self.tau_lpf, + self.tau_out, + self.tau_in, + self.tau1_out, + self.tau1_in, + self.tau2_out, + self.tau2_in, + ) + aux_data = ( + 'ihc_style', + 'tau_lpf', + 'tau_out', + 'tau_in', + 'tau1_out', + 'tau1_in', + 'tau2_out', + 'tau2_in', + ) return (children, aux_data) @classmethod @@ -560,14 +561,14 @@ def tree_unflatten(cls, _, children): class IhcHypers: """Hyperparameters for the inner hair cell. Tagged `static` in `jax.jit`.""" n_ch: int - just_hwr: bool - n_caps: int + # 0 is just_hwr, 1 is one_cap, 2 is two_cap. + ihc_style: int # The following 2 functions are boiler code required by pytree. # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): - children = (self.n_ch, self.just_hwr, self.n_caps) - aux_data = ('n_ch', 'just_hwr', 'n_caps') + children = (self.n_ch, self.ihc_style) + aux_data = ('n_ch', 'ihc_style') return (children, aux_data) @classmethod @@ -1130,13 +1131,20 @@ def design_and_init_ihc( ihc_params = ear_params.ihc n_ch = ear_hypers.n_ch - ihc_hypers = IhcHypers( - n_ch=n_ch, just_hwr=ihc_params.just_hwr, n_caps=ihc_params.n_caps - ) - if ihc_params.just_hwr: + ihc_style_num = 0 + if ihc_params.ihc_style == 'just_hwr': + ihc_style_num = 0 + elif ihc_params.ihc_style == 'one_cap': + ihc_style_num = 1 + elif ihc_params.ihc_style == 'two_cap': + ihc_style_num = 2 + else: + raise NotImplementedError + ihc_hypers = IhcHypers(n_ch=n_ch, ihc_style=ihc_style_num) + if ihc_params.ihc_style == 'just_hwr': ihc_weights = IhcWeights() ihc_state = IhcState(ihc_accum=jnp.zeros((n_ch,))) - elif ihc_params.n_caps == 1: + elif ihc_params.ihc_style == 'one_cap': ro = 1 / ihc_detect(10) # output resistance at a very high level c = ihc_params.tau_out / ro ri = ihc_params.tau_in / c @@ -1159,7 +1167,7 @@ def design_and_init_ihc( lpf1_state=ihc_weights.rest_output * jnp.ones((n_ch,)), lpf2_state=ihc_weights.rest_output * jnp.ones((n_ch,)), ) - elif ihc_params.n_caps == 2: + elif ihc_params.ihc_style == 'two_cap': g1_max = ihc_detect(10) # receptor conductance at high level r1min = 1 / g1_max @@ -1631,13 +1639,13 @@ def ihc_step( ihc_weights = weights.ears[ear].ihc ihc_hypers = hypers.ears[ear].ihc - if ihc_hypers.just_hwr: + if ihc_hypers.ihc_style == 0: ihc_out = jnp.min(2, jnp.max(0, bm_out)) # pytype: disable=wrong-arg-types # jnp-type # limit it for stability else: conductance = ihc_detect(bm_out) # rectifying nonlinearity - if ihc_hypers.n_caps == 1: + if ihc_hypers.ihc_style == 1: ihc_out = conductance * ihc_state.cap_voltage ihc_state.cap_voltage = ( ihc_state.cap_voltage diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index 4bf45c0..88268a3 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -143,10 +143,10 @@ def bench_jax_grad(state: google_benchmark.State): Args: state: The Benchmark state for this run. """ - one_cap = False + ihc_style = 'two_cap' random_seed = 1 params_jax = carfac_jax.CarfacDesignParameters() - params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False random_generator = jax.random.PRNGKey(random_seed) n_samp = state.range(0) @@ -202,10 +202,10 @@ def bench_jit_compile_time(state: google_benchmark.State): Args: state: The benchmark state to execute over. """ - one_cap = False + ihc_style = 'two_cap' random_seed = 1 params_jax = carfac_jax.CarfacDesignParameters() - params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False random_generator = jax.random.PRNGKey(random_seed) n_samp = 1 @@ -251,10 +251,10 @@ def bench_jax_in_slices(state: google_benchmark.State): state: the benchmark state for this execution run. """ # Inits JAX version - one_cap = False + ihc_style = 'two_cap' random_seed = 1 params_jax = carfac_jax.CarfacDesignParameters() - params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False # Generate some random inputs. @@ -335,11 +335,11 @@ def bench_jax(state: google_benchmark.State): state: the benchmark state for this execution run. """ # Inits JAX version - one_cap = False + ihc_style = 'two_cap' random_seed = 1 params_jax = carfac_jax.CarfacDesignParameters() params_jax.ears[0].car.use_delay_buffer = state.range(2) - params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False # Generate some random inputs. @@ -393,10 +393,10 @@ def bench_jax_util_mapped(state: google_benchmark.State): """ if jax.device_count() < state.range(0): state.skip_with_error(f'requires {state.range(0)} devices') - one_cap = False random_seed = state.range(0) + ihc_style = 'two_cap' params_jax = carfac_jax.CarfacDesignParameters() - params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False random_generator = jax.random.PRNGKey(random_seed) hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index e0b213f..1a406e5 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -35,9 +35,11 @@ def _assert_almost_equal_pytrees(self, pytree1, pytree2, delta=None): self.assertSequenceAlmostEqual(elements1, elements2, delta=delta) @parameterized.product( - random_seed=[x for x in range(20)], one_cap=[False, True], n_ears=[1, 2] + random_seed=[x for x in range(20)], + ihc_style=['one_cap', 'two_cap'], + n_ears=[1, 2], ) - def test_backward_pass(self, random_seed, one_cap, n_ears): + def test_backward_pass(self, random_seed, ihc_style, n_ears): # Tests `jax.grad` can give similar gradients computed by numeric method. @functools.partial(jax.jit, static_argnames=('hypers',)) def loss(weights, input_waves, hypers, state): @@ -66,7 +68,7 @@ def loss(weights, input_waves, hypers, state): # Computes gradients by `jax.grad`. gfunc = jax.grad(loss, has_aux=True) params_jax = carfac_jax.CarfacDesignParameters(n_ears=n_ears) - params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( params_jax diff --git a/python/jax/carfac_test.py b/python/jax/carfac_test.py index e85108e..34b552b 100644 --- a/python/jax/carfac_test.py +++ b/python/jax/carfac_test.py @@ -29,9 +29,7 @@ def test_hypers_hash(self): hypers.ears[0].car = carfac_jax.CarHypers() hypers.ears[0].agc = [carfac_jax.AgcHypers(n_ch=1, n_agc_stages=2), carfac_jax.AgcHypers(n_ch=1, n_agc_stages=2)] - hypers.ears[0].ihc = carfac_jax.IhcHypers(n_ch=1, - just_hwr=True, - n_caps=1) + hypers.ears[0].ihc = carfac_jax.IhcHypers(n_ch=1, ihc_style=1) h1 = hash(hypers) hypers.ears[0].car.n_ch += 1 h2 = hash(hypers) @@ -39,7 +37,7 @@ def test_hypers_hash(self): hypers.ears[0].agc[1].reverse_cumulative_decimation += 1 h3 = hash(hypers) self.assertNotEqual(h2, h3) - hypers.ears[0].ihc.just_hwr = not hypers.ears[0].ihc.just_hwr + hypers.ears[0].ihc.ihc_style = 2 h4 = hash(hypers) self.assertNotEqual(h3, h4) @@ -110,15 +108,15 @@ def container_comparison(self, left_side, right_side, exclude_keys=None): msg='failed comparison on key item %s' % (k), ) - @parameterized.parameters([1, 2]) - def test_equal_design(self, n_caps): + @parameterized.parameters(['one_cap', 'two_cap']) + def test_equal_design(self, ihc_style): # Test: the designs are similar. - cfp = carfac_np.design_carfac(one_cap=(n_caps == 1)) + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False params_jax = carfac_jax.CarfacDesignParameters() - params_jax.ears[0].ihc.n_caps = n_caps + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( params_jax @@ -169,11 +167,11 @@ def test_equal_design(self, n_caps): self.container_comparison( hypers_jax.ears[ear_idx].ihc, ear_params_np.ihc_coeffs, - exclude_keys={'n_caps'}, + exclude_keys={'ihc_style'}, ) self.assertEqual( - ear_params_np.ihc_coeffs.one_cap, - hypers_jax.ears[ear_idx].ihc.n_caps == 1, + ear_params_np.ihc_coeffs.ihc_style, + hypers_jax.ears[ear_idx].ihc.ihc_style, ) self.container_comparison( @@ -182,7 +180,7 @@ def test_equal_design(self, n_caps): exclude_keys='lpf2_state', ) - if ear_params_np.ihc_coeffs.one_cap: + if ear_params_np.ihc_coeffs.ihc_style == 1: self.assertSequenceAlmostEqual( state_jax.ears[ear_idx].ihc.lpf2_state, ear_params_np.ihc_state.lpf2_state, @@ -195,11 +193,7 @@ def test_equal_design(self, n_caps): # now we only check these one by one. We could add tests for 2 cap # similarly. self.assertEqual( - cfp.ihc_params.one_cap, params_jax.ears[ear_idx].ihc.n_caps == 1 - ) - - self.assertEqual( - cfp.ihc_params.just_hwr, params_jax.ears[ear_idx].ihc.just_hwr + cfp.ihc_params.ihc_style, params_jax.ears[ear_idx].ihc.ihc_style ) self.assertEqual( @@ -250,13 +244,14 @@ def test_equal_design(self, n_caps): ) @parameterized.product( - random_seed=[x for x in range(5)], one_cap=[False, True] + random_seed=[x for x in range(5)], + ihc_style=['one_cap', 'two_cap'], ) - def test_chunked_naps_same_as_jit(self, random_seed, one_cap): + def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): """Tests whether `run_segment` produces the same results as np version.""" # Inits JAX version params_jax = carfac_jax.CarfacDesignParameters() - params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( params_jax @@ -294,11 +289,13 @@ def test_chunked_naps_same_as_jit(self, random_seed, one_cap): @parameterized.product( random_seed=[x for x in range(20)], - one_cap=[False, True], + ihc_style=['one_cap', 'two_cap'], n_ears=[1, 2], delay_buffer=[False, True], ) - def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer): + def test_equal_forward_pass( + self, random_seed, ihc_style, n_ears, delay_buffer + ): """Tests whether `run_segment` produces the same results as np version.""" # Inits JAX version params_jax = carfac_jax.CarfacDesignParameters( @@ -306,14 +303,14 @@ def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer): ) params_jax.n_ears = n_ears for ear in range(n_ears): - params_jax.ears[ear].ihc.n_caps = 1 if one_cap else 2 + params_jax.ears[ear].ihc.ihc_style = ihc_style params_jax.ears[ear].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( params_jax ) # Inits numpy version cfp = carfac_np.design_carfac( - one_cap=one_cap, n_ears=n_ears, use_delay_buffer=delay_buffer + ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer ) carfac_np.carfac_init(cfp) @@ -419,7 +416,7 @@ def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer): state_np.ears[ear].ihc_state.lpf1_state, delta=1e-3, # Low Precision ) - if cfp.ears[ear].ihc_coeffs.one_cap: + if cfp.ears[ear].ihc_coeffs.ihc_style == 1: self.assertSequenceAlmostEqual( state_jax.ears[ear].ihc.lpf2_state, state_np.ears[ear].ihc_state.lpf2_state, @@ -430,7 +427,7 @@ def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer): state_np.ears[ear].ihc_state.cap_voltage, delta=2e-5, # Low Precision ) - else: + elif cfp.ears[ear].ihc_coeffs.ihc_style == 2: # `state_np` won't have `cap1_voltage` or `cap2_voltage` if # `one_cap==True`. self.assertSequenceAlmostEqual( @@ -443,6 +440,8 @@ def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer): state_np.ears[ear].ihc_state.cap2_voltage, delta=1e-5, # Low Precision ) + else: + self.fail('Unsupported IHC style.') # Comapares agc state for stage in range(hypers_jax.ears[ear].agc[0].n_agc_stages): self.assertSequenceAlmostEqual( diff --git a/python/jax/carfac_util_test.py b/python/jax/carfac_util_test.py index 3f31cd3..753238b 100644 --- a/python/jax/carfac_util_test.py +++ b/python/jax/carfac_util_test.py @@ -25,11 +25,11 @@ class CarfacUtilTest(absltest.TestCase): def setUp(self): super().setUp() - self.one_cap = False + self.ihc_style = 'two_cap' self.random_seed = 17234 self.open_loop = False params_jax = carfac.CarfacDesignParameters() - params_jax.ears[0].ihc.n_caps = 1 if self.one_cap else 2 + params_jax.ears[0].ihc.ihc_style = self.ihc_style params_jax.ears[0].car.linear_car = False self.random_generator = jax.random.PRNGKey(self.random_seed) self.hypers, self.weights, self.init_state = carfac.design_and_init_carfac( diff --git a/python/np/carfac.py b/python/np/carfac.py index 750ead0..26331a9 100644 --- a/python/np/carfac.py +++ b/python/np/carfac.py @@ -375,13 +375,12 @@ def car_step(x_in: float, # TODO(malcolmslaney) Perhaps make one superclass? @dataclasses.dataclass class IhcJustHwrParams: - just_hwr: bool = True # just a simple HWR + ihc_style: str = 'just_hwr' @dataclasses.dataclass class IhcOneCapParams(IhcJustHwrParams): - just_hwr: bool = False # not just a simple HWR - one_cap: bool = True # bool; False for new two-cap hack + ihc_style: str = 'one_cap' tau_lpf: float = 0.000080 # 80 microseconds smoothing twice tau_out: float = 0.0005 # depletion tau is pretty fast tau_in: float = 0.010 # recovery tau is slower @@ -389,8 +388,7 @@ class IhcOneCapParams(IhcJustHwrParams): @dataclasses.dataclass class IhcTwoCapParams(IhcJustHwrParams): - just_hwr: bool = False # not just a simple HWR - one_cap: bool = False # bool; False for new two-cap hack + ihc_style: str = 'two_cap' tau_out: float = 0.0005 # depletion tau is pretty fast tau_in: float = 0.010 # recovery tau is slower tau_lpf: float = 0.000080 # 80 microseconds smoothing twice @@ -434,24 +432,27 @@ def ihc_detect(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: class IhcCoeffs: """Variables needed for the inner hair cell implementation.""" n_ch: int - just_hwr: bool lpf_coeff: float = 0 out1_rate: float = 0 in1_rate: float = 0 out2_rate: float = 0 in2_rate: float = 0 - one_cap: float = 0 output_gain: float = 0 rest_output: float = 0 rest_cap2: float = 0 rest_cap1: float = 0 - rest_cap: float = 0 out_rate: float = 0 in_rate: float = 0 + # 0 is just_hwr, 1 is one_cap, 2 is two_cap + ihc_style: int = 0 -def design_ihc(ihc_params: IhcJustHwrParams, fs: float, n_ch: int) -> IhcCoeffs: +def design_ihc( + ihc_params: IhcJustHwrParams | IhcOneCapParams | IhcTwoCapParams, + fs: float, + n_ch: int, +) -> IhcCoeffs: """Design the inner hair cell implementation from parameters. Args: @@ -462,9 +463,7 @@ def design_ihc(ihc_params: IhcJustHwrParams, fs: float, n_ch: int) -> IhcCoeffs: Returns: A IHC coefficient class. """ - if ihc_params.just_hwr: - ihc_coeffs = IhcCoeffs(n_ch=n_ch, just_hwr=True) - elif isinstance(ihc_params, IhcOneCapParams): + if isinstance(ihc_params, IhcOneCapParams): ro = 1 / ihc_detect(10) # output resistance at a very high level c = ihc_params.tau_out / ro ri = ihc_params.tau_in / c @@ -476,14 +475,14 @@ def design_ihc(ihc_params: IhcJustHwrParams, fs: float, n_ch: int) -> IhcCoeffs: cap_voltage = 1 - current * ri ihc_coeffs = IhcCoeffs( n_ch=n_ch, - just_hwr=False, + ihc_style=1, lpf_coeff=1 - math.exp(-1 / (ihc_params.tau_lpf * fs)), out_rate=ro / (ihc_params.tau_out * fs), in_rate=1 / (ihc_params.tau_in * fs), - one_cap=ihc_params.one_cap, output_gain=1 / (saturation_output - current), rest_output=current / (saturation_output - current), - rest_cap=cap_voltage) + rest_cap=cap_voltage, + ) elif isinstance(ihc_params, IhcTwoCapParams): g1_max = ihc_detect(10) # receptor conductance at high level @@ -514,18 +513,19 @@ def design_ihc(ihc_params: IhcJustHwrParams, fs: float, n_ch: int) -> IhcCoeffs: ihc_coeffs = IhcCoeffs( n_ch=n_ch, - just_hwr=False, + ihc_style=2, lpf_coeff=1 - math.exp(-1 / (ihc_params.tau_lpf * fs)), out1_rate=r1min / (ihc_params.tau1_out * fs), in1_rate=1 / (ihc_params.tau1_in * fs), out2_rate=r2min / (ihc_params.tau2_out * fs), in2_rate=1 / (ihc_params.tau2_in * fs), - one_cap=ihc_params.one_cap, output_gain=1 / (saturation_current2 - rest_current2), rest_output=rest_current2 / (saturation_current2 - rest_current2), rest_cap2=cap2_voltage, rest_cap1=cap1_voltage, ) + elif isinstance(ihc_params, IhcJustHwrParams): + ihc_coeffs = IhcCoeffs(n_ch=n_ch, ihc_style=0) else: raise NotImplementedError return ihc_coeffs @@ -556,19 +556,18 @@ class IhcState: def __init__(self, coeffs, dtype=np.float32): n_ch = coeffs.n_ch - if coeffs.just_hwr: + if coeffs.ihc_style == 0: self.ihc_accum = np.zeros((n_ch,), dtype=dtype) + elif coeffs.ihc_style == 1: + self.ihc_accum = np.zeros((n_ch,), dtype=dtype) + self.cap_voltage = coeffs.rest_cap * np.ones((n_ch,), dtype=dtype) + self.lpf1_state = coeffs.rest_output * np.ones((n_ch,), dtype=dtype) + self.lpf2_state = coeffs.rest_output * np.ones((n_ch,), dtype=dtype) else: - if coeffs.one_cap: - self.ihc_accum = np.zeros((n_ch,), dtype=dtype) - self.cap_voltage = coeffs.rest_cap * np.ones((n_ch,), dtype=dtype) - self.lpf1_state = coeffs.rest_output * np.ones((n_ch,), dtype=dtype) - self.lpf2_state = coeffs.rest_output * np.ones((n_ch,), dtype=dtype) - else: - self.ihc_accum = np.zeros((n_ch,), dtype=dtype) - self.cap1_voltage = coeffs.rest_cap1 * np.ones((n_ch,), dtype=dtype) - self.cap2_voltage = coeffs.rest_cap2 * np.ones((n_ch,), dtype=dtype) - self.lpf1_state = coeffs.rest_output * np.ones((n_ch,), dtype=dtype) + self.ihc_accum = np.zeros((n_ch,), dtype=dtype) + self.cap1_voltage = coeffs.rest_cap1 * np.ones((n_ch,), dtype=dtype) + self.cap2_voltage = coeffs.rest_cap2 * np.ones((n_ch,), dtype=dtype) + self.lpf1_state = coeffs.rest_output * np.ones((n_ch,), dtype=dtype) def ihc_init_state(coeffs): @@ -614,13 +613,13 @@ def ihc_step(bm_out: np.ndarray, ihc_coeffs: IhcCoeffs, and the new state. """ - if ihc_coeffs.just_hwr: + if ihc_coeffs.ihc_style == 0: ihc_out = np.min(2, np.max(0, bm_out)) # limit it for stability else: conductance = ihc_detect(bm_out) # rectifying nonlinearity - if ihc_coeffs.one_cap: + if ihc_coeffs.ihc_style == 1: ihc_out = conductance * ihc_state.cap_voltage ihc_state.cap_voltage = ( ihc_state.cap_voltage - ihc_out * ihc_coeffs.out_rate + @@ -1025,8 +1024,7 @@ def design_carfac( ihc_params: Optional[ Union[IhcJustHwrParams, IhcOneCapParams, IhcTwoCapParams] ] = None, - one_cap: bool = False, - just_hwr: bool = False, + ihc_style: str = 'two_cap', use_delay_buffer: bool = False, ) -> CarfacParams: """This function designs the CARFAC filterbank. @@ -1051,8 +1049,9 @@ def design_carfac( car_params: bundles all the pole-zero filter cascade parameters agc_params: bundles all the automatic gain control parameters ihc_params: bundles all the inner hair cell parameters - one_cap: True for Allen model, as Lyon's book describes - just_hwr: False for normal/fancy IHC; True for HWR. + ihc_style: Type of IHC Model to use. Valid avlues are 'one_cap' for Allen + model, 'two_cap' for the v2 model, and 'just_hwr' for the simpler HWR + model. use_delay_buffer: Whether to use the delay buffer implementation in the CAR step. @@ -1064,14 +1063,15 @@ def design_carfac( agc_params = agc_params or AgcParams() if not ihc_params: - if just_hwr: - ihc_params = IhcJustHwrParams() - else: - if one_cap: - ihc_params = IhcOneCapParams() - else: + match ihc_style: + case 'two_cap': ihc_params = IhcTwoCapParams() - + case 'one_cap': + ihc_params = IhcOneCapParams() + case 'just_hwr': + ihc_params = IhcJustHwrParams() + case _: + raise ValueError(f'Unknown IHC style: {ihc_style}') # first figure out how many filter stages (PZFC/CARFAC channels): pole_hz = car_params.first_pole_theta * fs / (2 * math.pi) n_ch = 0 diff --git a/python/np/carfac_test.py b/python/np/carfac_test.py index 93d244c..4fc5521 100644 --- a/python/np/carfac_test.py +++ b/python/np/carfac_test.py @@ -220,7 +220,7 @@ def test_car_freq_response(self): self.assertAlmostEqual(bw, correct_bw, delta=0.1) self.assertAlmostEqual(q, correct_q) - def run_ihc(self, test_freq=300, one_cap=True): + def run_ihc(self, test_freq=300, ihc_style='one_cap'): fs = 40000 sampling_interval = 1 / fs tmax = 0.28 # a half second @@ -235,7 +235,7 @@ def run_ihc(self, test_freq=300, one_cap=True): amplitude = 0.09 * 2**stim_num omega = 2 * np.pi * test_freq - cfp = carfac.design_carfac(fs=fs, one_cap=one_cap) + cfp = carfac.design_carfac(fs=fs, ihc_style=ihc_style) cfp = carfac.carfac_init(cfp) quad_sin = np.sin(omega * t) * present @@ -253,7 +253,7 @@ def run_ihc(self, test_freq=300, one_cap=True): plt.plot(t, neuro_output) plt.xlabel('Seconds') plt.title(f'IHC Response for tone blips at {test_freq}Hz') - plt.savefig(f'/tmp/ihc_response_cap_{one_cap}cap_{test_freq}Hz.png') + plt.savefig(f'/tmp/ihc_response_cap_{ihc_style}_{test_freq}Hz.png') blip_maxes = [] blip_ac = [] for i in range(1, 7): @@ -270,7 +270,7 @@ def run_ihc(self, test_freq=300, one_cap=True): @parameterized.named_parameters( ( 'two_cap_300', - False, + 'two_cap', 300, [ [2.026682, 544.901381], @@ -283,7 +283,7 @@ def run_ihc(self, test_freq=300, one_cap=True): ), ( 'two_cap_3000', - False, + 'two_cap', 3000, [ [0.698303, 93.388172], @@ -296,7 +296,7 @@ def run_ihc(self, test_freq=300, one_cap=True): ), ( 'one_cap_300', - True, + 'one_cap', 300, [ [2.752913, 721.001685], @@ -309,7 +309,7 @@ def run_ihc(self, test_freq=300, one_cap=True): ), ( 'one_cap_3000', - True, + 'one_cap', 3000, [ [1.417657, 234.098558], @@ -322,7 +322,7 @@ def run_ihc(self, test_freq=300, one_cap=True): ), ) def test_ihc_param(self, cap, freq, test_results): - blip_maxes, blip_ac = self.run_ihc(freq, one_cap=cap) + blip_maxes, blip_ac = self.run_ihc(freq, ihc_style=cap) for i, (max_val, ac) in enumerate(test_results): self.assertAlmostEqual(blip_maxes[i], max_val, delta=max_val / 10000) self.assertAlmostEqual(blip_ac[i], ac, delta=ac / 10000) @@ -428,8 +428,10 @@ def test_stage_g_calculation(self): f'Failed at channel {ch} for undamping {undamping}.', ) - @parameterized.named_parameters(('two_cap', False), ('one_cap', True)) - def test_whole_carfac(self, cap): + @parameterized.named_parameters( + ('two_cap', 'two_cap'), ('one_cap', 'one_cap') + ) + def test_whole_carfac(self, ihc_style): # Test: Make sure that the AGC adapts to a tone. Test with open-loop impulse # response. @@ -442,7 +444,7 @@ def test_whole_carfac(self, cap): impulse = np.zeros(t.shape) impulse[0] = 1e-4 - cfp = carfac.design_carfac(fs=fs, one_cap=cap) + cfp = carfac.design_carfac(fs=fs, ihc_style=ihc_style) cfp = carfac.carfac_init(cfp) _, cfp, bm_initial, _, _ = carfac.run_segment( @@ -502,7 +504,7 @@ def test_whole_carfac(self, cap): 65: (7.765176633256488e-06, 7.573388744425412e-06), 70: (5.994126581754244e-07, 5.919053135128626e-07), } - if not cap: + if ihc_style == 'two_cap': # The following data comes from the Numpy implementation max_expected_responses = { # By channel, pre and post adaptation 0: (9.487948409514502e-05, 9.489925609401865e-05), @@ -588,7 +590,7 @@ def find_closest_channel(cfs: List[float], desired: float) -> np.ndarray: return np.argmin((np.asarray(cfs) - desired)**2) results = {} - if cap: + if ihc_style == 'one_cap': results = { # The Matlab test prints this data block: 125: [64, 119.007, 0.264], 250: [58, 239.791, 0.986], @@ -745,7 +747,7 @@ def test_multiaural_carfac(self): two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = noise two_chan_noise[:, 1] = noise - cfp = carfac.design_carfac(fs=fs, n_ears=2, one_cap=True) + cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap') cfp = carfac.carfac_init(cfp) naps, _, _, _, _ = carfac.run_segment(cfp, two_chan_noise) max_abs_diff = np.amax(np.abs(naps[:, :, 0] - naps[:, :, 1])) @@ -782,9 +784,9 @@ def test_multiaural_carfac_with_silent_channel(self): two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = c_major_chord # Leave the audio in channel 1 as silence. - cfp = carfac.design_carfac(fs=fs, n_ears=2, one_cap=True) + cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap') cfp = carfac.carfac_init(cfp) - mono_cfp = carfac.design_carfac(fs=fs, n_ears=1, one_cap=True) + mono_cfp = carfac.design_carfac(fs=fs, n_ears=1, ihc_style='one_cap') mono_cfp = carfac.carfac_init(mono_cfp) _, _, bm_binaural, _, _ = carfac.run_segment(cfp, two_chan_noise)