Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix no space left error #76

Merged
merged 11 commits into from
Aug 19, 2024
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements/docs.txt
python -m pip cache purge
# we are being quite strict here, but hopefully that will not be too inconvenient
- name: Checking that documentation builds with no warnings and all links are working
run: |
Expand All @@ -44,6 +45,7 @@ jobs:
pip install -r requirements/main.txt
pip install -r requirements/tests.txt
pip install -r requirements/huggingface.txt
python -m pip cache purge
- name: Checking that SDP can be imported and basic configs can be run without nemo
# in the future this might fail if some runtime tests require nemo
# in that case this test will need to be changed
Expand All @@ -69,6 +71,7 @@ jobs:
sudo apt-get install -y libsndfile1 ffmpeg sox libsox-fmt-mp3
pip install Cython wheel # need to pre-install to avoid error in nemo installation
pip install "nemo_toolkit[asr,nlp]"
python -m pip cache purge
- name: Run all tests
env:
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
Expand Down
153 changes: 21 additions & 132 deletions sdp/processors/nemo/transcribe_speech.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,52 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is copied over from https://github.com/NVIDIA/NeMo/blob/v1.23.0/examples/asr/transcribe_speech.py.
# It is currently only compatible with NeMo v1.23.0. To use a different version of NeMo, please modify the file.

import contextlib
import glob
import json
import os
import time
from dataclasses import dataclass, field, is_dataclass
from tempfile import NamedTemporaryFile
from dataclasses import dataclass, is_dataclass
from typing import List, Optional, Union

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel
from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
prepare_audio_data,
read_and_maybe_sort_manifest,
restore_transcription_order,
setup_model,
transcribe_partial_audio,
write_transcription,
)
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.core.config import hydra_runner
from nemo.utils import logging

Expand Down Expand Up @@ -99,8 +79,6 @@
langid: Str used for convert_num_to_words during groundtruth cleaning
use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER)

calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset.

# Usage
ASR model can be specified by either "model_path" or "pretrained_name".
Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
Expand Down Expand Up @@ -140,12 +118,11 @@ class TranscriptionConfig:
pretrained_name: Optional[str] = None # Name of a pretrained model
audio_dir: Optional[str] = None # Path to a directory which contains audio files
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
channel_selector: Optional[Union[int, str]] = (
None # Used to select a single channel from multichannel audio, or use average across channels
)
channel_selector: Optional[
Union[int, str]
] = None # Used to select a single channel from multichannel audio, or use average across channels
audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest
eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation
presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction

# General configs
output_filename: Optional[str] = None
Expand All @@ -170,8 +147,6 @@ class TranscriptionConfig:
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
amp: bool = False
amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp
compute_dtype: str = "float32"
matmul_precision: str = "highest" # Literal["highest", "high", "medium"]
audio_type: str = "wav"

# Recompute model transcription, even if the output folder exists with scores.
Expand All @@ -181,19 +156,10 @@ class TranscriptionConfig:
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig()

# Decoding strategy for RNNT models
# enable CUDA graphs for transcription
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)

# Decoding strategy for AED models
multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig()
# Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs:
# Implicit single-turn assuming default role='user' (works with Canary-1B)
# +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes
# Explicit single-turn prompt:
# +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes
# Explicit multi-turn prompt:
# +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]'
prompt: dict = field(default_factory=dict)

# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
decoder_type: Optional[str] = None
Expand All @@ -218,15 +184,11 @@ class TranscriptionConfig:

# key for groundtruth text in manifest
gt_text_attr_name: str = "text"
gt_lang_attr_name: str = "lang"

# Use model's transcribe() function instead of transcribe_partial_audio() by default
# Only use transcribe_partial_audio() when the audio is too long to fit in memory
# Your manifest input should have `offset` field to use transcribe_partial_audio()
allow_partial_transcribe: bool = False
extract_nbest: bool = False # Extract n-best hypotheses from the model

calculate_rtfx: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
Expand Down Expand Up @@ -255,7 +217,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ")

# setup GPU
torch.set_float32_matmul_precision(cfg.matmul_precision)
if cfg.cuda is None:
if torch.cuda.is_available():
device = [0] # use 0th CUDA device
Expand Down Expand Up @@ -286,14 +247,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
asr_model.set_trainer(trainer)
asr_model = asr_model.eval()

if cfg.compute_dtype != "float32" and cfg.amp:
raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32")

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16

if cfg.compute_dtype != "float32":
asr_model.to(getattr(torch, cfg.compute_dtype))

# we will adjust this flag if the model does not support it
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs
Expand All @@ -319,19 +272,13 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
if isinstance(asr_model.decoding, MultiTaskDecoding):
cfg.multitask_decoding.compute_langs = cfg.compute_langs
cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment
if cfg.extract_nbest:
cfg.multitask_decoding.beam.return_best_hypothesis = False
cfg.return_hypotheses = True
asr_model.change_decoding_strategy(cfg.multitask_decoding)
elif cfg.decoder_type is not None:
# TODO: Support compute_langs in CTC eventually
if cfg.compute_langs and cfg.decoder_type == 'ctc':
raise ValueError("CTC models do not support `compute_langs` at the moment")

decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding
if cfg.extract_nbest:
decoding_cfg.beam.return_best_hypothesis = False
cfg.return_hypotheses = True
decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it
if 'preserve_alignments' in decoding_cfg:
decoding_cfg.preserve_alignments = preserve_alignment
Expand All @@ -344,9 +291,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis

# Check if ctc or rnnt model
elif hasattr(asr_model, 'joint'): # RNNT model
if cfg.extract_nbest:
cfg.rnnt_decoding.beam.return_best_hypothesis = False
cfg.return_hypotheses = True
cfg.rnnt_decoding.fused_batch_size = -1
cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps
cfg.rnnt_decoding.compute_langs = cfg.compute_langs
Expand All @@ -358,9 +302,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
if cfg.compute_langs:
raise ValueError("CTC models do not support `compute_langs` at the moment.")
cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps
if cfg.extract_nbest:
cfg.ctc_decoding.beam.return_best_hypothesis = False
cfg.return_hypotheses = True

asr_model.change_decoding_strategy(cfg.ctc_decoding)

Expand All @@ -370,27 +311,14 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc"
):
cfg.decoding = cfg.ctc_decoding
elif isinstance(asr_model.decoding, MultiTaskDecoding):
cfg.decoding = cfg.multitask_decoding
else:
cfg.decoding = cfg.rnnt_decoding

remove_path_after_done = None
if isinstance(asr_model, EncDecMultiTaskModel):
# Special case for EncDecMultiTaskModel, where the input manifest is directly passed into the model's transcribe() function
partial_audio = False
if cfg.audio_dir is not None and not cfg.append_pred:
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
else:
assert cfg.dataset_manifest is not None
if cfg.presort_manifest:
with NamedTemporaryFile("w", suffix=".json", delete=False) as f:
for item in read_and_maybe_sort_manifest(cfg.dataset_manifest, try_sort=True):
item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest)
print(json.dumps(item), file=f)
cfg.dataset_manifest = f.name
remove_path_after_done = f.name
filepaths = cfg.dataset_manifest
filepaths = cfg.dataset_manifest
assert cfg.dataset_manifest is not None
else:
# prepare audio filepaths and decide wether it's partial audio
filepaths, partial_audio = prepare_audio_data(cfg)
Expand All @@ -406,7 +334,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
else:

@contextlib.contextmanager
def autocast(dtype=None, enabled=True):
def autocast(dtype=None):
yield

# Compute output filename
Expand All @@ -422,22 +350,10 @@ def autocast(dtype=None, enabled=True):

# transcribe audio

if cfg.calculate_rtfx:
total_duration = 0.0

with open(cfg.dataset_manifest, "rt") as fh:
for line in fh:
item = json.loads(line)
if "duration" not in item:
raise ValueError(
f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field."
)
total_duration += item["duration"]
amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16

with autocast(dtype=amp_dtype, enabled=cfg.amp):
with autocast(dtype=amp_dtype):
with torch.no_grad():
if cfg.calculate_rtfx:
start_time = time.time()
if partial_audio:
transcriptions = transcribe_partial_audio(
asr_model=asr_model,
Expand All @@ -450,40 +366,21 @@ def autocast(dtype=None, enabled=True):
decoder_type=cfg.decoder_type,
)
else:
override_cfg = asr_model.get_transcribe_config()
override_cfg.batch_size = cfg.batch_size
override_cfg.num_workers = cfg.num_workers
override_cfg.return_hypotheses = cfg.return_hypotheses
override_cfg.channel_selector = cfg.channel_selector
override_cfg.augmentor = augmentor
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name
if hasattr(override_cfg, "prompt"):
override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt))

transcriptions = asr_model.transcribe(
audio=filepaths,
override_config=override_cfg,
paths2audio_files=filepaths,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=cfg.return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
if cfg.calculate_rtfx:
transcribe_time = time.time() - start_time

if cfg.dataset_manifest is not None:
logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")
if cfg.presort_manifest:
transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions)
else:
logging.info(f"Finished transcribing {len(filepaths)} files !")
logging.info(f"Finished transcribing {len(filepaths)} files !")
logging.info(f"Writing transcriptions into file: {cfg.output_filename}")

# if transcriptions form a tuple of (best_hypotheses, all_hypotheses)
# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
if type(transcriptions) == tuple and len(transcriptions) == 2:
if cfg.extract_nbest:
# extract all hypotheses if exists
transcriptions = transcriptions[1]
else:
# extract just best hypothesis
transcriptions = transcriptions[0]
transcriptions = transcriptions[0]

if cfg.return_transcriptions:
return transcriptions
Expand All @@ -499,11 +396,6 @@ def autocast(dtype=None, enabled=True):
)
logging.info(f"Finished writing predictions to {output_filename}!")

# clean-up
if cfg.presort_manifest is not None:
if remove_path_after_done is not None:
os.unlink(remove_path_after_done)

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
Expand All @@ -518,9 +410,6 @@ def autocast(dtype=None, enabled=True):
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

if cfg.calculate_rtfx:
logging.info(f"Dataset RTFx {(total_duration/transcribe_time)}")

return cfg


Expand Down
Loading