Skip to content

Commit

Permalink
codestyle; fix chatglm pattern and test
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jul 10, 2024
1 parent 84cbff8 commit 8b95ed6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {
auto total_size_q = ov::gen_pattern::Symbol("total_size_q");
auto total_size_k = ov::gen_pattern::Symbol("total_size_k");
auto total_size_v = ov::gen_pattern::Symbol("total_size_v");
auto batch = ov::gen_pattern::Symbol("batch");
auto seq_len = ov::gen_pattern::Symbol("seq_len");

auto qkv_proj = makePattern<opset1::VariadicSplit>({qkv_linear, -1, {total_size_q, total_size_k, total_size_v}});
qkv_proj->set_output_size(3);
Expand All @@ -441,11 +443,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {
// rotate half
auto ListConstruct_452_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_1 = makeConst({0, 0, head_cnt, ndims / 2, 2});
auto const_target_shape_1 = makeConst({seq_len, batch, head_cnt, ndims / 2, 2});

auto ListConstruct_379_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_2 = makeConst({0, 0, 1, ndims / 2, 2});
auto const_target_shape_2 = makeConst({seq_len, batch, 1, ndims / 2, 2});

auto reshape_Reshape_453 = makePattern<opset1::Reshape>(
{slice_Slice_437 | var_split_1->output(0), ListConstruct_452_Concat | const_target_shape_1});
Expand Down Expand Up @@ -480,7 +482,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {
auto ShapeOf_135133 = makePattern<opset1::ShapeOf>({stack_481});
auto flatten_Slice_497 = GenSlice(ShapeOf_135133, 0, 3, 1, 0);
auto flatten_Concat_500 = makePattern<opset1::Concat>({flatten_Slice_497, {-1}}, {{"axis", 0}});
auto const_target_shape_3 = makeConst({0, 0, head_cnt, ndims});
auto const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims});
// [length, batch, head_cnt, half_rotary_dims, 2]
auto flatten_Reshape_501 =
makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape_3}, {{"special_zero", true}});
Expand Down Expand Up @@ -566,8 +568,12 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto neg_Multiply = makePattern<opset1::Multiply>({Gather_311651, {-1}}, {{"auto_broadcast", "numpy"}});

auto ScatterUpdate_463814 = makePattern<opset3::ScatterUpdate>({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}});
auto slice_Slice_446 =
GenSlice2(rotary_emb_cos, ScatterUpdate_463814 | Gather_377635 | neg_Multiply, {INT_MAX}, {1}, 1, true); // tensor_array<f32[1,..4096,1,128]>
auto slice_Slice_446 = GenSlice2(rotary_emb_cos,
ScatterUpdate_463814 | Gather_377635 | neg_Multiply,
{INT_MAX},
{1},
1,
true); // tensor_array<f32[1,..4096,1,128]>
auto mul_Multiply_552 =
makePattern<opset1::Multiply>({slice_Slice_543, slice_Slice_446},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
Expand Down Expand Up @@ -606,8 +612,12 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
makePattern<opset1::Squeeze>({ListUnpack_586_Split->output(0), -2}); // tensor_array<f32[?,?,32,64]>
auto cat_Concat_593 = makePattern<opset1::Concat>({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze},
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>
auto slice_Slice_470 =
GenSlice2(rotary_emb_sin, ScatterUpdate_463814 | Gather_377635 | neg_Multiply, {INT_MAX}, {1}, 1, true); // tensor_array<f32[1,..4096,1,128]>
auto slice_Slice_470 = GenSlice2(rotary_emb_sin,
ScatterUpdate_463814 | Gather_377635 | neg_Multiply,
{INT_MAX},
{1},
1,
true); // tensor_array<f32[1,..4096,1,128]>
auto mul_Multiply_594 =
makePattern<opset1::Multiply>({cat_Concat_593, slice_Slice_470},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,7 @@ std::shared_ptr<ov::Model> RoPETestChatGLMSlice::buildROPE_ChatGLM(int batch, in
makeOP<opset1::Reshape>({slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}});
auto select_Gather_381 = makeOP<opset8::Gather>({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}});
auto slice_Unsqueeze_367 = makeOP<opset1::Unsqueeze>({size_Gather_348, 0});
auto slice_Slice_369 =
makeOP<opset8::Slice>({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}});
auto slice_Slice_369 = makeOP<opset8::Slice>({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}, {0}});
auto size_ShapeOf_374 = makeOP<opset3::ShapeOf>({reshape_Reshape_373}, {{"output_type", "i32"}});
auto size_Gather_376 = makeOP<opset8::Gather>({size_ShapeOf_374, {3}, 0}, {{"batch_dims", 0}});
auto ListConstruct_379_Concat =
Expand All @@ -872,7 +871,7 @@ std::shared_ptr<ov::Model> RoPETestChatGLMSlice::buildROPE_ChatGLM(int batch, in
auto Unsqueeze_62717 = makeOP<opset1::Unsqueeze>({add_Add_396, -1});
auto stack_401 = makeOP<opset1::Concat>({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}});
auto flatten_ShapeOf_402 = makeOP<opset3::ShapeOf>({stack_401}, {{"output_type", "i32"}});
auto flatten_Slice_417 = makeOP<opset8::Slice>({flatten_ShapeOf_402, {0}, {3}, {1}});
auto flatten_Slice_417 = makeOP<opset8::Slice>({flatten_ShapeOf_402, {0}, {3}, {1}, {0}});
auto flatten_Concat_420 = makeOP<opset1::Concat>({flatten_Slice_417, {-1}}, {{"axis", 0}});
auto flatten_Reshape_421 = makeOP<opset1::Reshape>({stack_401, flatten_Concat_420}, {{"special_zero", true}});
auto slice_Slice_363 = makeOP<opset8::Slice>({view_Reshape, slice_Unsqueeze_112, {INT_MAX}, {1}, {3}});
Expand Down

0 comments on commit 8b95ed6

Please sign in to comment.