Skip to content

Commit

Permalink
[ASR] Support for transcription of multi-channel audio for AED models (
Browse files Browse the repository at this point in the history
…NVIDIA#9007)

* Propagate channel selector for AED model + add channel selector to get_lhotse_dataloader_from config

Signed-off-by: Ante Jukić <[email protected]>

* Included comments

Signed-off-by: Ante Jukić <[email protected]>

* Added unit test

Signed-off-by: Ante Jukić <[email protected]>

---------

Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju authored Apr 30, 2024
1 parent 2b6bd58 commit 7da9121
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 0 deletions.
1 change: 1 addition & 0 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'drop_last': False,
'text_field': config.get('text_field', 'answer'),
'lang_field': config.get('lang_field', 'target_lang'),
'channel_selector': config.get('channel_selector', None),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
Expand Down
28 changes: 28 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class LhotseDataLoadingConfig:
seed: int | str = 0
num_workers: int = 0
pin_memory: bool = False
channel_selector: int | str | None = None

# 4. Optional Lhotse data augmentation.
# a. On-the-fly noise/audio mixing.
Expand Down Expand Up @@ -156,6 +157,11 @@ def get_lhotse_dataloader_from_config(
# 1. Load a manifest as a Lhotse CutSet.
cuts, is_tarred = read_cutset_from_config(config)

# Apply channel selector
if config.channel_selector is not None:
logging.info('Using channel selector %s.', config.channel_selector)
cuts = cuts.map(partial(_select_channel, channel_selector=config.channel_selector))

# Resample as a safeguard; it's a no-op when SR is already OK
cuts = cuts.resample(config.sample_rate)

Expand Down Expand Up @@ -443,3 +449,25 @@ def _flatten_alt_text(cut) -> list:
text_instance.custom = {"text": data.pop("text"), "lang": data.pop("lang"), **data}
ans.append(text_instance)
return ans


def _select_channel(cut, channel_selector: int | str) -> list:
if isinstance(channel_selector, int):
channel_idx = channel_selector
elif isinstance(channel_selector, str):
if channel_selector in cut.custom:
channel_idx = cut.custom[channel_selector]
else:
raise ValueError(f"Channel selector {channel_selector} not found in cut.custom")

if channel_idx >= cut.num_channels:
raise ValueError(
f"Channel index {channel_idx} is larger than the actual number of channels {cut.num_channels}"
)

if cut.num_channels == 1:
# one channel available and channel_idx==0
return cut
else:
# with_channels only defined on MultiCut
return cut.with_channels(channel_idx)
100 changes: 100 additions & 0 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,51 @@ def nemo_manifest_path(cutset_path: Path):
return p


@pytest.fixture(scope="session")
def mc_cutset_path(tmp_path_factory) -> Path:
"""10 two-channel utterances of length 1s as a Lhotse CutSet."""
from lhotse import CutSet, MultiCut
from lhotse.testing.dummies import DummyManifest

num_examples = 10 # number of examples
num_channels = 2 # number of channels per example

# create a dummy manifest with single-channel examples
sc_cuts = DummyManifest(CutSet, begin_id=0, end_id=num_examples * num_channels, with_data=True)
mc_cuts = []

for n in range(num_examples):
# sources for individual channels
mc_sources = []
for channel in range(num_channels):
source = sc_cuts[n * num_channels + channel].recording.sources[0]
source.channels = [channel]
mc_sources.append(source)

# merge recordings
rec = Recording(
sources=mc_sources,
id=f'mc-dummy-recording-{n:02d}',
num_samples=sc_cuts[0].num_samples,
duration=sc_cuts[0].duration,
sampling_rate=sc_cuts[0].sampling_rate,
)

# multi-channel cut
cut = MultiCut(
recording=rec, id=f'mc-dummy-cut-{n:02d}', start=0, duration=1.0, channel=list(range(num_channels))
)
mc_cuts.append(cut)

mc_cuts = CutSet.from_cuts(mc_cuts)

tmp_path = tmp_path_factory.mktemp("data")
p = tmp_path / "mc_cuts.jsonl.gz"
pa = tmp_path / "mc_audio"
mc_cuts.save_audios(pa).to_file(p)
return p


@pytest.fixture(scope="session")
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> Tuple[str, str]:
"""10 utterances of length 1s as a NeMo tarred manifest."""
Expand Down Expand Up @@ -247,6 +292,61 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path):
# exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts


def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path):
# Dataloader without channel selector
config = OmegaConf.create(
{
"cuts_path": mc_cutset_path,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 4,
"seed": 0,
}
)

dl = get_lhotse_dataloader_from_config(
config=config, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
)
batches = [b for b in dl]
assert len(batches) == 3

# 1.0s = 16000 samples, two channels, note the constant duration and batch size
assert batches[0]["audio"].shape == (4, 2, 16000)
assert batches[1]["audio"].shape == (4, 2, 16000)
assert batches[2]["audio"].shape == (2, 2, 16000)
# exactly 10 cuts were used

# Apply channel selector
for channel_selector in [None, 0, 1]:

config_cs = OmegaConf.create(
{
"cuts_path": mc_cutset_path,
"channel_selector": channel_selector,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 4,
"seed": 0,
}
)

dl_cs = get_lhotse_dataloader_from_config(
config=config_cs, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
)

for n, b_cs in enumerate(dl_cs):
if channel_selector is None:
# no channel selector, needs to match the original dataset
assert torch.equal(b_cs["audio"], batches[n]["audio"])
else:
# channel selector, needs to match the selected channel
assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :])


@requires_torchaudio
def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path):
config = OmegaConf.create(
Expand Down

0 comments on commit 7da9121

Please sign in to comment.