Skip to content

Commit

Permalink
[Tests] Correct Wav2Vec2 & WavLM tests (#15015)
Browse files Browse the repository at this point in the history
* up

* up

* up
  • Loading branch information
patrickvonplaten authored Jan 3, 2022
1 parent 0b4c3a1 commit dbac889
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/self-scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ jobs:
- name: Install dependencies
run: |
apt -y update && apt install -y libsndfile1-dev git
apt -y update && apt install -y libsndfile1-dev git espeak-ng
pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
Expand Down
15 changes: 7 additions & 8 deletions tests/test_modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import copy
import glob
import inspect
import math
import unittest
Expand All @@ -23,6 +24,7 @@
import pytest
from datasets import load_dataset

from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
Expand Down Expand Up @@ -485,8 +487,6 @@ def test_compute_mask_indices_overlap(self):
@slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").filter(
Expand Down Expand Up @@ -556,18 +556,17 @@ def test_inference_ctc_robust_batched(self):
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))

resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]

model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")

input_values = processor(resampled_audio, return_tensors="tf").input_values
input_values = processor(sample, return_tensors="tf").input_values

logits = model(input_values).logits

transcription = processor.batch_decode(logits.numpy()).text

self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
31 changes: 5 additions & 26 deletions tests/test_modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
""" Testing suite for the PyTorch WavLM model. """

import copy
import math
import unittest

Expand Down Expand Up @@ -452,30 +451,9 @@ def _mock_init_weights(self, module):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)

# overwrite from test_modeling_common
# as WavLM is not very precise
@unittest.skip(reason="Feed forward chunking is not implemented for WavLM")
def test_feed_forward_chunking(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()

hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]

torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()

hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-2))
pass

@slow
def test_model_from_pretrained(self):
Expand Down Expand Up @@ -528,7 +506,7 @@ def test_inference_base(self):
def test_inference_large(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus", return_attention_mask=True
"microsoft/wavlm-large", return_attention_mask=True
)

input_speech = self._load_datasamples(2)
Expand All @@ -544,8 +522,9 @@ def test_inference_large(self):
)

EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
[[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
)

self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2))

def test_inference_diarization(self):
Expand Down

0 comments on commit dbac889

Please sign in to comment.