Skip to content

Commit

Permalink
[InstructBlip] Add accelerate support for instructblip (#24488)
Browse files Browse the repository at this point in the history
* add accelerate support for instructblip

* add `_keep_in_fp32_modules`

* dynamically adapt `_no_split_modules`

* better fix

* same logic for `_keep_in_fp32_modules`
  • Loading branch information
younesbelkada authored Jun 26, 2023
1 parent 5757923 commit 9895670
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"]
_keep_in_fp32_modules = []

# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
def _init_weights(self, module):
Expand Down Expand Up @@ -1264,11 +1266,18 @@ def __init__(self, config: InstructBlipConfig):
self.qformer = InstructBlipQFormerModel(config.qformer_config)

self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)

if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)

if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)

if language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)

self.language_model = language_model

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1422,7 +1431,7 @@ def forward(

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_model_attention_mask, attention_mask], dim=1)
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)

if self.config.use_decoder_only_language_model:
outputs = self.language_model(
Expand Down

0 comments on commit 9895670

Please sign in to comment.