From 186aa6befecc6e6f022fed34019a00d60884d557 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 18 Jan 2024 16:41:44 +0000 Subject: [PATCH] [Whisper] Fix audio classification with weighted layer sum (#28563) * fix * tests * fix test --- .../models/whisper/modeling_whisper.py | 10 ++++++++- tests/models/whisper/test_modeling_whisper.py | 21 ++++++++++++------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index a3550c791c7..1e68f4f63e9 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -57,6 +57,8 @@ logger = logging.get_logger(__name__) +_HIDDEN_STATES_START_POSITION = 1 + _CONFIG_FOR_DOC = "WhisperConfig" _CHECKPOINT_FOR_DOC = "openai/whisper-tiny" @@ -2957,6 +2959,11 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + if self.config.use_weighted_layer_sum: + output_hidden_states = True + elif output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: @@ -2969,7 +2976,8 @@ def forward( ) if self.config.use_weighted_layer_sum: - hidden_states = torch.stack(encoder_outputs, dim=1) + hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 0192b83f929..e0f369c7a67 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2292,16 +2292,15 @@ def get_subsampled_output_lengths(self, input_lengths): def encoder_seq_length(self): return self.get_subsampled_output_lengths(self.seq_length) - def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False): - model = WhisperForAudioClassification(config=config).to(torch_device).eval() - - if freeze_encoder: - model.freeze_encoder() + def create_and_check_model_forward(self, config, inputs_dict, use_weighted_layer_sum=False): + config.use_weighted_layer_sum = use_weighted_layer_sum + model = WhisperForAudioClassification(config=config) + model.to(torch_device).eval() input_features = inputs_dict["input_features"] - # first forward pass - last_hidden_state = model(input_features).logits + with torch.no_grad(): + last_hidden_state = model(input_features).logits self.parent.assertTrue(last_hidden_state.shape, (13, 2)) @@ -2336,6 +2335,14 @@ def test_forward_signature(self): expected_arg_names = ["input_features", "head_mask", "encoder_outputs"] self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + def test_forward_pass(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_forward(*config_and_inputs) + + def test_forward_pass_weighted_layer_sum(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_forward(*config_and_inputs, use_weighted_layer_sum=True) + @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.") def test_cpu_offload(self): pass