Skip to content

Commit

Permalink
[PIR+CINN]Support GetInput/OutputOpValues for Group (PaddlePaddle#58702)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored and SecretXV committed Nov 28, 2023
1 parent e6c6851 commit 002db3c
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,49 @@ struct Group {
return op_set;
}

// TODO(phlrain) : impliment GetInputNodeDatas GetOutputNodeDatas func
// std::unordered_set<::pir::Value> GetInputNodeDatas() { return {}; }
// std::unordered_set<::pir::Value> GetOutputNodeDatas() { return {}; }
std::unordered_set<::pir::Value> GetInputOpValues() {
std::unordered_set<::pir::Value> group_inputs;
auto ops_set = this->OpSet();
// count all op's input Value
for (auto op : this->CollectOps()) {
for (auto& value : op->operands_source()) {
if (!value || !value.type()) {
continue;
}

if (!ops_set.count(value.dyn_cast<::pir::OpResult>().owner())) {
// if the input value owner op is not in OpSet, it's the group's input
group_inputs.insert(value);
continue;
}

if (std::find(this->input_names.begin(),
this->input_names.end(),
CompatibleInfo::ValueName(value)) !=
this->input_names.end()) {
// if the input data in group's input_names
group_inputs.insert(value);
continue;
}
}
}

return group_inputs;
}
std::unordered_set<::pir::Value> GetOutputOpValues() {
std::unordered_set<::pir::Value> group_outputs;

for (auto op : this->output_ops) {
for (auto& result : op->results()) {
if (!result || result.type()) {
continue;
}

group_outputs.insert(result);
}
}
return group_outputs;
}

std::string GetFuncName() { return "fn_" + group_id + unique_id; }

Expand Down

0 comments on commit 002db3c

Please sign in to comment.