Skip to content

Commit

Permalink
[PIR] Fix bug of BuildScope for IfOp (PaddlePaddle#58109)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
zhangbo9674 authored and wentaoyu committed Oct 24, 2023
1 parent c50df4e commit c2b20dc
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ void InstructionBase::InitInputsOutputsIds(

std::string InstructionBase::DebugStringEx(
const paddle::framework::Scope* scope,
const std::unordered_map<::pir::Value, std::string>& value_2_var_name)
const {
ValueExecutionInfo* value_exe_info) const {
std::stringstream ss;
ss << "Op(" << Name() << "), inputs:{";

Expand All @@ -268,7 +267,7 @@ std::string InstructionBase::DebugStringEx(
auto& input = *it;
bool is_no_need_buffer_var = (!no_need_buffer_vars.empty() &&
no_need_buffer_vars.count(input.first) > 0);
auto var_name = value_2_var_name.at(input.first);
auto var_name = value_exe_info->GetVarName(input.first);
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
Expand Down Expand Up @@ -296,7 +295,7 @@ std::string InstructionBase::DebugStringEx(
ss << "}, outputs:{";
for (auto it = Outputs().begin(); it != Outputs().end();) {
auto& output = *it;
auto var_name = value_2_var_name.at(output.first);
auto var_name = value_exe_info->GetVarName(output.first);
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,8 @@ class InstructionBase {
const ValueExecutionInfo& value_exec_info);

// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(
const paddle::framework::Scope* scope,
const std::unordered_map<::pir::Value, std::string>& value_2_var_name)
const;
virtual std::string DebugStringEx(const paddle::framework::Scope* scope,
ValueExecutionInfo* value_exe_info) const;

protected:
size_t id_;
Expand Down
24 changes: 15 additions & 9 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ Scope* NewIRInterpreter::InnerScope() const {
}

std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const {
return value_exe_info_->GetValue2VarName().at(value);
return value_exe_info_->GetVarName(value);
}

void NewIRInterpreter::UpdateSyncOpNum() {
Expand Down Expand Up @@ -627,7 +627,7 @@ std::string NewIRInterpreter::DebugValueInfo() {
PADDLE_ENFORCE((bool)kv.first,
platform::errors::PreconditionNotMet(
"vlaue(%s) should not be nullptr", kv.second));
PADDLE_ENFORCE(value_exe_info_->GetVarName2Id().count(kv.second) > 0,
PADDLE_ENFORCE(value_exe_info_->HasVar(kv.second),
platform::errors::PreconditionNotMet(
"var(%s) should exist in var_name_2_id_", kv.second));
auto* var = InnerScope()->FindVar(kv.second);
Expand All @@ -636,8 +636,7 @@ std::string NewIRInterpreter::DebugValueInfo() {
platform::errors::PreconditionNotMet(
"var(%s) should exist in scope (%p)", kv.second, InnerScope()));
os << kv.first.impl() << " -> " << kv.second << " -> "
<< value_exe_info_->GetVarName2Id().at(kv.second) << " -> " << var
<< "\n";
<< value_exe_info_->GetVarId(kv.first) << " -> " << var << "\n";
}
return os.str();
}
Expand Down Expand Up @@ -857,6 +856,7 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
}

void NewIRInterpreter::CalculateLastLiveOps() {
VLOG(4) << "NewIRInterpreter(): " << this << " start CalculateLastLiveOps";
// calculate last_live_ops_
for (size_t op_idx = 0; op_idx < vec_instruction_base_.size(); ++op_idx) {
InstructionBase* instr = vec_instruction_base_[op_idx].get();
Expand All @@ -882,11 +882,16 @@ void NewIRInterpreter::CalculateLastLiveOps() {
gc_check_vars.insert(var_id);
}
}
VLOG(4) << "get gc check vars for: " << instr->Name();

for (auto var_id : gc_check_vars) {
Scope* inner_scope = InnerScope();
paddle::framework::Variable* var = inner_scope->FindVar(
value_exe_info_->GetNameById(static_cast<int>(var_id)));
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("Var(id=%d) should not be nullptr.",
static_cast<int>(var_id)));
if (var->IsType<phi::DenseTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>() ||
var->IsType<phi::SparseCooTensor>() ||
Expand All @@ -899,6 +904,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
<< framework::ToTypeName(var->Type());
}
}
VLOG(4) << "update last_live_ops for: " << instr->Name();
}
// clear the last_live_ops list for all vars in skip_gc_vars
for (const std::string& skip_gc_var : execution_config_.skip_gc_vars) {
Expand All @@ -908,7 +914,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
VLOG(8) << "Skip gc for var: " << skip_gc_var;
}
}
VLOG(4) << "calculate last_live_ops_";
VLOG(4) << "clear the last_live_ops list for all vars in skip_gc_vars";

// shrink, find the downstream op that has no other op in the
// downstream list happens before it
Expand Down Expand Up @@ -949,6 +955,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
last_live_ops_[i] = minumum_last_live_ops;
var_ref_count_[i] = static_cast<int>(last_live_ops_[i].size());
}
VLOG(4) << "shrink the last_live_ops list for all vars in skip_gc_vars";

for (auto& dep : *dependecy_count_) {
deps_.emplace_back(std::make_shared<interpreter::OpDepInfo>(dep));
Expand All @@ -957,6 +964,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>(
var_ref_count_[i], value_exe_info_->GetVarList()[i]));
}
VLOG(4) << "done CalculateLastLiveOps";
}

void NewIRInterpreter::ConstructEventForJitInput() {
Expand Down Expand Up @@ -1410,8 +1418,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) {
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
VLOG(4) << place_ << " "
<< instr_node->DebugStringEx(scope_,
value_exe_info_->GetValue2VarName());
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
if (!instr_node->IsArtificial()) {
instr_node->Run();

Expand All @@ -1433,8 +1440,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) {
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
VLOG(4) << place_ << " "
<< instr_node->DebugStringEx(scope_,
value_exe_info_->GetValue2VarName());
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
CheckGC(instr_node);
VLOG(4) << "done CheckGC";
interpreter::LogDeviceMemoryStats(place_);
Expand Down
104 changes: 18 additions & 86 deletions paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ std::shared_ptr<ValueExecutionInfo> ValueExecutionInfo::NewChild(Scope* scope) {
std::shared_ptr<ValueExecutionInfo> info =
std::make_shared<ValueExecutionInfo>(scope);
info->parent_ = this;
info->value_2_var_name_ = this->value_2_var_name_;
info->var_2_var_name_ = this->var_2_var_name_;
info->var_name_2_id_ = this->var_name_2_id_;
info->id_2_var_name_ = this->id_2_var_name_;
info->var_list_ = this->var_list_;
return info;
}

Expand Down Expand Up @@ -157,120 +162,49 @@ void ValueExecutionInfo::ResetVarList(int id, Variable* var) {
var_list_[id] = var;
}

bool ValueExecutionInfo::HasValue(::pir::Value value) const {
return HasValueInternal(value);
}

bool ValueExecutionInfo::HasLocalValue(::pir::Value value) const {
return HasValueLocally(value);
}

std::string ValueExecutionInfo::GetVarName(::pir::Value value) const {
return GetVarNameInternal(value);
}

std::string ValueExecutionInfo::GetVarName(const Variable* var) const {
return GetVarNameInternal(var);
}

std::string ValueExecutionInfo::GetLocalVarName(::pir::Value value) const {
return GetVarNameLocally(value);
}

std::string ValueExecutionInfo::GetLocalVarName(const Variable* var) const {
return GetVarNameLocally(var);
}

int ValueExecutionInfo::GetVarId(::pir::Value value) const {
return GetVarIdInternal(value);
}

int ValueExecutionInfo::GetVarId(const Variable* var) const {
return GetVarIdInternal(var);
}

int ValueExecutionInfo::GetLocalVarId(::pir::Value value) const {
return GetVarIdLocally(value);
}

int ValueExecutionInfo::GetLocalVarId(const Variable* var) const {
return GetVarIdLocally(var);
}

bool ValueExecutionInfo::HasValueInternal(::pir::Value value) const {
if (HasValueLocally(value)) {
bool ValueExecutionInfo::HasVar(const std::string& var_name) const {
auto it = var_name_2_id_.find(var_name);
if (it != var_name_2_id_.end()) {
return true;
}
return (parent_ == nullptr) ? false : parent_->HasValueInternal(value);
return false;
}

bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const {
bool ValueExecutionInfo::HasValue(::pir::Value value) const {
auto it = value_2_var_name_.find(value);
if (it != value_2_var_name_.end()) {
return true;
}
return false;
}

std::string ValueExecutionInfo::GetVarNameInternal(::pir::Value value) const {
auto name = GetVarNameLocally(value);
if (name != "") {
return name;
}
return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(value);
}

std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const {
std::string ValueExecutionInfo::GetVarName(::pir::Value value) const {
auto it = value_2_var_name_.find(value);
if (it != value_2_var_name_.end()) {
return it->second;
}
return "";
}

std::string ValueExecutionInfo::GetVarNameInternal(const Variable* var) const {
auto name = GetVarNameLocally(var);
if (name != "") {
return name;
}
return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(var);
}

std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const {
std::string ValueExecutionInfo::GetVarName(const Variable* var) const {
auto it = var_2_var_name_.find(var);
if (it != var_2_var_name_.end()) {
return it->second;
}
return "";
}

int ValueExecutionInfo::GetVarIdInternal(::pir::Value value) const {
auto id = GetVarIdLocally(value);
if (id != -1) {
return id;
}
return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(value);
}

int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const {
auto var_name = GetVarNameLocally(value);
int ValueExecutionInfo::GetVarId(::pir::Value value) const {
auto var_name = GetVarName(value);
auto it = var_name_2_id_.find(var_name);
if (it != var_name_2_id_.end()) {
return it->second;
}
return -1;
}

int ValueExecutionInfo::GetVarIdInternal(const Variable* var) const {
auto id = GetVarIdLocally(var);
if (id != -1) {
return id;
}
return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(var);
}

int ValueExecutionInfo::GetVarIdLocally(const Variable* var) const {
auto var_name = GetVarNameLocally(var);
int ValueExecutionInfo::GetVarId(const Variable* var) const {
auto var_name = GetVarName(var);
auto it = var_name_2_id_.find(var_name);
if (it != var_name_2_id_.end()) {
return it->second;
Expand Down Expand Up @@ -608,8 +542,7 @@ void HandleForInplaceOp(pir::Operation* op,
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
pir::Value inplace_value =
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name =
value_exe_info->GetValue2VarName().at(inplace_value);
std::string var_name = value_exe_info->GetVarName(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
value_exe_info->AddValue2VarName(value, var_name);
Expand All @@ -618,8 +551,7 @@ void HandleForInplaceOp(pir::Operation* op,
pir::Value view_value =
op->operand_source(yaml_parser.InputName2Id().at(view_name));
// const std::string& var_name = value_2_var_name->at(view_value);
const std::string& var_name =
value_exe_info->GetValue2VarName().at(view_value);
std::string var_name = value_exe_info->GetVarName(view_value);
VLOG(4) << "view: " << value_name << " -> " << view_name
<< " (var: " << var_name << ")";
value_exe_info->AddValue2VarName(value, var_name);
Expand Down
34 changes: 2 additions & 32 deletions paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,49 +79,19 @@ class ValueExecutionInfo {

void ResetVarList(int id, Variable* var);

/// Check a value exist in the ValueExecutionInfo or any of its ancestors.
bool HasValue(::pir::Value value) const;
bool HasVar(const std::string& var_name) const;

/// Check a value exist in the ValueExecutionInfo.
bool HasLocalValue(::pir::Value value) const;
bool HasValue(::pir::Value value) const;

std::string GetVarName(::pir::Value value) const;

std::string GetVarName(const Variable* var) const;

std::string GetLocalVarName(::pir::Value value) const;

std::string GetLocalVarName(const Variable* var) const;

int GetVarId(::pir::Value value) const;

int GetVarId(const Variable* var) const;

int GetLocalVarId(::pir::Value value) const;

int GetLocalVarId(const Variable* var) const;

private:
bool HasValueInternal(::pir::Value value) const;

bool HasValueLocally(::pir::Value value) const;

std::string GetVarNameInternal(::pir::Value value) const;

std::string GetVarNameLocally(::pir::Value value) const;

std::string GetVarNameInternal(const Variable* var) const;

std::string GetVarNameLocally(const Variable* var) const;

int GetVarIdInternal(::pir::Value value) const;

int GetVarIdLocally(::pir::Value value) const;

int GetVarIdInternal(const Variable* var) const;

int GetVarIdLocally(const Variable* var) const;

std::shared_ptr<ValueExecutionInfo> NewChild(Scope* scope);

ValueExecutionInfo* parent_{nullptr}; // not owned
Expand Down

0 comments on commit c2b20dc

Please sign in to comment.