diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 95d8ce1069a556..6b8ccc07929214 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -426,6 +426,8 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) { auto total_size_q = ov::gen_pattern::Symbol("total_size_q"); auto total_size_k = ov::gen_pattern::Symbol("total_size_k"); auto total_size_v = ov::gen_pattern::Symbol("total_size_v"); + auto batch = ov::gen_pattern::Symbol("batch"); + auto seq_len = ov::gen_pattern::Symbol("seq_len"); auto qkv_proj = makePattern({qkv_linear, -1, {total_size_q, total_size_k, total_size_v}}); qkv_proj->set_output_size(3); @@ -441,11 +443,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) { // rotate half auto ListConstruct_452_Concat = makePattern({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}}); - auto const_target_shape_1 = makeConst({0, 0, head_cnt, ndims / 2, 2}); + auto const_target_shape_1 = makeConst({seq_len, batch, head_cnt, ndims / 2, 2}); auto ListConstruct_379_Concat = makePattern({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}}); - auto const_target_shape_2 = makeConst({0, 0, 1, ndims / 2, 2}); + auto const_target_shape_2 = makeConst({seq_len, batch, 1, ndims / 2, 2}); auto reshape_Reshape_453 = makePattern( {slice_Slice_437 | var_split_1->output(0), ListConstruct_452_Concat | const_target_shape_1}); @@ -480,7 +482,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) { auto ShapeOf_135133 = makePattern({stack_481}); auto flatten_Slice_497 = GenSlice(ShapeOf_135133, 0, 3, 1, 0); auto flatten_Concat_500 = makePattern({flatten_Slice_497, {-1}}, {{"axis", 0}}); - auto const_target_shape_3 = makeConst({0, 0, head_cnt, ndims}); + auto const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims}); // [length, batch, head_cnt, half_rotary_dims, 2] auto flatten_Reshape_501 = makePattern({stack_481, flatten_Concat_500 | const_target_shape_3}, {{"special_zero", true}}); @@ -566,8 +568,12 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto neg_Multiply = makePattern({Gather_311651, {-1}}, {{"auto_broadcast", "numpy"}}); auto ScatterUpdate_463814 = makePattern({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}}); - auto slice_Slice_446 = - GenSlice2(rotary_emb_cos, ScatterUpdate_463814 | Gather_377635 | neg_Multiply, {INT_MAX}, {1}, 1, true); // tensor_array + auto slice_Slice_446 = GenSlice2(rotary_emb_cos, + ScatterUpdate_463814 | Gather_377635 | neg_Multiply, + {INT_MAX}, + {1}, + 1, + true); // tensor_array auto mul_Multiply_552 = makePattern({slice_Slice_543, slice_Slice_446}, {{"auto_broadcast", "numpy"}}); // tensor_array @@ -606,8 +612,12 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { makePattern({ListUnpack_586_Split->output(0), -2}); // tensor_array auto cat_Concat_593 = makePattern({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze}, {{"axis", -1}}); // tensor_array - auto slice_Slice_470 = - GenSlice2(rotary_emb_sin, ScatterUpdate_463814 | Gather_377635 | neg_Multiply, {INT_MAX}, {1}, 1, true); // tensor_array + auto slice_Slice_470 = GenSlice2(rotary_emb_sin, + ScatterUpdate_463814 | Gather_377635 | neg_Multiply, + {INT_MAX}, + {1}, + 1, + true); // tensor_array auto mul_Multiply_594 = makePattern({cat_Concat_593, slice_Slice_470}, {{"auto_broadcast", "numpy"}}); // tensor_array diff --git a/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp index dab3b6365266cd..46ea730ac32a8c 100644 --- a/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp +++ b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp @@ -846,8 +846,7 @@ std::shared_ptr RoPETestChatGLMSlice::buildROPE_ChatGLM(int batch, in makeOP({slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}}); auto select_Gather_381 = makeOP({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}}); auto slice_Unsqueeze_367 = makeOP({size_Gather_348, 0}); - auto slice_Slice_369 = - makeOP({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}}); + auto slice_Slice_369 = makeOP({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}, {0}}); auto size_ShapeOf_374 = makeOP({reshape_Reshape_373}, {{"output_type", "i32"}}); auto size_Gather_376 = makeOP({size_ShapeOf_374, {3}, 0}, {{"batch_dims", 0}}); auto ListConstruct_379_Concat = @@ -872,7 +871,7 @@ std::shared_ptr RoPETestChatGLMSlice::buildROPE_ChatGLM(int batch, in auto Unsqueeze_62717 = makeOP({add_Add_396, -1}); auto stack_401 = makeOP({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}}); auto flatten_ShapeOf_402 = makeOP({stack_401}, {{"output_type", "i32"}}); - auto flatten_Slice_417 = makeOP({flatten_ShapeOf_402, {0}, {3}, {1}}); + auto flatten_Slice_417 = makeOP({flatten_ShapeOf_402, {0}, {3}, {1}, {0}}); auto flatten_Concat_420 = makeOP({flatten_Slice_417, {-1}}, {{"axis", 0}}); auto flatten_Reshape_421 = makeOP({stack_401, flatten_Concat_420}, {{"special_zero", true}}); auto slice_Slice_363 = makeOP({view_Reshape, slice_Unsqueeze_112, {INT_MAX}, {1}, {3}});