Skip to content

Commit

Permalink
add extrac hf text and update
Browse files Browse the repository at this point in the history
Signed-off-by: stevehuang52 <[email protected]>
  • Loading branch information
stevehuang52 committed Oct 31, 2023
1 parent 8704df3 commit baea154
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ model:
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 0
shuffle_n: 2048
num_workers: 8
pin_memory: true
use_start_end_token: false

Expand All @@ -72,6 +73,7 @@ model:
sample_rate: ${model.sample_rate}
batch_size: 8
shuffle: false
shuffle_n: 2048
num_workers: 8
pin_memory: true
use_start_end_token: false
Expand All @@ -97,6 +99,7 @@ model:
sample_rate: ${model.sample_rate}
batch_size: 8
shuffle: false
shuffle_n: 2048
num_workers: 8
pin_memory: true
use_start_end_token: false
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/data/huggingface/hf_audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def __init__(
global_rank: int = 0,
world_size: int = 0,
shuffle_n: int = 0,
shuffle_seed: int = 1234,
shuffle_seed: Optional[int] = None,
normalize_text: bool = False,
symbols_to_keep: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -486,7 +486,7 @@ def __init__(
global_rank: int = 0,
world_size: int = 0,
shuffle_n: int = 0,
shuffle_seed: int = 1234,
shuffle_seed: Optional[int] = None,
parser: Union[str, Callable] = 'en',
blank_index: int = -1,
unk_index: int = -1,
Expand Down Expand Up @@ -557,7 +557,7 @@ def __init__(
global_rank: int = 0,
world_size: int = 0,
shuffle_n: int = 0,
shuffle_seed: int = 1234,
shuffle_seed: Optional[int] = None,
use_start_end_token: bool = True,
normalize_text: bool = False,
symbols_to_keep: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_hf_audio_to_text_bpe_dataset(
global_rank=global_rank,
world_size=world_size,
shuffle_n=config.get("shuffle_n", 2048),
shuffle_seed=config.get("shuffle_seed", 42),
shuffle_seed=config.get("shuffle_seed", None),
use_start_end_token=config.get('use_start_end_token', True),
normalize_text=config.get('normalize_text', False),
symbols_to_keep=config.get('symbols_to_keep', None),
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_hf_audio_to_text_char_dataset(
global_rank=global_rank,
world_size=world_size,
shuffle_n=config.get("shuffle_n", 2048),
shuffle_seed=config.get("shuffle_seed", 42),
shuffle_seed=config.get("shuffle_seed", None),
parser=config.get("parser", "en"),
blank_index=config.get("blank_index", -1),
unk_index=config.get("unk_index", -1),
Expand Down
34 changes: 34 additions & 0 deletions scripts/tokenizers/conf/huggingface_data_tokenizer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

# num workers to use for extracting text from datasets.
num_workers: 8

# simple text cleaning, by default converts all chars to lower-case and only keeps alpha-numeric chars.
normalize_text: true
symbols_to_keep: ["'"] # a list of symbols to keep during text cleaning.

# the key for groundtruth transcription, e.g., MCV usually uses "sentence" while some others use "text"
text_key: "text" # the key for groundtruth transcription, e.g., MCV usually uses "sentence" while some others use "text"
num_proc: 4 # num processes to use for downloading HF datasets

data_path: "librispeech_asr"
data_name: null
streaming: true

hf_data_cfg: # hf_data_cfg can be a ListConfig or DictConfig. Params for each data are passed into huggingface load_dataset(). Add more params if needed
- path: ${data_path}
name: ${data_name}
split: 'train.clean.360'
streaming: ${streaming}
num_proc: ${num_proc}
- path: ${data_path}
name: ${data_name}
split: 'train.clean.100'
streaming: ${streaming}
num_proc: ${num_proc}
- path: ${data_path}
name: ${data_name}
split: 'train.other.500'
streaming: ${streaming}
num_proc: ${num_proc}

output_file: "librispeech_asr_train960.txt"
103 changes: 103 additions & 0 deletions scripts/tokenizers/get_hf_text_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to download text corpus from HuggingFace datasets,
where the saved corpus can be further used to train a tokenizer using `process_asr_text_tokenizer.py`.
Usage:
```
python get_hf_text_data.py --config-path="conf" --config-name="huggingface_data_tokenizer"
```
Please refer to "conf/huggingface_data_tokenizer.yaml" for more details.
"""


import os
from itertools import repeat
from multiprocessing import Pool
from pathlib import Path

import datasets as hf_datasets
from omegaconf import OmegaConf, open_dict

from nemo.core.config import hydra_runner
from nemo.utils import logging


def clean_text(text: str, symbols_to_keep=None):
symbols_to_keep = [x for x in symbols_to_keep] if symbols_to_keep is not None else []
text = text.lower()
# only keep alphanumeric characters, spaces and symbols defined in self.symbols_to_keep
text = ''.join([c for c in text if c.isalnum() or c.isspace() or c in symbols_to_keep])
return text


def get_nested_dict_value(dictionary: dict, key: str):
"""
the key should be a string of nested keys separated by `.`, e.g. `key1.key2.key3`,
then the returned value will be `dictionary[key1][key2][key3]`
"""
nested_keys = key.split(".")
result = dictionary
for k in nested_keys:
if k not in result:
raise KeyError(
f"Key `{key}` not found in [{result.keys()}], target is {nested_keys}, input is {dictionary}"
)
result = result[k]
return result


def worker(x):
sample, cfg = x
text = get_nested_dict_value(sample, cfg.text_key)
if cfg.normalize_text:
text = clean_text(text, cfg.symbols_to_keep)
return text


@hydra_runner(config_path="conf", config_name="huggingface_data_tokenizer")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(OmegaConf.to_yaml(cfg, resolve=True))

if cfg.output_file is None:
cfg.output_file = 'huggingface_text_corpus.txt'

if Path(cfg.output_file).exists():
logging.info(f"Output file {cfg.output_file} already exists, removing it...")
os.system(f"rm {cfg.output_file}")

for data_cfg in cfg.hf_data_cfg:
if 'num_proc' in data_cfg and data_cfg.get('streaming', False):
logging.warning("num_proc is not supported for streaming datasets, removing it from config")
with open_dict(data_cfg):
data_cfg.pop('num_proc')
logging.info(
f"Loading from HuggingFace datasets library with config: {OmegaConf.to_container(data_cfg, resolve=True)}"
)
dataset = hf_datasets.load_dataset(**data_cfg)
logging.info("Start extracting text from dataset...")
with Pool(cfg.num_workers) as p:
text_corpus = p.map(worker, zip(dataset, repeat(cfg)))
with Path(cfg.output_file).open('a') as f:
for line in text_corpus:
f.write(f"{line}\n")
logging.info(f"Finished processing {len(text_corpus)} samples from {data_cfg}")
logging.info("All Done!")


if __name__ == '__main__':
main()

0 comments on commit baea154

Please sign in to comment.