Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] Correct Wav2Vec2 & WavLM tests #15015

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/self-scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ jobs:

- name: Install dependencies
run: |
apt -y update && apt install -y libsndfile1-dev git
apt -y update && apt install -y libsndfile1-dev git espeak-ng
pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
Expand Down
15 changes: 7 additions & 8 deletions tests/test_modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import copy
import glob
import inspect
import math
import unittest
Expand All @@ -23,6 +24,7 @@
import pytest
from datasets import load_dataset

from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
Expand Down Expand Up @@ -485,8 +487,6 @@ def test_compute_mask_indices_overlap(self):
@slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").filter(
Expand Down Expand Up @@ -556,18 +556,17 @@ def test_inference_ctc_robust_batched(self):
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't nicely load mp3 at the moment with datasets -> let's just use a .wav file for now

ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))

resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]

model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")

input_values = processor(resampled_audio, return_tensors="tf").input_values
input_values = processor(sample, return_tensors="tf").input_values

logits = model(input_values).logits

transcription = processor.batch_decode(logits.numpy()).text

self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
31 changes: 5 additions & 26 deletions tests/test_modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
""" Testing suite for the PyTorch WavLM model. """

import copy
import math
import unittest

Expand Down Expand Up @@ -452,30 +451,9 @@ def _mock_init_weights(self, module):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)

# overwrite from test_modeling_common
# as WavLM is not very precise
@unittest.skip(reason="Feed forward chunking is not implemented for WavLM")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feed forward chunking is not even implemented by WavLM

def test_feed_forward_chunking(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()

hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]

torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()

hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-2))
pass

@slow
def test_model_from_pretrained(self):
Expand Down Expand Up @@ -528,7 +506,7 @@ def test_inference_base(self):
def test_inference_large(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus", return_attention_mask=True
"microsoft/wavlm-large", return_attention_mask=True
)

input_speech = self._load_datasamples(2)
Expand All @@ -544,8 +522,9 @@ def test_inference_large(self):
)

EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
[[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this was incorrect

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did it ever pass? If so, it would be nice to checkout an earlier commit on which it passed and check the difference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know what was going on here. I'm sure though that the checkpoint works correctly as the model gives great results when fine-tuning -> see: https://huggingface.co/patrickvonplaten/wavlm-libri-clean-100h-large (tested on dec 17th: https://huggingface.co/patrickvonplaten/wavlm-libri-clean-100h-large/commit/a1b7ace90561bafd37167ca73c72833ad345963f) and there hasn't been a model change in the checkpoint's repo since december 16th: https://huggingface.co/microsoft/wavlm-large/commit/38b04afdf061607fdccc24c4ca4e8c3ae339012f

So the checkpoint is fine. Really not sure what was/is going on with this test. Will monitor in the coming days

)

self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2))

def test_inference_diarization(self):
Expand Down