Skip to content

Commit

Permalink
fix SDPAToPA, add unit test for qwen7bchat model
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Dec 20, 2024
1 parent 96456f6 commit db7642d
Show file tree
Hide file tree
Showing 6 changed files with 566 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,11 @@ class AttrSetter : public ov::AttributeVisitor {
a->set(m_attr_map[name].as_vector<int64_t>());
} else if (auto a = ov::as_type<ov::AttributeAdapter<ov::element::TypeVector>>(&adapter)) {
a->set(m_attr_map[name].as_T_vector<ov::element::Type>());
} else if (auto a = dynamic_cast<ov::AttributeAdapter<std::shared_ptr<ov::op::util::Variable>>*>(&adapter)) {
ov::op::util::VariableInfo var_info;
var_info.variable_id = m_attr_map[name].as_string();
auto variable = std::make_shared<ov::op::util::Variable>(var_info);
a->set(variable);
} else {
OPENVINO_THROW("unsupported AttributeAdapter for attribute : ", name);
}
Expand Down Expand Up @@ -896,6 +901,7 @@ struct PatternNode {
// scalar constant (treated as wildcard for single-element-constant with any rank)
PatternNode(int v) : node(std::make_shared<ov::op::v0::Constant>(element::from<int>(), Shape({}), v)) {}
PatternNode(float v) : node(std::make_shared<ov::op::v0::Constant>(element::from<float>(), Shape({}), v)) {}
PatternNode(long long v) : node(std::make_shared<ov::op::v0::Constant>(element::from<int64_t>(), Shape({}), v)) {}

PatternNode(std::initializer_list<int> v, values_info vi = nullptr) {
node = ConstVector(std::vector<int>(v), vi);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "openvino/core/model.hpp"
#include "openvino/core/node.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations/utils/utils.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
parameters_to_remove.push_back(param);
}

pa_transpose->set_friendly_name(sdpa_node->get_friendly_name());
replace_node(m.get_match_root(), pa_transpose);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

Expand Down Expand Up @@ -112,22 +110,21 @@ ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen(

auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, any_input()});
auto p_opt_reshape_2 = optional<v1::Reshape>({p_unsqueeze, any_input()});
auto p_opt_convert_2 = optional<v0::Convert>(p_opt_reshape_2);
auto p_kv_shape_current = wrap_type<v3::ShapeOf>({p_opt_convert_2});
auto p_opt_reshape_1 = optional<v1::Reshape>({p_unsqueeze, any_input()});
auto p_opt_convert_1 = optional<v0::Convert>(p_opt_reshape_1);
auto p_kv_shape_current = wrap_type<v3::ShapeOf>({p_opt_convert_1});
auto p_seq_current = wrap_type<v8::Gather>({p_kv_shape_current, any_input(), any_input()});
auto p_opt_convert_2 = optional<v0::Convert>(p_seq_current);

auto p_max_context_len = wrap_type<v0::Parameter>();
auto p_prev_max_seq_len = wrap_type<v1::Subtract>({max_context_len, any_input()});
auto p_opt_convert_1 = optional<v0::Convert>(p_prev_max_seq_len);
auto opt_reshape_1 = optional<v1::Reshape>({p_opt_convert_1, p_seq_current});

auto p_total_seq = wrap_type<v1::Add>({p_seq_current, opt_reshape_1});
auto p_prev_max_seq_len = wrap_type<v1::Subtract>({p_max_context_len, any_input()});
auto p_opt_convert_3 = optional<v0::Convert>(p_prev_max_seq_len);
auto p_opt_reshape_2 = optional<v1::Reshape>({p_opt_convert_3, any_input()});
auto p_total_seq = wrap_type<v1::Add>({p_opt_convert_2, p_opt_reshape_2});

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto total_seq = pattern_map.at(p_total_seq).get_node_shared_ptr();

std::shared_ptr<Node> replacement = max_context_len;

auto target_type = total_seq->get_output_element_type(0);
Expand Down
Loading

0 comments on commit db7642d

Please sign in to comment.