From 1f3551142683f5d40e476826c55e04f646736f3d Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Fri, 31 May 2024 14:00:09 +0000 Subject: [PATCH] Updated LangIDWhisper processor Signed-off-by: Sasha Meister --- .../huggingface/speech_recognition.py | 60 ++++++++++++++----- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/sdp/processors/huggingface/speech_recognition.py b/sdp/processors/huggingface/speech_recognition.py index ae8ea0d7..3f9907aa 100644 --- a/sdp/processors/huggingface/speech_recognition.py +++ b/sdp/processors/huggingface/speech_recognition.py @@ -13,7 +13,6 @@ # limitations under the License. import json -import librosa from pathlib import Path from collections import Counter @@ -40,6 +39,9 @@ def __init__( pretrained_model: str, output_lang_key: str, device: str = None, + segment_duration: float = np.inf, + num_segments: int = 1, + random_seed: int = None, **kwargs, ): super().__init__(**kwargs) @@ -54,6 +56,9 @@ def __init__( self.pretrained_model = pretrained_model self.device = device self.output_lang_key = output_lang_key + self.segment_duration = segment_duration + self.num_segments = num_segments + self.random_seed = random_seed if self.device is None: if torch.cuda.is_available(): @@ -69,35 +74,62 @@ def process(self): with Path(self.output_manifest_file).open('w') as f: for item in tqdm(json_list): - pred_lang = self.segment(item["audio_filepath"], segment_duration=30, num_segments=3, random_seed=None) + pred_lang = self.get_label(item["audio_filepath"]) item[self.output_lang_key] = pred_lang f.write(json.dumps(item, ensure_ascii=False) + '\n') - def segment(self, path2audio_file, segment_duration, num_segments, random_seed): - audio, sr = sf.read(path2audio_file) + def get_label(self, path2audio_file): + audio, sample_rate = sf.read(path2audio_file) audio = np.float32(audio) audio_length = audio.shape[0] - duration = sr * segment_duration - if duration > audio_length: - duration = audio_length + audio_segment_samples = sample_rate * self.segment_duration + segments_in_audio = int(audio_length / audio_segment_samples) + segment_starts = [] + segment_ends = [] + + np.random.seed(self.random_seed) + + if segments_in_audio <= 1: + segment_starts = [0] + segment_ends = [audio_length] + else: + if segments_in_audio > self.num_segments: + segments_in_audio = self.num_segments + + long_segment_duration = int(audio_length / segments_in_audio) + + for segment_no in range(segments_in_audio): + long_start_segment = long_segment_duration * segment_no + long_end_segment = long_segment_duration * (segment_no + 1) + segment_start = np.random.randint(long_start_segment, long_end_segment - audio_segment_samples) + segment_end = segment_start + audio_segment_samples + segment_starts.append(segment_start) + segment_ends.append(segment_end) + + label_id_list = [] - np.random.seed(random_seed) - starts = np.random.randint(0, audio_length - duration + 1, size=num_segments) - for start in starts: - audio_segm = audio[start : start + duration] - audio_segm = self.whisper.pad_or_trim(audio_segm) - mel = self.whisper.log_mel_spectrogram(audio_segm) + + n_mels = 80 + + if self.pretrained_model = "large-v3": + n_mels=128 + + for segment_start, segment_end in zip(segment_starts, segment_ends): + audio_segement = audio[segment_start:segment_end] + audio_segement = self.whisper.pad_or_trim(audio_segement) + mel = self.whisper.log_mel_spectrogram(audio_segement, n_mels) mel = mel.to(self.device) _, probs = self.model.detect_language(mel) lang = max(probs, key=probs.get) label_id_list.append(lang) - + m_label_id = Counter(label_id_list).most_common(1)[0][0] return m_label_id + class ASRWhisper(BaseProcessor): """