From d2a5ccd0e8fcd67313778b624f4c9681912311c6 Mon Sep 17 00:00:00 2001 From: Rob Schonberger Date: Wed, 9 Oct 2024 16:10:26 -0700 Subject: [PATCH] Fix a forgotten change in carfac_bench to carfac_np.design_carfac, which was using the one_cap argument and broken. This updates to use the ihc_style string as elsewhere in the codebase and fixes the benchmark. PiperOrigin-RevId: 684205274 --- python/jax/carfac_bench.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index 88268a3..746b5dc 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -52,8 +52,8 @@ def bench_numpy_in_slices(state: google_benchmark.State): state: the benchmark state for this execution run. """ random_seed = 1 - one_cap = False - cfp = carfac_np.design_carfac(one_cap=one_cap) + ihc_style = 'two_cap' + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False @@ -72,7 +72,7 @@ def bench_numpy_in_slices(state: google_benchmark.State): np_random.standard_normal(size=(n_samp, n_ears)) * _NOISE_FACTOR ) run_seg_slices = np.array_split(run_seg_input_full, split_count) - cfp = carfac_np.design_carfac(one_cap=one_cap) + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False state.resume_timing() @@ -109,8 +109,8 @@ def bench_numpy(state: google_benchmark.State): state: the benchmark state for this execution run. """ random_seed = 1 - one_cap = False - cfp = carfac_np.design_carfac(one_cap=one_cap) + ihc_style = 'two_cap' + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False @@ -123,7 +123,7 @@ def bench_numpy(state: google_benchmark.State): run_seg_input = ( np_random.standard_normal(size=(n_samp, n_ears)) * _NOISE_FACTOR ) - cfp = carfac_np.design_carfac(one_cap=one_cap) + cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False state.resume_timing()