Skip to content

Commit

Permalink
[Whisper] Fix audio classification with weighted layer sum (#28563)
Browse files Browse the repository at this point in the history
* fix

* tests

* fix test
  • Loading branch information
sanchit-gandhi authored Jan 18, 2024
1 parent 619ecfe commit 186aa6b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
10 changes: 9 additions & 1 deletion src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@

logger = logging.get_logger(__name__)

_HIDDEN_STATES_START_POSITION = 1

_CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"

Expand Down Expand Up @@ -2957,6 +2959,11 @@ def forward(
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if self.config.use_weighted_layer_sum:
output_hidden_states = True
elif output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if encoder_outputs is None:
Expand All @@ -2969,7 +2976,8 @@ def forward(
)

if self.config.use_weighted_layer_sum:
hidden_states = torch.stack(encoder_outputs, dim=1)
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
Expand Down
21 changes: 14 additions & 7 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,16 +2292,15 @@ def get_subsampled_output_lengths(self, input_lengths):
def encoder_seq_length(self):
return self.get_subsampled_output_lengths(self.seq_length)

def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
model = WhisperForAudioClassification(config=config).to(torch_device).eval()

if freeze_encoder:
model.freeze_encoder()
def create_and_check_model_forward(self, config, inputs_dict, use_weighted_layer_sum=False):
config.use_weighted_layer_sum = use_weighted_layer_sum
model = WhisperForAudioClassification(config=config)
model.to(torch_device).eval()

input_features = inputs_dict["input_features"]

# first forward pass
last_hidden_state = model(input_features).logits
with torch.no_grad():
last_hidden_state = model(input_features).logits

self.parent.assertTrue(last_hidden_state.shape, (13, 2))

Expand Down Expand Up @@ -2336,6 +2335,14 @@ def test_forward_signature(self):
expected_arg_names = ["input_features", "head_mask", "encoder_outputs"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)

def test_forward_pass(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)

def test_forward_pass_weighted_layer_sum(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs, use_weighted_layer_sum=True)

@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_cpu_offload(self):
pass
Expand Down

0 comments on commit 186aa6b

Please sign in to comment.