From 4ba531c43f884cea2575f1f2c1b287173a28fb1e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 17 Sep 2024 08:31:24 +0800 Subject: [PATCH] Fix: Qwen2-VL training on video datasets (#33307) * fix video finetuning * Update modeling_qwen2_vl.py * Update modeling_qwen2_vl.py * fix --- .../models/qwen2_vl/modeling_qwen2_vl.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 5838c460e7c..4e4e04198c0 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1680,16 +1680,18 @@ def forward( inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device) - image_mask = input_ids == self.config.image_token_id - if self.training: - inputs_embeds = inputs_embeds.clone() - inputs_embeds[image_mask] = image_embeds + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device) - video_mask = input_ids == self.config.video_token_id - inputs_embeds[video_mask] = video_embeds + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device)