Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] Support for transcription of multi-channel audio for AED models #9007

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'drop_last': False,
'text_field': config.get('text_field', 'answer'),
'lang_field': config.get('lang_field', 'target_lang'),
'channel_selector': config.get('channel_selector', None),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
Expand Down
28 changes: 28 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class LhotseDataLoadingConfig:
seed: int | str = 0
num_workers: int = 0
pin_memory: bool = False
channel_selector: int | str | None = None

# 4. Optional Lhotse data augmentation.
# a. On-the-fly noise/audio mixing.
Expand Down Expand Up @@ -156,6 +157,11 @@ def get_lhotse_dataloader_from_config(
# 1. Load a manifest as a Lhotse CutSet.
cuts, is_tarred = read_cutset_from_config(config)

# Apply channel selector
if config.channel_selector is not None:
logging.info('Using channel selector %s.', config.channel_selector)
cuts = cuts.map(partial(_select_channel, channel_selector=config.channel_selector))

# Resample as a safeguard; it's a no-op when SR is already OK
cuts = cuts.resample(config.sample_rate)

Expand Down Expand Up @@ -443,3 +449,25 @@ def _flatten_alt_text(cut) -> list:
text_instance.custom = {"text": data.pop("text"), "lang": data.pop("lang"), **data}
ans.append(text_instance)
return ans


def _select_channel(cut, channel_selector: int | str) -> list:
if isinstance(channel_selector, int):
channel_idx = channel_selector
elif isinstance(channel_selector, str):
if channel_selector in cut.custom:
channel_idx = cut.custom[channel_selector]
else:
raise ValueError(f"Channel selector {channel_selector} not found in cut.custom")

if channel_idx >= cut.num_channels:
raise ValueError(
f"Channel index {channel_idx} is larger than the actual number of channels {cut.num_channels}"
)

if cut.num_channels == 1:
# one channel available and channel_idx==0
return cut
else:
# with_channels only defined on MultiCut
return cut.with_channels(channel_idx)
100 changes: 100 additions & 0 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,51 @@ def nemo_manifest_path(cutset_path: Path):
return p


@pytest.fixture(scope="session")
def mc_cutset_path(tmp_path_factory) -> Path:
"""10 two-channel utterances of length 1s as a Lhotse CutSet."""
from lhotse import CutSet, MultiCut
from lhotse.testing.dummies import DummyManifest

num_examples = 10 # number of examples
num_channels = 2 # number of channels per example

# create a dummy manifest with single-channel examples
sc_cuts = DummyManifest(CutSet, begin_id=0, end_id=num_examples * num_channels, with_data=True)
mc_cuts = []

for n in range(num_examples):
# sources for individual channels
mc_sources = []
for channel in range(num_channels):
source = sc_cuts[n * num_channels + channel].recording.sources[0]
source.channels = [channel]
mc_sources.append(source)

# merge recordings
rec = Recording(
sources=mc_sources,
id=f'mc-dummy-recording-{n:02d}',
num_samples=sc_cuts[0].num_samples,
duration=sc_cuts[0].duration,
sampling_rate=sc_cuts[0].sampling_rate,
)

# multi-channel cut
cut = MultiCut(
recording=rec, id=f'mc-dummy-cut-{n:02d}', start=0, duration=1.0, channel=list(range(num_channels))
)
mc_cuts.append(cut)

mc_cuts = CutSet.from_cuts(mc_cuts)

tmp_path = tmp_path_factory.mktemp("data")
p = tmp_path / "mc_cuts.jsonl.gz"
pa = tmp_path / "mc_audio"
mc_cuts.save_audios(pa).to_file(p)
return p


@pytest.fixture(scope="session")
def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> Tuple[str, str]:
"""10 utterances of length 1s as a NeMo tarred manifest."""
Expand Down Expand Up @@ -247,6 +292,61 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path):
# exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts


def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path):
# Dataloader without channel selector
config = OmegaConf.create(
{
"cuts_path": mc_cutset_path,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 4,
"seed": 0,
}
)

dl = get_lhotse_dataloader_from_config(
config=config, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
)
batches = [b for b in dl]
assert len(batches) == 3

# 1.0s = 16000 samples, two channels, note the constant duration and batch size
assert batches[0]["audio"].shape == (4, 2, 16000)
assert batches[1]["audio"].shape == (4, 2, 16000)
assert batches[2]["audio"].shape == (2, 2, 16000)
# exactly 10 cuts were used

# Apply channel selector
for channel_selector in [None, 0, 1]:

config_cs = OmegaConf.create(
{
"cuts_path": mc_cutset_path,
"channel_selector": channel_selector,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 4,
"seed": 0,
}
)

dl_cs = get_lhotse_dataloader_from_config(
config=config_cs, global_rank=0, world_size=1, dataset=UnsupervisedAudioDataset()
)

for n, b_cs in enumerate(dl_cs):
if channel_selector is None:
# no channel selector, needs to match the original dataset
assert torch.equal(b_cs["audio"], batches[n]["audio"])
else:
# channel selector, needs to match the selected channel
assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :])


@requires_torchaudio
def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path):
config = OmegaConf.create(
Expand Down
Loading