From 9895670e95d7fef9fd53e8d0a638970ab03221f1 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 26 Jun 2023 18:36:27 +0200 Subject: [PATCH] [`InstructBlip`] Add accelerate support for instructblip (#24488) * add accelerate support for instructblip * add `_keep_in_fp32_modules` * dynamically adapt `_no_split_modules` * better fix * same logic for `_keep_in_fp32_modules` --- .../models/instructblip/modeling_instructblip.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 029c59aa7cdd..acc0df630991 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -281,6 +281,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel): r"language_model.decoder.embed_tokens.weight", r"language_model.lm_head.weight", ] + _no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"] + _keep_in_fp32_modules = [] # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip def _init_weights(self, module): @@ -1264,11 +1266,18 @@ def __init__(self, config: InstructBlipConfig): self.qformer = InstructBlipQFormerModel(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: language_model = AutoModelForCausalLM.from_config(config.text_config) else: language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + if language_model._no_split_modules is not None: + self._no_split_modules.extend(language_model._no_split_modules) + + if language_model._keep_in_fp32_modules is not None: + self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules) + self.language_model = language_model # Initialize weights and apply final processing @@ -1422,7 +1431,7 @@ def forward( if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_model_attention_mask, attention_mask], dim=1) + attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1) if self.config.use_decoder_only_language_model: outputs = self.language_model(