Skip to content

Commit

Permalink
Script for estimating Lhotse dynamic duration buckets (#8237)
Browse files Browse the repository at this point in the history
* Script for estimating Lhotse dynamic duration buckets

Signed-off-by: Piotr Żelasko <[email protected]>

* Improve documentation

Signed-off-by: Piotr Żelasko <[email protected]>

* Apply suggestions from code review

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko authored Feb 13, 2024
1 parent 3a76b9d commit 03a7e4f
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 64 deletions.
22 changes: 22 additions & 0 deletions docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,28 @@ Some other Lhotse related arguments we support:

The full and always up-to-date list of supported options can be found in ``LhotseDataLoadingConfig`` class.

Pre-computing bucket duration bins
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

We recommend to pre-compute the bucket duration bins in order to accelerate the start of the training -- otherwise, the dynamic bucketing sampler will have to spend some time estimating them before the training starts.
The following script may be used::

$ python scripts/speech_recognition/estimate_duration_bins.py -b 30 manifest.json

Use the following options in your config:
num_buckets=30
bucket_duration_bins=[1.78,2.34,2.69,...
<other diagnostic information about the dataset>

For multi-dataset setups, one may provide multiple manifests and even their weights::

$ python scripts/speech_recognition/estimate_duration_bins.py -b 30 [[manifest.json,0.7],[other.json,0.3]]

Use the following options in your config:
num_buckets=30
bucket_duration_bins=[1.91,3.02,3.56,...
<other diagnostic information about the dataset>

Preparing Text-Only Data for Hybrid ASR-TTS Models
--------------------------------------------------

Expand Down
138 changes: 77 additions & 61 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import warnings
from itertools import repeat
from pathlib import Path
from typing import Sequence, Tuple

Expand Down Expand Up @@ -47,7 +48,6 @@ def read_cutset_from_config(config) -> Tuple[CutSet, bool]:


def read_lhotse_manifest(config, is_tarred: bool) -> CutSet:

if is_tarred:
# Lhotse Shar is the equivalent of NeMo's native "tarred" dataset.
# The combination of shuffle_shards, and repeat causes this to
Expand Down Expand Up @@ -95,7 +95,7 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet:
logging.info(f"- {path=} {weight=}")
cutsets.append(cs.repeat())
weights.append(weight)
cuts = CutSet.mux(*cutsets, weights=weights)
cuts = mux(*cutsets, weights=weights, max_open_streams=config.max_open_streams)
else:
# Regular Lhotse manifest points to individual audio files (like native NeMo manifest).
cuts = CutSet.from_file(config.cuts_path)
Expand All @@ -107,11 +107,13 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet:
"text_field": config.text_field,
"lang_field": config.lang_field,
}
if is_tarred:
if isinstance(config.manifest_filepath, (str, Path)):
logging.info(
f"Initializing Lhotse CutSet from a single NeMo manifest (tarred): '{config.manifest_filepath}'"
)
# The option below is to allow a special case of NeMo manifest iteration as Lhotse CutSet
# without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse.
# This is useful for utility scripts that iterate metadata and estimate optimal batching settings.
notar_kwargs = {"missing_sampling_rate_ok": config.missing_sampling_rate_ok}
if isinstance(config.manifest_filepath, (str, Path)):
logging.info(f"Initializing Lhotse CutSet from a single NeMo manifest (tarred): '{config.manifest_filepath}'")
if is_tarred:
cuts = CutSet(
LazyNeMoTarredIterator(
config.manifest_filepath,
Expand All @@ -121,61 +123,75 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet:
)
)
else:
# Format option 1:
# Assume it's [[path1], [path2], ...] (same for tarred_audio_filepaths).
# This is the format for multiple NeMo buckets.
# Note: we set "weights" here to be proportional to the number of utterances in each data source.
# this ensures that we distribute the data from each source uniformly throughout each epoch.
# Setting equal weights would exhaust the shorter data sources closer the towards the beginning
# of an epoch (or over-sample it in the case of infinite CutSet iteration with .repeat()).
# Format option 1:
# Assume it's [[path1, weight1], [path2, weight2], ...] (while tarred_audio_filepaths remain unchanged).
# Note: this option allows to manually set the weights for multiple datasets.
logging.info(
f"Initializing Lhotse CutSet from multiple tarred NeMo manifest sources with a weighted multiplexer. "
f"We found the following sources and weights: "
)
cutsets = []
weights = []
for manifest_info, (tar_path,) in zip(config.manifest_filepath, config.tarred_audio_filepaths):
if len(manifest_info) == 1:
(manifest_path,) = manifest_info
nemo_iter = LazyNeMoTarredIterator(
manifest_path=manifest_path, tar_paths=tar_path, shuffle_shards=config.shuffle, **common_kwargs
)
weight = len(nemo_iter)
else:
assert (
isinstance(manifest_info, Sequence)
and len(manifest_info) == 2
and isinstance(manifest_info[1], (int, float))
), (
"Supported inputs types for config.manifest_filepath are: "
"str | list[list[str]] | list[tuple[str, number]] "
"where str is a path and number is a mixing weight (it may exceed 1.0). "
f"We got: '{manifest_info}'"
)
manifest_path, weight = manifest_info
nemo_iter = LazyNeMoTarredIterator(
manifest_path=manifest_path, tar_paths=tar_path, shuffle_shards=config.shuffle, **common_kwargs
)
logging.info(f"- {manifest_path=} {weight=}")
if config.max_open_streams is not None:
for subiter in nemo_iter.to_shards():
cutsets.append(CutSet(subiter))
weights.append(weight)
else:
cutsets.append(CutSet(nemo_iter))
weights.append(weight)
if config.max_open_streams is not None:
cuts = CutSet.infinite_mux(
*cutsets, weights=weights, seed="trng", max_open_streams=config.max_open_streams
)
else:
cuts = CutSet.mux(*[cs.repeat() for cs in cutsets], weights=weights, seed="trng")
cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs))
else:
# Format option 1:
# Assume it's [[path1], [path2], ...] (same for tarred_audio_filepaths).
# This is the format for multiple NeMo buckets.
# Note: we set "weights" here to be proportional to the number of utterances in each data source.
# this ensures that we distribute the data from each source uniformly throughout each epoch.
# Setting equal weights would exhaust the shorter data sources closer the towards the beginning
# of an epoch (or over-sample it in the case of infinite CutSet iteration with .repeat()).
# Format option 1:
# Assume it's [[path1, weight1], [path2, weight2], ...] (while tarred_audio_filepaths remain unchanged).
# Note: this option allows to manually set the weights for multiple datasets.
logging.info(
f"Initializing Lhotse CutSet from a single NeMo manifest (non-tarred): '{config.manifest_filepath}'"
f"Initializing Lhotse CutSet from multiple tarred NeMo manifest sources with a weighted multiplexer. "
f"We found the following sources and weights: "
)
cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **common_kwargs))
cutsets = []
weights = []
tar_paths = config.tarred_audio_filepaths if is_tarred else repeat((None,))
# Create a stream for each dataset.
for manifest_info, (tar_path,) in zip(config.manifest_filepath, tar_paths):
# First, convert manifest_path[+tar_path] to an iterator.
manifest_path = manifest_info[0]
if is_tarred:
nemo_iter = LazyNeMoTarredIterator(
manifest_path=manifest_path, tar_paths=tar_path, shuffle_shards=config.shuffle, **common_kwargs
)
else:
nemo_iter = LazyNeMoIterator(manifest_path, **notar_kwargs, **common_kwargs)
# Then, determine the weight or use one provided
if len(manifest_info) == 1:
weight = len(nemo_iter)
else:
assert (
isinstance(manifest_info, Sequence)
and len(manifest_info) == 2
and isinstance(manifest_info[1], (int, float))
), (
"Supported inputs types for config.manifest_filepath are: "
"str | list[list[str]] | list[tuple[str, number]] "
"where str is a path and number is a mixing weight (it may exceed 1.0). "
f"We got: '{manifest_info}'"
)
weight = manifest_info[1]
logging.info(f"- {manifest_path=} {weight=}")
# [optional] When we have a limit on the number of open streams,
# split the manifest to individual shards if applicable.
# This helps the multiplexing achieve closer data distribution
# to the one desired in spite of the limit.
if config.max_open_streams is not None:
for subiter in nemo_iter.to_shards():
cutsets.append(CutSet(subiter))
weights.append(weight)
else:
cutsets.append(CutSet(nemo_iter))
weights.append(weight)
# Finally, we multiplex the dataset streams to mix the data.
cuts = mux(*cutsets, weights=weights, max_open_streams=config.max_open_streams)
return cuts


def mux(*cutsets: CutSet, weights: list[int | float], max_open_streams: int | None = None) -> CutSet:
"""
Helper function to call the right multiplexing method flavour in lhotse.
The result is always an infinitely iterable ``CutSet``, but depending on whether ``max_open_streams`` is set,
it will select a more appropriate multiplexing strategy.
"""
if max_open_streams is not None:
cuts = CutSet.infinite_mux(*cutsets, weights=weights, seed="trng", max_open_streams=max_open_streams)
else:
cuts = CutSet.mux(*[cs.repeat() for cs in cutsets], weights=weights, seed="trng")
return cuts
3 changes: 3 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class LhotseDataLoadingConfig:
# 5. Other Lhotse options.
text_field: str = "text" # key to read the transcript from
lang_field: str = "lang" # key to read the language tag from
# Enables iteration of NeMo non-tarred manifests that don't have a "sampling_rate" key without performing any I/O.
# Note that this will not allow actual dataloading; it's only for manifest iteration as Lhotse objects.
missing_sampling_rate_ok: bool = False


def get_lhotse_dataloader_from_config(
Expand Down
38 changes: 35 additions & 3 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,33 @@ class LazyNeMoIterator:
- "text" (overridable with ``text_field`` argument)
Specially supported keys are:
- [recommended] "sampling_rate" allows us to provide a valid Lhotse ``Recording`` object without checking the audio file
- "offset" for partial recording reads
- "lang" is mapped to Lhotse superivsion's language (overridable with ``lang_field`` argument)
Every other key found in the manifest will be attached to Lhotse Cut and accessible via ``cut.custom[key]``.
.. caution:: We will perform some I/O (as much as required by soundfile.info) to discover the sampling rate
of the audio file. If this is not acceptable, convert the manifest to Lhotse format which contains
sampling rate info.
sampling rate info. For pure metadata iteration purposes we also provide a ``missing_sampling_rate_ok`` flag that
will create only partially valid Lhotse objects (with metadata related to sampling rate / num samples missing).
Example::
>>> cuts = lhotse.CutSet(LazyNeMoIterator("nemo_manifests/train.json"))
"""

def __init__(self, path: str | Path, text_field: str = "text", lang_field: str = "lang") -> None:
def __init__(
self,
path: str | Path,
text_field: str = "text",
lang_field: str = "lang",
missing_sampling_rate_ok: bool = False,
) -> None:
self.source = LazyJsonlIterator(path)
self.text_field = text_field
self.lang_field = lang_field
self.missing_sampling_rate_ok = missing_sampling_rate_ok

@property
def path(self) -> str | Path:
Expand All @@ -68,7 +77,7 @@ def __iter__(self) -> Generator[Cut, None, None]:
audio_path = data.pop("audio_filepath")
duration = data.pop("duration")
offset = data.pop("offset", None)
recording = Recording.from_file(audio_path)
recording = self._create_recording(audio_path, duration, data.pop("sampling_rate", None))
cut = recording.to_cut()
if offset is not None:
cut = cut.truncate(offset=offset, duration=duration, preserve_id=True)
Expand All @@ -94,6 +103,29 @@ def __len__(self) -> int:
def __add__(self, other):
return LazyIteratorChain(self, other)

def _create_recording(self, audio_path: str, duration: float, sampling_rate: int | None = None,) -> Recording:
if sampling_rate is not None:
# TODO(pzelasko): It will only work with single-channel audio in the current shape.
return Recording(
id=audio_path,
sources=[AudioSource(type="file", channels=[0], source=audio_path)],
sampling_rate=sampling_rate,
num_samples=compute_num_samples(duration, sampling_rate),
duration=duration,
channel_ids=[0],
)
elif self.missing_sampling_rate_ok:
return Recording(
id=audio_path,
sources=[AudioSource(type="file", channels=[0], source=audio_path)],
sampling_rate=-1,
num_samples=-1,
duration=duration,
channel_ids=[0],
)
else:
return Recording.from_file(audio_path)


class LazyNeMoTarredIterator:
"""
Expand Down
45 changes: 45 additions & 0 deletions scripts/speech_recognition/convert_to_tarred_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@
"--buckets_num", type=int, default=1, help="Number of buckets to create based on duration.",
)

parser.add_argument(
"--dynamic_buckets_num",
type=int,
default=30,
help="Intended for dynamic (on-the-fly) bucketing; this option will not bucket your dataset during tar conversion. "
"Estimates optimal bucket duration bins for a given number of buckets.",
)

parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.")
parser.add_argument(
'--write_metadata',
Expand Down Expand Up @@ -207,6 +215,10 @@ class ASRTarredDatasetConfig:
shard_manifests: bool = True
keep_files_together: bool = False
force_codec: Optional[str] = None
use_lhotse: bool = False
use_bucketing: bool = False
num_buckets: Optional[int] = None
bucket_duration_bins: Optional[list[float]] = None


@dataclass
Expand Down Expand Up @@ -376,10 +388,43 @@ def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/",
metadata.dataset_config = config
metadata.num_samples_per_shard = len(new_entries) // config.num_shards

if args.buckets_num <= 1:
# Estimate and update dynamic bucketing args
bucketing_kwargs = self.estimate_dynamic_bucketing_duration_bins(
new_manifest_path, num_buckets=args.dynamic_buckets_num
)
for k, v in bucketing_kwargs.items():
setattr(metadata.dataset_config, k, v)

# Write metadata
metadata_yaml = OmegaConf.structured(metadata)
OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True)

def estimate_dynamic_bucketing_duration_bins(self, manifest_path: str, num_buckets: int = 30) -> dict:
from lhotse import CutSet
from lhotse.dataset.sampling.dynamic_bucketing import estimate_duration_buckets
from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator

cuts = CutSet(LazyNeMoIterator(manifest_path, missing_sampling_rate_ok=True))
bins = estimate_duration_buckets(cuts, num_buckets=num_buckets)
print(
f"Note: we estimated the optimal bucketing duration bins for {num_buckets} buckets. "
"You can enable dynamic bucketing by setting the following options in your training script:\n"
" use_lhotse=true\n"
" use_bucketing=true\n"
f" num_buckets={num_buckets}\n"
f" bucket_duration_bins=[{','.join(map(str, bins))}]\n"
" batch_duration=<tune-this-value>\n"
"If you'd like to use a different number of buckets, re-estimate this option manually using "
"scripts/speech_recognition/estimate_duration_bins.py"
)
return dict(
use_lhotse=True,
use_bucketing=True,
num_buckets=num_buckets,
bucket_duration_bins=list(map(float, bins)), # np.float -> float for YAML serialization
)

def create_concatenated_dataset(
self,
base_manifest_path: str,
Expand Down
Loading

0 comments on commit 03a7e4f

Please sign in to comment.