Skip to content

Commit

Permalink
Fix docstrings for TF BLIP (#22618)
Browse files Browse the repository at this point in the history
* Fix docstrings for TFBLIP

* Fix missing line in TF port!

* Use values from torch tests now other bugs fixed

* Use values from torch tests now other bugs fixed

* Fix doctest string
  • Loading branch information
Rocketknight1 authored Apr 12, 2023
1 parent ce06e47 commit 50f82e1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
7 changes: 3 additions & 4 deletions src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def get_text_features(
)

pooled_output = text_outputs[1]
text_features = self.text_projection(pooled_output)
text_features = self.blip.text_projection(pooled_output)

return text_features

Expand Down Expand Up @@ -1057,7 +1057,7 @@ def get_image_features(
vision_outputs = self.blip.vision_model(pixel_values=pixel_values, return_dict=return_dict)

pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
image_features = self.blip.visual_projection(pooled_output)

return image_features

Expand Down Expand Up @@ -1238,7 +1238,7 @@ def generate(
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
two cats are laying on a couch
two cats sleeping on a couch
```
"""

Expand Down Expand Up @@ -1410,7 +1410,6 @@ def call(
>>> inputs["labels"] = labels
>>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> loss.backward()
>>> # inference
>>> text = "How many cats are in the picture?"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/blip/modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def call(
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/blip/test_modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def test_inference_image_captioning(self):
# Test output
self.assertEqual(
predictions[0].numpy().tolist(),
[30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102],
[30522, 1037, 3861, 1997, 1037, 2450, 1998, 2014, 3899, 2006, 1996, 3509, 102],
)

def test_inference_vqa(self):
Expand All @@ -810,6 +810,6 @@ def test_inference_itm(self):
out_itm = model(**inputs)
out = model(**inputs, use_itm_head=False, training=False)

expected_scores = tf.convert_to_tensor([[0.9798, 0.0202]])
expected_scores = tf.convert_to_tensor([[0.0029, 0.9971]])
self.assertTrue(np.allclose(tf.nn.softmax(out_itm[0]).numpy(), expected_scores, rtol=1e-3, atol=1e-3))
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5053]]), rtol=1e-3, atol=1e-3))
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5162]]), rtol=1e-3, atol=1e-3))

0 comments on commit 50f82e1

Please sign in to comment.