Skip to content

Commit

Permalink
add code for test_onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
lihongjie committed Feb 19, 2025
1 parent f9ec39f commit 36d594b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Type Error: Type 'tensor(bfloat16)' of input parameter (hidden_states) of operat
transformers 4.49.0.dev0
```
需要修改其中的 `modeling_qwen2_5_vl.py``modeling_qwen2_vl.py`,由于`modeling_qwen2_5_vl.py`是由`modular_qwen2_5_vl.py`生成的,所以同步修改了`modular_qwen2_5_vl.py`
需要替换transformer库中的这几个文件才能进行下面的onnx导出过程。

### 导出过程

Expand Down
102 changes: 101 additions & 1 deletion modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
flash_attn_varlen_func = None

import os
import onnxruntime as ort

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -663,7 +664,8 @@ def __init__(self, config, *inputs, **kwargs) -> None:
self.blocks = nn.ModuleList(
[Qwen2_5_VLVisionBlockExport(config, config._attn_implementation) for _ in range(config.depth)]
)



def forward_export(self, hidden_states, rotary_pos_emb, attention_mask, attention_mask_window, window_index):

for layer_num, blk in enumerate(self.blocks):
Expand All @@ -684,6 +686,104 @@ def forward_export(self, hidden_states, rotary_pos_emb, attention_mask, attentio

return hidden_states


def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
Returns:
`torch.Tensor`: hidden_states.
"""
def generate_attnmask(seq_length, cu_seqlens):
attention_mask = torch.zeros([1, seq_length, seq_length], device=cu_seqlens.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True

return attention_mask
print("Qwen2_5_VisionTransformerPretrainedModel grid_thw",grid_thw)
print("Qwen2_5_VisionTransformerPretrainedModel hidden_states",hidden_states.shape) # [14308, 1176]
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
print("rotary_pos_emb.shape",rotary_pos_emb.shape) # [14308, 40]
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

print("hidden_states",hidden_states.shape) # [14308, 1280]
print("window_index.shape",window_index.shape)
print("window_index[0:33]",window_index[0:33])
# window_index[0:33] tensor([ 0, 1, 2, 3, 73, 74, 75, 76, 146, 147, 148, 149, 219, 220,
# 221, 222, 4, 5, 6, 7, 77, 78, 79, 80, 150, 151, 152, 153,
# 223, 224, 225, 226, 8])
seq_len, _ = hidden_states.size()

hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) # 14308//4, 4, 1280 # patch index 2x2 合 1
hidden_states = hidden_states[window_index, :, :] # 安装 window 内patch的顺序编排
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)


cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
print("cu_seqlens",cu_seqlens)

# torch.save(hidden_states, "hidden_states.pth")
# torch.save(rotary_pos_emb, "rotary_pos_emb.pth")
# torch.save(cu_seqlens, "cu_seqlens.pth")
# torch.save(cu_window_seqlens, "cu_window_seqlens.pth")
# torch.save(window_index, "window_index.pth")
# for layer_num, blk in enumerate(self.blocks):
# if layer_num in self.fullatt_block_indexes:
# cu_seqlens_now = cu_seqlens
# else:
# cu_seqlens_now = cu_window_seqlens
# if self.gradient_checkpointing and self.training:
# hidden_states = self._gradient_checkpointing_func(
# blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb
# )
# else:
# hidden_states = blk(
# hidden_states,
# cu_seqlens=cu_seqlens_now,
# rotary_pos_emb=rotary_pos_emb,
# )

# hidden_states = self.merger(hidden_states)
# reverse_indices = torch.argsort(window_index)
# hidden_states = hidden_states[reverse_indices, :]

# return hidden_states
print("test Vision Encoder Onnx -------------------")
session = ort.InferenceSession("Qwen2.5-VL-3B-Instruct_vision.onnx", providers=["CPUExecutionProvider"])
attention_mask = generate_attnmask(hidden_states.shape[0], cu_seqlens)
attention_mask_window = generate_attnmask(hidden_states.shape[0], cu_window_seqlens)

inputs = {"hidden_states": hidden_states.cpu().numpy(),
"rotary_pos_emb":rotary_pos_emb.cpu().numpy(),
"attention_mask":attention_mask.cpu().numpy(),
"attention_mask_window":attention_mask_window.cpu().numpy(),
"window_index":window_index.cpu().numpy()}
hidden_states = session.run(["hidden_states_out"], inputs)[0]
hidden_states = torch.from_numpy(hidden_states).to(grid_thw.device)
return hidden_states


class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, device=None):
Expand Down

0 comments on commit 36d594b

Please sign in to comment.