diff --git a/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_main.py b/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_main.py index 50af364539cd..eda2d0f932aa 100644 --- a/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_main.py +++ b/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_main.py @@ -83,6 +83,9 @@ flags.DEFINE_bool( 'split_embeddings_into_separate_tables', False, 'If true, write each embedding to a separate table.') +flags.DEFINE_bool( + 'use_frontend_fn', False, + 'If `true`, call frontend fn on audio before passing to the model.') flags.DEFINE_bool('debug', False, 'If True, run in debug model.') FLAGS = flags.FLAGS @@ -129,6 +132,7 @@ def main(unused_argv): FLAGS.delete_audio_from_output, output_filename, split_embeddings_into_separate_tables=FLAGS.split_embeddings_into_separate_tables, # pylint:disable=line-too-long + use_frontend_fn=FLAGS.use_frontend_fn, input_format=input_format, output_format=output_format, suffix=i) diff --git a/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_utils.py b/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_utils.py index 1770807ca717..1792e0480fe5 100644 --- a/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_utils.py +++ b/non_semantic_speech_benchmark/data_prep/audio_to_embeddings_beam_utils.py @@ -37,6 +37,7 @@ import tensorflow_datasets as tfds import tensorflow_hub as hub from non_semantic_speech_benchmark import file_utils +from non_semantic_speech_benchmark.export_model import tf_frontend def _tfexample_audio_to_npfloat32(ex, audio_key): @@ -71,6 +72,12 @@ def _build_tflite_interpreter(tflite_model_path): return interpreter +def _default_feature_fn(x, s): + return tf.expand_dims( + tf_frontend.compute_frontend_features(x, s, overlap_seconds=79), + axis=-1).numpy().astype(np.float32) + + def _samples_to_embedding_tflite(model_input, sample_rate, interpreter, output_key): """Run TFLite inference to map audio samples to an embedding.""" @@ -391,7 +398,8 @@ def make_beam_pipeline( embedding_modules, module_output_keys, audio_key, sample_rate_key, label_key, speaker_id_key, average_over_time, delete_audio_from_output, output_filename, split_embeddings_into_separate_tables=False, - input_format='tfrecord', output_format='tfrecord', suffix='Main'): + use_frontend_fn=False, input_format='tfrecord', output_format='tfrecord', + suffix='Main'): """Construct beam pipeline for mapping from audio to embeddings. Args: @@ -414,6 +422,8 @@ def make_beam_pipeline( output_filename: Python string. Output filename. split_embeddings_into_separate_tables: Python bool. If true, write each embedding to a separate table. + use_frontend_fn: If `true`, call frontend fn on audio before passing to the + model. input_format: Python string. Must correspond to a function in `reader_functions`. output_format: Python string. Must correspond to a function @@ -448,7 +458,8 @@ def make_beam_pipeline( audio_key=audio_key, sample_rate_key=sample_rate_key, sample_rate=sample_rate, - average_over_time=average_over_time)) + average_over_time=average_over_time, + feature_fn=_default_feature_fn if use_frontend_fn else None)) embedding_tables[name] = tbl assert tf_examples_key_ not in embedding_tables embedding_tables[tf_examples_key_] = input_examples