Skip to content

Commit

Permalink
[BLIP-2] Improve conversion script (huggingface#24854)
Browse files Browse the repository at this point in the history
* Improve conversion script

* Add int8 code example

* Update tip

* Fix code

* Fix code snippet

* Add nucleus sampling

* More improvements

* Address comments

* Address comments
  • Loading branch information
NielsRogge authored and parambharat committed Sep 26, 2023
1 parent 7e0eecb commit 450025d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
image_size=224,
patch_size=14,
hidden_act="gelu",
layer_norm_eps=0.00001,
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=1e-10,
qkv_bias=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import torch

# pip3 install salesforce-lavis
# I'm actually installing a slightly modified version: pip3 install git+https://github.com/nielsrogge/LAVIS.git@fix_lavis
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
# to make sure we can compare both original and HF implementation in float32
from lavis.models import load_model_and_preprocess
from PIL import Image

Expand All @@ -37,6 +38,7 @@
BlipImageProcessor,
OPTConfig,
T5Config,
set_seed,
)
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD

Expand Down Expand Up @@ -145,11 +147,16 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_

name, type = model_name_to_original[model_name]

# note: this script is tested on 2 GPUs, as models are compared in float32,
# which requires quite some memory. Hence loading both on a
# separate device is the easiest to compare
hf_model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
lavis_device = "cuda:1" if torch.cuda.is_available() else "cpu"

# load original model
print("Loading original model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
original_model, vis_processors, _ = load_model_and_preprocess(
name=name, model_type=type, is_eval=True, device=device
name=name, model_type=type, is_eval=True, device=lavis_device
)
original_model.eval()
print("Done!")
Expand Down Expand Up @@ -185,61 +192,53 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
assert unexpected_keys == ["qformer.embeddings.position_ids"]

image = load_demo_image()
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(device)
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(device)
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)

# create processor
image_processor = BlipImageProcessor(
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
)
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)

# make sure processor creates exact same pixel values
assert torch.allclose(pixel_values, original_pixel_values)
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))

original_model.to(device)
hf_model.to(device)
original_model.to(lavis_device)
hf_model.to(hf_model_device)
with torch.no_grad():
if "opt" in model_name:
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
logits = hf_model(original_pixel_values, input_ids).logits
logits = hf_model(pixel_values, input_ids).logits
else:
original_logits = original_model(
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
).logits
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
logits = hf_model(original_pixel_values, input_ids, labels=labels).logits
logits = hf_model(pixel_values, input_ids, labels=labels).logits

assert original_logits.shape == logits.shape
print("First values of original logits:", original_logits[0, :3, :3])
print("First values of HF logits:", logits[0, :3, :3])

# assert values
if model_name == "blip2-flan-t5-xl":
expected_slice_logits = torch.tensor(
[[-41.5850, -4.4440, -8.9922], [-47.4322, -5.9143, -1.7340]], device=device
)
assert torch.allclose(logits[0, :3, :3], expected_slice_logits, atol=1e-4)
elif model_name == "blip2-flan-t5-xl-coco":
expected_slice_logits = torch.tensor(
[[-57.0109, -9.8967, -12.6280], [-68.6578, -12.7191, -10.5065]], device=device
)
else:
# cast to same type
target_dtype = logits.dtype
assert torch.allclose(original_logits.to(target_dtype), logits, atol=1e-2)
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
print("Looks ok!")

print("Generating a caption...")
prompt = ""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
prompt = "Question: what object is in this image? Answer:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)

set_seed(42)

original_outputs = original_model.generate({"image": original_pixel_values})
original_outputs = original_model.generate(
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True
)
outputs = hf_model.generate(
original_pixel_values,
pixel_values,
input_ids,
do_sample=False,
do_sample=True,
num_beams=5,
max_length=30,
min_length=1,
Expand All @@ -248,10 +247,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
length_penalty=1.0,
temperature=1,
)
print("Original generation:", original_outputs)
prompt_length = input_ids.shape[1]
output_text = processor.batch_decode(outputs[:, prompt_length:], skip_special_tokens=True)
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
print("Original generation:", original_outputs)
print("HF generation:", output_text)

if pytorch_dump_folder_path is not None:
Expand Down
39 changes: 35 additions & 4 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,12 @@ def forward(
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
<Tip>
Note that Flan-T5 checkpoints cannot be cast to float16. They are pre-trained using bfloat16.
</Tip>
""",
BLIP_2_START_DOCSTRING,
)
Expand Down Expand Up @@ -1687,15 +1693,40 @@ def forward(
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> model = Blip2ForConditionalGeneration.from_pretrained(
... "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
... )
>>> model.to(device) # doctest: +IGNORE_RESULT
... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
... ) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Question: how many cats are there? Answer:"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)
>>> generated_ids = model.generate(**inputs)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
two
```
Note that int8 inference is also supported through [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
This greatly reduces the amount of memory used by the model while maintaining the same performance.
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
>>> import torch
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
>>> model = Blip2ForConditionalGeneration.from_pretrained(
... "Salesforce/blip2-flan-t5-xl", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
... ) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Question: how many cats are there? Answer:"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
>>> generated_ids = model.generate(**inputs)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
Expand Down
1 change: 1 addition & 0 deletions utils/documentation_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ src/transformers/models/blip/image_processing_blip.py
src/transformers/models/blip/modeling_blip.py
src/transformers/models/blip/modeling_tf_blip.py
src/transformers/models/blip/processing_blip.py
src/transformers/models/blip_2/modeling_blip_2.py
src/transformers/models/blip_2/processing_blip_2.py
src/transformers/models/bloom/configuration_bloom.py
src/transformers/models/bloom/tokenization_bloom_fast.py
Expand Down

0 comments on commit 450025d

Please sign in to comment.