Skip to content

Commit

Permalink
RopeFusion transformation fix after PagedAttention transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jan 15, 2025
1 parent 22922a2 commit c8ef5ba
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
varsplit->set_output_size(2);
// Reshape or UnSqueeze should both be support
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}}) |
auto dim0 = ov::gen_pattern::Symbol("dim0");
auto dim1 = ov::gen_pattern::Symbol("dim1");
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}}) |
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(1), 2});
// repeate cos/sin table
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
Expand All @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto neg_Multiply_1177 = makePattern<opset1::Multiply>({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_65524 = makePattern<opset1::Unsqueeze>({neg_Multiply_1177, -1});
auto head_num = ov::gen_pattern::Symbol("head_num");
auto Unsqueeze_28998 =
makePattern<opset1::Reshape>({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});

auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3);
auto Unsqueeze_65525 = makePattern<opset1::Unsqueeze>({slice_Slice_1168, -1});
auto stack_1182 = makePattern<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});
auto Unsqueeze_28999 =
makePattern<opset1::Reshape>({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});
auto stack_1182 =
makePattern<opset1::Concat>({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999},
{{"axis", -1}});

auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
Expand All @@ -447,7 +456,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
makePattern<opset1::Concat>({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}});
auto permute_Transpose_1213 = makePattern<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});

auto result = permute_Transpose_1213;
auto result = cat_Concat_1211 | permute_Transpose_1213;

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand All @@ -461,7 +470,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
OutputVector new_args;
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);

config.output_trans0213 = true;
if (pattern_map.count(permute_Transpose_1213))
config.output_trans0213 = true;
config.is_interleaved = true;

// input is [B,L,H,S]
Expand All @@ -478,14 +488,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr(),
pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()},
pattern_map.at(cat_Concat_1211).get_node_shared_ptr()},
new_node);
ov::replace_node(old_node, new_node);
// shapeof may be moved up from transpose to add,
Expand Down Expand Up @@ -557,9 +564,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
} else {
auto ListConstruct_452_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims / 2, 2});
auto const_target_shape_1 = makeConst({seq_len, batch, head_cnt, ndims / 2, 2});
reshape_Reshape_453 = makePattern<opset1::Reshape>(
{slice_Slice_437 | var_split_1->output(0), ListConstruct_452_Concat | const_target_shape_1});
reshape_Reshape_453 =
makePattern<opset1::Reshape>({slice_Slice_437 | var_split_1->output(0),
ListConstruct_452_Concat | const_target_shape_1 | const_target_shape_0});
}

auto x_even = makePattern<opset8::Gather>({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}});
Expand Down Expand Up @@ -588,6 +597,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
} else {
auto ListConstruct_379_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_0 = makeConst({1, -1, 1, ndims / 2, 2});
auto const_target_shape_2 = makeConst({seq_len, batch, 1, ndims / 2, 2});

auto slice_Slice_449 = makePattern<ov::opset8::Slice>({cos_sin_cache, {0}, seq_length, {1}, {0}});
Expand All @@ -596,7 +606,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
// [seq_length, 1, batch, half_rotary_dims, 2]
view_Reshape_460 =
makePattern<opset1::Reshape>({slice_StridedSlice_449 | slice_Slice_449 | var_split_2->output(0),
ListConstruct_379_Concat | const_target_shape_2},
ListConstruct_379_Concat | const_target_shape_0 | const_target_shape_2},
{{"special_zero", false}});
}

Expand All @@ -609,12 +619,17 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
auto sub_Subtract_469 = makePattern<opset1::Add>({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}});

auto y_even = makePattern<opset1::Unsqueeze>({sub_Subtract_469, -1});
auto const_y_even_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1});
auto y_even_reshape =
makePattern<opset1::Reshape>({sub_Subtract_469, const_y_even_reshape}, {{"special_zero", false}});
auto x_odd_cos = makePattern<opset1::Multiply>({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}});
auto x_even_sin = makePattern<opset1::Multiply>({x_even, sin_tab}, {{"auto_broadcast", "numpy"}});
auto add_Add_476 = makePattern<opset1::Add>({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}});
auto y_odd = makePattern<opset1::Unsqueeze>({add_Add_476, -1});
auto const_y_odd_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1});
auto y_odd_reshape = makePattern<opset1::Reshape>({add_Add_476, const_y_odd_reshape}, {{"special_zero", false}});

auto stack_481 = makePattern<opset1::Concat>({y_even, y_odd}, {{"axis", -1}});
auto stack_481 = makePattern<opset1::Concat>({y_even | y_even_reshape, y_odd | y_odd_reshape}, {{"axis", -1}});

auto ShapeOf_135133 = makePattern<opset1::ShapeOf>({stack_481});
auto flatten_Slice_497 = GenSlice(ShapeOf_135133, 0, 3, 1, 0);
Expand All @@ -629,9 +644,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
{{"special_zero", true}});
} else {
// [length, batch, head_cnt, half_rotary_dims, 2]
auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims});
const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims});
flatten_Reshape_501 = makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape_3},
{{"special_zero", true}});
flatten_Reshape_501 =
makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape_0 | const_target_shape_3},
{{"special_zero", true}});
}
auto slice_Slice_443 = GenSlice(input_key, ndims, INT_MAX, 1, 3);

Expand Down
Loading

0 comments on commit c8ef5ba

Please sign in to comment.