diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 049537a3e52835..7b0913525c2545 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -115,9 +115,49 @@ struct Group { return op_set; } - // TODO(phlrain) : impliment GetInputNodeDatas GetOutputNodeDatas func - // std::unordered_set<::pir::Value> GetInputNodeDatas() { return {}; } - // std::unordered_set<::pir::Value> GetOutputNodeDatas() { return {}; } + std::unordered_set<::pir::Value> GetInputOpValues() { + std::unordered_set<::pir::Value> group_inputs; + auto ops_set = this->OpSet(); + // count all op's input Value + for (auto op : this->CollectOps()) { + for (auto& value : op->operands_source()) { + if (!value || !value.type()) { + continue; + } + + if (!ops_set.count(value.dyn_cast<::pir::OpResult>().owner())) { + // if the input value owner op is not in OpSet, it's the group's input + group_inputs.insert(value); + continue; + } + + if (std::find(this->input_names.begin(), + this->input_names.end(), + CompatibleInfo::ValueName(value)) != + this->input_names.end()) { + // if the input data in group's input_names + group_inputs.insert(value); + continue; + } + } + } + + return group_inputs; + } + std::unordered_set<::pir::Value> GetOutputOpValues() { + std::unordered_set<::pir::Value> group_outputs; + + for (auto op : this->output_ops) { + for (auto& result : op->results()) { + if (!result || result.type()) { + continue; + } + + group_outputs.insert(result); + } + } + return group_outputs; + } std::string GetFuncName() { return "fn_" + group_id + unique_id; }