Skip to content

Commit

Permalink
Support custom checkpoint to disk.
Browse files Browse the repository at this point in the history
Also, pipe through option to use default frontend fn for data prep.

PiperOrigin-RevId: 370953785
  • Loading branch information
joel-shor authored and copybara-github committed Apr 28, 2021
1 parent fa8702c commit 29148dc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
flags.DEFINE_bool(
'split_embeddings_into_separate_tables', False,
'If true, write each embedding to a separate table.')
flags.DEFINE_bool(
'use_frontend_fn', False,
'If `true`, call frontend fn on audio before passing to the model.')
flags.DEFINE_bool('debug', False, 'If True, run in debug model.')

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -129,6 +132,7 @@ def main(unused_argv):
FLAGS.delete_audio_from_output,
output_filename,
split_embeddings_into_separate_tables=FLAGS.split_embeddings_into_separate_tables, # pylint:disable=line-too-long
use_frontend_fn=FLAGS.use_frontend_fn,
input_format=input_format,
output_format=output_format,
suffix=i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from non_semantic_speech_benchmark import file_utils
from non_semantic_speech_benchmark.export_model import tf_frontend


def _tfexample_audio_to_npfloat32(ex, audio_key):
Expand Down Expand Up @@ -71,6 +72,12 @@ def _build_tflite_interpreter(tflite_model_path):
return interpreter


def _default_feature_fn(x, s):
return tf.expand_dims(
tf_frontend.compute_frontend_features(x, s, overlap_seconds=79),
axis=-1).numpy().astype(np.float32)


def _samples_to_embedding_tflite(model_input, sample_rate, interpreter,
output_key):
"""Run TFLite inference to map audio samples to an embedding."""
Expand Down Expand Up @@ -391,7 +398,8 @@ def make_beam_pipeline(
embedding_modules, module_output_keys, audio_key, sample_rate_key,
label_key, speaker_id_key, average_over_time, delete_audio_from_output,
output_filename, split_embeddings_into_separate_tables=False,
input_format='tfrecord', output_format='tfrecord', suffix='Main'):
use_frontend_fn=False, input_format='tfrecord', output_format='tfrecord',
suffix='Main'):
"""Construct beam pipeline for mapping from audio to embeddings.
Args:
Expand All @@ -414,6 +422,8 @@ def make_beam_pipeline(
output_filename: Python string. Output filename.
split_embeddings_into_separate_tables: Python bool. If true, write each
embedding to a separate table.
use_frontend_fn: If `true`, call frontend fn on audio before passing to the
model.
input_format: Python string. Must correspond to a function in
`reader_functions`.
output_format: Python string. Must correspond to a function
Expand Down Expand Up @@ -448,7 +458,8 @@ def make_beam_pipeline(
audio_key=audio_key,
sample_rate_key=sample_rate_key,
sample_rate=sample_rate,
average_over_time=average_over_time))
average_over_time=average_over_time,
feature_fn=_default_feature_fn if use_frontend_fn else None))
embedding_tables[name] = tbl
assert tf_examples_key_ not in embedding_tables
embedding_tables[tf_examples_key_] = input_examples
Expand Down

0 comments on commit 29148dc

Please sign in to comment.