From 1e7a297aecb3f3157e321392dfa72c39ace2d129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Jukic=CC=81?= Date: Mon, 22 Apr 2024 17:05:06 -0700 Subject: [PATCH] Propagate channel selector for AED model + add channel selector to get_lhotse_dataloader_from config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- .../asr/models/aed_multitask_models.py | 1 + .../common/data/lhotse/dataloader.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 5cda453db45d6..ce0f59d2a5125 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -875,6 +875,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'drop_last': False, 'text_field': config.get('text_field', 'answer'), 'lang_field': config.get('lang_field', 'target_lang'), + 'channel_selector': config.get('channel_selector', 'None'), } temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index b32f067c14a94..19e4dee61824f 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -90,6 +90,7 @@ class LhotseDataLoadingConfig: seed: int | str = "randomized" # int | "randomized" | "trng"; the latter two are lazily resolved by Lhotse in dloading worker processes num_workers: int = 0 pin_memory: bool = False + channel_selector: int | str | None = None # 4. Optional Lhotse data augmentation. # a. On-the-fly noise/audio mixing. @@ -157,6 +158,11 @@ def get_lhotse_dataloader_from_config( # 1. Load a manifest as a Lhotse CutSet. cuts, is_tarred = read_cutset_from_config(config) + # Apply channel selector + if config.channel_selector is not None: + logging.info('Using channel selector %s.', config.channel_selector) + cuts = cuts.map(partial(_select_channel, channel_selector=config.channel_selector), apply_fn=None) + # Resample as a safeguard; it's a no-op when SR is already OK cuts = cuts.resample(config.sample_rate) @@ -438,3 +444,20 @@ def _flatten_alt_text(cut) -> list: text_instance.custom = {"text": data.pop("text"), "lang": data.pop("lang"), **data} ans.append(text_instance) return ans + + +def _select_channel(cut, channel_selector: int | str) -> list: + if isinstance(channel_selector, int): + channel_idx = channel_selector + elif isinstance(channel_selector, str): + if channel_selector in cut.custom: + channel_idx = cut.custom[channel_selector] + else: + raise ValueError(f"Channel selector {channel_selector} not found in cut.custom") + + if channel_idx >= cut.num_channels: + raise ValueError( + f"Channel index {channel_idx} is larger than the actual number of channels {cut.num_channels}" + ) + + return cut.with_channels(channel_idx)