diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index 88268a3..746b5dc 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -52,8 +52,8 @@ def bench_numpy_in_slices(state: google_benchmark.State): state: the benchmark state for this execution run. """ random_seed = 1 - one_cap = False - cfp = carfac_np.design_carfac(one_cap=one_cap) + ihc_style = 'two_cap' + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False @@ -72,7 +72,7 @@ def bench_numpy_in_slices(state: google_benchmark.State): np_random.standard_normal(size=(n_samp, n_ears)) * _NOISE_FACTOR ) run_seg_slices = np.array_split(run_seg_input_full, split_count) - cfp = carfac_np.design_carfac(one_cap=one_cap) + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False state.resume_timing() @@ -109,8 +109,8 @@ def bench_numpy(state: google_benchmark.State): state: the benchmark state for this execution run. """ random_seed = 1 - one_cap = False - cfp = carfac_np.design_carfac(one_cap=one_cap) + ihc_style = 'two_cap' + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False @@ -123,7 +123,7 @@ def bench_numpy(state: google_benchmark.State): run_seg_input = ( np_random.standard_normal(size=(n_samp, n_ears)) * _NOISE_FACTOR ) - cfp = carfac_np.design_carfac(one_cap=one_cap) + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False state.resume_timing()