Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BLIP-2] Improve conversion script #24854

Merged
merged 9 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
image_size=224,
patch_size=14,
hidden_act="gelu",
layer_norm_eps=0.00001,
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=1e-10,
qkv_bias=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import torch

# pip3 install salesforce-lavis
# I'm actually installing a slightly modified version: pip3 install git+https://github.com/nielsrogge/LAVIS.git@fix_lavis
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
# to make sure we can compare both original and HF implementation in float32
from lavis.models import load_model_and_preprocess
from PIL import Image

Expand All @@ -43,6 +44,8 @@

def load_demo_image():
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
url = "https://user-images.githubusercontent.com/50018861/252267123-a49ec5be-d964-4760-9ef5-3f006a353720.png"
url = "https://user-images.githubusercontent.com/50018861/255126794-269d9dea-2620-454c-9643-63b1155a697e.png"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

return image
Expand Down Expand Up @@ -147,9 +150,11 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_

# load original model
print("Loading original model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
hf_model_device = "cuda:1" if torch.cuda.is_available() else "cpu"
lavis_device = "cuda:2" if torch.cuda.is_available() else "cpu"

original_model, vis_processors, _ = load_model_and_preprocess(
name=name, model_type=type, is_eval=True, device=device
name=name, model_type=type, is_eval=True, device=lavis_device
)
original_model.eval()
print("Done!")
Expand Down Expand Up @@ -185,61 +190,56 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
assert unexpected_keys == ["qformer.embeddings.position_ids"]

image = load_demo_image()
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(device)
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(device)
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)

# create processor
image_processor = BlipImageProcessor(
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
)
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)

# make sure processor creates exact same pixel values
assert torch.allclose(pixel_values, original_pixel_values)
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))

original_model.to(device)
hf_model.to(device)
original_model.to(lavis_device)
hf_model.to(hf_model_device)
with torch.no_grad():
if "opt" in model_name:
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
logits = hf_model(original_pixel_values, input_ids).logits
logits = hf_model(pixel_values, input_ids).logits
else:
original_logits = original_model(
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
).logits
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
logits = hf_model(original_pixel_values, input_ids, labels=labels).logits
logits = hf_model(pixel_values, input_ids, labels=labels).logits

assert original_logits.shape == logits.shape
print("First values of original logits:", original_logits[0, :3, :3])
print("First values of HF logits:", logits[0, :3, :3])

# assert values
if model_name == "blip2-flan-t5-xl":
expected_slice_logits = torch.tensor(
[[-41.5850, -4.4440, -8.9922], [-47.4322, -5.9143, -1.7340]], device=device
)
assert torch.allclose(logits[0, :3, :3], expected_slice_logits, atol=1e-4)
elif model_name == "blip2-flan-t5-xl-coco":
expected_slice_logits = torch.tensor(
[[-57.0109, -9.8967, -12.6280], [-68.6578, -12.7191, -10.5065]], device=device
)
else:
# cast to same type
target_dtype = logits.dtype
assert torch.allclose(original_logits.to(target_dtype), logits, atol=1e-2)
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
print("Looks ok!")

print("Generating a caption...")
prompt = ""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
prompt = "Question: what object is in this image? Answer:"
# prompt = "Question: what is the structure and geometry of this chair?"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)

from transformers import set_seed

set_seed(42)

original_outputs = original_model.generate({"image": original_pixel_values})
original_outputs = original_model.generate(
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True
)
outputs = hf_model.generate(
original_pixel_values,
pixel_values,
input_ids,
do_sample=False,
do_sample=True,
num_beams=5,
max_length=30,
min_length=1,
Expand All @@ -248,10 +248,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
length_penalty=1.0,
temperature=1,
)
print("Original generation:", original_outputs)
prompt_length = input_ids.shape[1]
output_text = processor.batch_decode(outputs[:, prompt_length:], skip_special_tokens=True)
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
print("Original generation:", original_outputs)
print("HF generation:", output_text)

if pytorch_dump_folder_path is not None:
Expand Down
32 changes: 32 additions & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,12 @@ def forward(

One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.

<Tip>

Note that Flan-T5 checkpoints cannot be cast to float16. They are pre-trained using bfloat16.

</Tip>
""",
BLIP_2_START_DOCSTRING,
)
Expand Down Expand Up @@ -1697,6 +1703,32 @@ def forward(
>>> prompt = "Question: how many cats are there? Answer:"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

>>> generated_ids = model.generate(**inputs)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
two
```

Note that int8 inference is also supported through [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
This greatly reduces the amount of memory used by the model while maintaining the same performance.

```python
>>> from PIL import Image
>>> import requests
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
>>> import torch

>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
>>> model = Blip2ForConditionalGeneration.from_pretrained(
... "Salesforce/blip2-flan-t5-xl", load_in_8bit=True, device_map="auto", torch_dtype=torch.bfloat16
... ) # doctest: +IGNORE_RESULT

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> prompt = "Question: how many cats are there? Answer:"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)

>>> generated_ids = model.generate(**inputs)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
Expand Down