diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc index 0b494c29dea86d..8198598fd5954a 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc @@ -257,8 +257,7 @@ void InstructionBase::InitInputsOutputsIds( std::string InstructionBase::DebugStringEx( const paddle::framework::Scope* scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name) - const { + ValueExecutionInfo* value_exe_info) const { std::stringstream ss; ss << "Op(" << Name() << "), inputs:{"; @@ -268,7 +267,7 @@ std::string InstructionBase::DebugStringEx( auto& input = *it; bool is_no_need_buffer_var = (!no_need_buffer_vars.empty() && no_need_buffer_vars.count(input.first) > 0); - auto var_name = value_2_var_name.at(input.first); + auto var_name = value_exe_info->GetVarName(input.first); ss << var_name; if (scope) { if (!VarInited(*scope, var_name)) { @@ -296,7 +295,7 @@ std::string InstructionBase::DebugStringEx( ss << "}, outputs:{"; for (auto it = Outputs().begin(); it != Outputs().end();) { auto& output = *it; - auto var_name = value_2_var_name.at(output.first); + auto var_name = value_exe_info->GetVarName(output.first); ss << var_name; if (scope) { if (!VarInited(*scope, var_name)) { diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.h b/paddle/fluid/framework/new_executor/instruction/instruction_base.h index 60797426119154..5dd7ff3e4d2a5d 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.h @@ -144,10 +144,8 @@ class InstructionBase { const ValueExecutionInfo& value_exec_info); // if scope is not null, also show dimensions of arguments - virtual std::string DebugStringEx( - const paddle::framework::Scope* scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name) - const; + virtual std::string DebugStringEx(const paddle::framework::Scope* scope, + ValueExecutionInfo* value_exe_info) const; protected: size_t id_; diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 50af034414d6f2..92432a602dd287 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -390,7 +390,7 @@ Scope* NewIRInterpreter::InnerScope() const { } std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const { - return value_exe_info_->GetValue2VarName().at(value); + return value_exe_info_->GetVarName(value); } void NewIRInterpreter::UpdateSyncOpNum() { @@ -627,7 +627,7 @@ std::string NewIRInterpreter::DebugValueInfo() { PADDLE_ENFORCE((bool)kv.first, platform::errors::PreconditionNotMet( "vlaue(%s) should not be nullptr", kv.second)); - PADDLE_ENFORCE(value_exe_info_->GetVarName2Id().count(kv.second) > 0, + PADDLE_ENFORCE(value_exe_info_->HasVar(kv.second), platform::errors::PreconditionNotMet( "var(%s) should exist in var_name_2_id_", kv.second)); auto* var = InnerScope()->FindVar(kv.second); @@ -636,8 +636,7 @@ std::string NewIRInterpreter::DebugValueInfo() { platform::errors::PreconditionNotMet( "var(%s) should exist in scope (%p)", kv.second, InnerScope())); os << kv.first.impl() << " -> " << kv.second << " -> " - << value_exe_info_->GetVarName2Id().at(kv.second) << " -> " << var - << "\n"; + << value_exe_info_->GetVarId(kv.first) << " -> " << var << "\n"; } return os.str(); } @@ -857,6 +856,7 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) { } void NewIRInterpreter::CalculateLastLiveOps() { + VLOG(4) << "NewIRInterpreter(): " << this << " start CalculateLastLiveOps"; // calculate last_live_ops_ for (size_t op_idx = 0; op_idx < vec_instruction_base_.size(); ++op_idx) { InstructionBase* instr = vec_instruction_base_[op_idx].get(); @@ -882,11 +882,16 @@ void NewIRInterpreter::CalculateLastLiveOps() { gc_check_vars.insert(var_id); } } + VLOG(4) << "get gc check vars for: " << instr->Name(); for (auto var_id : gc_check_vars) { Scope* inner_scope = InnerScope(); paddle::framework::Variable* var = inner_scope->FindVar( value_exe_info_->GetNameById(static_cast(var_id))); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound("Var(id=%d) should not be nullptr.", + static_cast(var_id))); if (var->IsType() || var->IsType() || var->IsType() || var->IsType() || @@ -899,6 +904,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { << framework::ToTypeName(var->Type()); } } + VLOG(4) << "update last_live_ops for: " << instr->Name(); } // clear the last_live_ops list for all vars in skip_gc_vars for (const std::string& skip_gc_var : execution_config_.skip_gc_vars) { @@ -908,7 +914,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { VLOG(8) << "Skip gc for var: " << skip_gc_var; } } - VLOG(4) << "calculate last_live_ops_"; + VLOG(4) << "clear the last_live_ops list for all vars in skip_gc_vars"; // shrink, find the downstream op that has no other op in the // downstream list happens before it @@ -949,6 +955,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { last_live_ops_[i] = minumum_last_live_ops; var_ref_count_[i] = static_cast(last_live_ops_[i].size()); } + VLOG(4) << "shrink the last_live_ops list for all vars in skip_gc_vars"; for (auto& dep : *dependecy_count_) { deps_.emplace_back(std::make_shared(dep)); @@ -957,6 +964,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { refs_.emplace_back(std::make_shared( var_ref_count_[i], value_exe_info_->GetVarList()[i])); } + VLOG(4) << "done CalculateLastLiveOps"; } void NewIRInterpreter::ConstructEventForJitInput() { @@ -1410,8 +1418,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { : "kGpuAsync")) << " runs on " << platform::GetCurrentThreadName(); VLOG(4) << place_ << " " - << instr_node->DebugStringEx(scope_, - value_exe_info_->GetValue2VarName()); + << instr_node->DebugStringEx(scope_, value_exe_info_.get()); if (!instr_node->IsArtificial()) { instr_node->Run(); @@ -1433,8 +1440,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { : "kGpuAsync")) << " runs on " << platform::GetCurrentThreadName(); VLOG(4) << place_ << " " - << instr_node->DebugStringEx(scope_, - value_exe_info_->GetValue2VarName()); + << instr_node->DebugStringEx(scope_, value_exe_info_.get()); CheckGC(instr_node); VLOG(4) << "done CheckGC"; interpreter::LogDeviceMemoryStats(place_); diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 409532aa59560f..3ae75ffd870088 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -50,6 +50,11 @@ std::shared_ptr ValueExecutionInfo::NewChild(Scope* scope) { std::shared_ptr info = std::make_shared(scope); info->parent_ = this; + info->value_2_var_name_ = this->value_2_var_name_; + info->var_2_var_name_ = this->var_2_var_name_; + info->var_name_2_id_ = this->var_name_2_id_; + info->id_2_var_name_ = this->id_2_var_name_; + info->var_list_ = this->var_list_; return info; } @@ -157,54 +162,15 @@ void ValueExecutionInfo::ResetVarList(int id, Variable* var) { var_list_[id] = var; } -bool ValueExecutionInfo::HasValue(::pir::Value value) const { - return HasValueInternal(value); -} - -bool ValueExecutionInfo::HasLocalValue(::pir::Value value) const { - return HasValueLocally(value); -} - -std::string ValueExecutionInfo::GetVarName(::pir::Value value) const { - return GetVarNameInternal(value); -} - -std::string ValueExecutionInfo::GetVarName(const Variable* var) const { - return GetVarNameInternal(var); -} - -std::string ValueExecutionInfo::GetLocalVarName(::pir::Value value) const { - return GetVarNameLocally(value); -} - -std::string ValueExecutionInfo::GetLocalVarName(const Variable* var) const { - return GetVarNameLocally(var); -} - -int ValueExecutionInfo::GetVarId(::pir::Value value) const { - return GetVarIdInternal(value); -} - -int ValueExecutionInfo::GetVarId(const Variable* var) const { - return GetVarIdInternal(var); -} - -int ValueExecutionInfo::GetLocalVarId(::pir::Value value) const { - return GetVarIdLocally(value); -} - -int ValueExecutionInfo::GetLocalVarId(const Variable* var) const { - return GetVarIdLocally(var); -} - -bool ValueExecutionInfo::HasValueInternal(::pir::Value value) const { - if (HasValueLocally(value)) { +bool ValueExecutionInfo::HasVar(const std::string& var_name) const { + auto it = var_name_2_id_.find(var_name); + if (it != var_name_2_id_.end()) { return true; } - return (parent_ == nullptr) ? false : parent_->HasValueInternal(value); + return false; } -bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const { +bool ValueExecutionInfo::HasValue(::pir::Value value) const { auto it = value_2_var_name_.find(value); if (it != value_2_var_name_.end()) { return true; @@ -212,15 +178,7 @@ bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const { return false; } -std::string ValueExecutionInfo::GetVarNameInternal(::pir::Value value) const { - auto name = GetVarNameLocally(value); - if (name != "") { - return name; - } - return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(value); -} - -std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const { +std::string ValueExecutionInfo::GetVarName(::pir::Value value) const { auto it = value_2_var_name_.find(value); if (it != value_2_var_name_.end()) { return it->second; @@ -228,15 +186,7 @@ std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const { return ""; } -std::string ValueExecutionInfo::GetVarNameInternal(const Variable* var) const { - auto name = GetVarNameLocally(var); - if (name != "") { - return name; - } - return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(var); -} - -std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const { +std::string ValueExecutionInfo::GetVarName(const Variable* var) const { auto it = var_2_var_name_.find(var); if (it != var_2_var_name_.end()) { return it->second; @@ -244,16 +194,8 @@ std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const { return ""; } -int ValueExecutionInfo::GetVarIdInternal(::pir::Value value) const { - auto id = GetVarIdLocally(value); - if (id != -1) { - return id; - } - return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(value); -} - -int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const { - auto var_name = GetVarNameLocally(value); +int ValueExecutionInfo::GetVarId(::pir::Value value) const { + auto var_name = GetVarName(value); auto it = var_name_2_id_.find(var_name); if (it != var_name_2_id_.end()) { return it->second; @@ -261,16 +203,8 @@ int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const { return -1; } -int ValueExecutionInfo::GetVarIdInternal(const Variable* var) const { - auto id = GetVarIdLocally(var); - if (id != -1) { - return id; - } - return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(var); -} - -int ValueExecutionInfo::GetVarIdLocally(const Variable* var) const { - auto var_name = GetVarNameLocally(var); +int ValueExecutionInfo::GetVarId(const Variable* var) const { + auto var_name = GetVarName(var); auto it = var_name_2_id_.find(var_name); if (it != var_name_2_id_.end()) { return it->second; @@ -608,8 +542,7 @@ void HandleForInplaceOp(pir::Operation* op, const std::string& inplace_name = yaml_parser.InplaceName(value_name); pir::Value inplace_value = op->operand_source(yaml_parser.InputName2Id().at(inplace_name)); - std::string var_name = - value_exe_info->GetValue2VarName().at(inplace_value); + std::string var_name = value_exe_info->GetVarName(inplace_value); VLOG(4) << "inplace: " << value_name << " -> " << inplace_name << " (var: " << var_name << ")"; value_exe_info->AddValue2VarName(value, var_name); @@ -618,8 +551,7 @@ void HandleForInplaceOp(pir::Operation* op, pir::Value view_value = op->operand_source(yaml_parser.InputName2Id().at(view_name)); // const std::string& var_name = value_2_var_name->at(view_value); - const std::string& var_name = - value_exe_info->GetValue2VarName().at(view_value); + std::string var_name = value_exe_info->GetVarName(view_value); VLOG(4) << "view: " << value_name << " -> " << view_name << " (var: " << var_name << ")"; value_exe_info->AddValue2VarName(value, var_name); diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index e0337313da2600..ce0484567b64f0 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -79,49 +79,19 @@ class ValueExecutionInfo { void ResetVarList(int id, Variable* var); - /// Check a value exist in the ValueExecutionInfo or any of its ancestors. - bool HasValue(::pir::Value value) const; + bool HasVar(const std::string& var_name) const; - /// Check a value exist in the ValueExecutionInfo. - bool HasLocalValue(::pir::Value value) const; + bool HasValue(::pir::Value value) const; std::string GetVarName(::pir::Value value) const; std::string GetVarName(const Variable* var) const; - std::string GetLocalVarName(::pir::Value value) const; - - std::string GetLocalVarName(const Variable* var) const; - int GetVarId(::pir::Value value) const; int GetVarId(const Variable* var) const; - int GetLocalVarId(::pir::Value value) const; - - int GetLocalVarId(const Variable* var) const; - private: - bool HasValueInternal(::pir::Value value) const; - - bool HasValueLocally(::pir::Value value) const; - - std::string GetVarNameInternal(::pir::Value value) const; - - std::string GetVarNameLocally(::pir::Value value) const; - - std::string GetVarNameInternal(const Variable* var) const; - - std::string GetVarNameLocally(const Variable* var) const; - - int GetVarIdInternal(::pir::Value value) const; - - int GetVarIdLocally(::pir::Value value) const; - - int GetVarIdInternal(const Variable* var) const; - - int GetVarIdLocally(const Variable* var) const; - std::shared_ptr NewChild(Scope* scope); ValueExecutionInfo* parent_{nullptr}; // not owned