diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index c4ca668721a..a2815459bb0 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1577,14 +1577,16 @@ def forward( inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device) - image_mask = input_ids == self.config.image_token_id - inputs_embeds.masked_scatter(image_mask.unsqueeze(-1), image_embeds) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device) - video_mask = input_ids == self.config.video_token_id - inputs_embeds.masked_scatter(video_mask.unsqueeze(-1), video_embeds) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device)