diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index c1285f5d3eb93..27f83a266ec9c 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -390,7 +390,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_library(generator SRCS generator.cc DEPS enforce place) -cc_library(tcmpt_utils SRCS tcmpt_utils.cc DEPS lod_tensor selected_rows place tcmpt) +cc_library(tcmpt_utils SRCS tcmpt_utils.cc DEPS lod_tensor selected_rows place tcmpt var_type_traits) # Get the current working branch execute_process( @@ -454,3 +454,4 @@ if(WITH_TESTING AND TEST selected_rows_test) endif() cc_test(scope_guard_test SRCS scope_guard_test.cc) +cc_test(tcmpt_utils_test SRCS tcmpt_utils_test.cc DEPS tcmpt_utils) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2ea761944671b..7cadf53cc5299 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -28,7 +28,6 @@ limitations under the License. */ #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/imperative/kernel_args_names_maker.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler.h" @@ -51,7 +50,7 @@ DECLARE_bool(check_nan_inf); DECLARE_bool(enable_unused_var_check); PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0, "number of threads for inner op"); -DECLARE_bool(use_pt_kernel); +DECLARE_bool(run_pt_kernel); namespace paddle { namespace framework { @@ -1077,22 +1076,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, this->InferShape(&infer_shape_ctx); } -OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key) { - proto::VarType::Type data_type = pt::TransToProtoVarType(kernel_key.dtype()); - platform::Place place = pt::TransToFluidPlace(kernel_key.backend()); - DataLayout data_layout = pt::TransToFluidDataLayout(kernel_key.layout()); - LibraryType library_type = LibraryType::kPlain; - if (kernel_key.backend() == pt::Backend::kMKLDNN) { - library_type = LibraryType::kMKLDNN; - } else if (kernel_key.backend() == pt::Backend::kCUDNN) { - library_type = LibraryType::kCUDNN; - } else { - // do nothing - } - // TODO(chenweihang): the customized_type_value is lost - return OpKernelType(data_type, place, data_layout, library_type); -} - static std::string RuntimeContextDebugString(const RuntimeContext& ctx) { std::stringstream ss; ss << "RuntimeContext(Inputs: "; @@ -1149,22 +1132,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } #endif + auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx); + // TODO(chenweihang): Now we are still reusing a lot of the original fluid // implementation, this is a gradual replacement process // TODO(chenweihang): in the first phase of project, we only support CPU, CUDA // and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second // phase - - if (FLAGS_use_pt_kernel && + if (FLAGS_run_pt_kernel && pt::KernelFactory::Instance().ContainsKernel(type_.c_str())) { - if (pt_kernel_key_.get() == nullptr || pt_kernel_.get() == nullptr) { - ChoosePtKernel(*runtime_ctx, *dev_ctx); + if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) { + ChoosePtKernel(exe_ctx); } run_pt_kernel_ = pt_kernel_->IsValid(); } if (!run_pt_kernel_) { if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { - ChooseKernel(*runtime_ctx, scope, place); + ChooseKernel(exe_ctx); } } @@ -1175,10 +1159,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::RecordEvent record_event("prepare_data", platform::EventRole::kInnerOp); if (need_prepare_data_) { - if (run_pt_kernel_) { - kernel_type_.reset( - new OpKernelType(TransPtKernelKeyToOpKernelType(*pt_kernel_key_))); - } transfer_scope = PrepareData(scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); } @@ -1208,8 +1188,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::RecordEvent record_event("compute", platform::EventRole::kInnerOp); if (run_pt_kernel_) { - // TODO(chenweihang): here will intrduce copy - auto op_kernel_ctx = ConstructPtKernelContext(*runtime_ctx, *dev_ctx); + auto op_kernel_ctx = BuildPtKernelContext(*runtime_ctx, *dev_ctx); (*pt_kernel_)(&op_kernel_ctx); } else { (*kernel_func_)( @@ -1262,104 +1241,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } -// TODO(chenweihang): now only check single var input -static bool IsValidVar(const std::string& name, - const VariableValueMap& inputs) { - auto it = inputs.find(name); - if (it == inputs.end()) { - return false; - } - auto* var = it->second.empty() ? nullptr : it->second[0]; - return var != nullptr; -} - -// TODO(chenweihang): enhance rules, not all dispensable inputs -// are host tensor, now only for scale kernel verify -static bool ContainHostTensor(const proto::OpProto& op_proto, - const VariableValueMap& inputs) { - for (int i = 0; i < op_proto.inputs_size(); ++i) { - auto in = op_proto.inputs()[i]; - if (in.has_dispensable() && in.dispensable()) { - return IsValidVar(in.name(), inputs); - } - } - return false; -} - -// TODO(yuanrisheng): enhance rules, for get kernel that contains Intermediate -// Tensor -static bool ContainMidOutputTensor(const proto::OpProto& op_proto, - const VariableValueMap& outputs) { - for (int i = 0; i < op_proto.outputs_size(); ++i) { - auto output = op_proto.outputs()[i]; - if (output.has_intermediate() && output.intermediate()) { - return IsValidVar(output.name(), outputs); - } - } - return false; -} - -static pt::KernelName ConstructPtKernelName(const std::string& op_type, - const proto::OpProto& op_proto, - const VariableValueMap& inputs, - const VariableValueMap& outputs) { - std::string overload_name; - // TODO(chenweihang): adapt SelectedRows by xiaowei's design - if (ContainHostTensor(op_proto, inputs)) { - if (overload_name != "") { - overload_name += "."; - } - overload_name += pt::kContainHostTensorSuffix; - } - if (ContainMidOutputTensor(op_proto, outputs)) { - if (overload_name != "") { - overload_name += "."; - } - overload_name += pt::kContainMidOutputTensorSuffix; - } - return pt::KernelName(op_type, overload_name); -} - -void OperatorWithKernel::ChoosePtKernel( - const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const { - // 1. construct operation name - // TODO(chenweihang): add rules for construct op name - auto kernel_name = - ConstructPtKernelName(Type(), *(Info().proto_), ctx.inputs, ctx.outputs); - - // 2. construct op kernel key - pt_kernel_key_.reset(new pt::KernelKey( - ConstructPtKernelKey(ctx.inputs, Attrs(), dev_ctx.GetPlace()))); - - // 3. selecte op kernel - pt_kernel_.reset(new pt::Kernel(pt::KernelFactory::Instance().SelectKernel( - kernel_name, *pt_kernel_key_))); - - // for debug - VLOG(1) << "ChoosePtKernel - kernel name: " << kernel_name - << " | kernel key: " << *pt_kernel_key_ - << " | kernel: " << *pt_kernel_; -} - -void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, - const Scope& scope, - const platform::Place& place) const { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(place); - - // check if op[type] has kernel registered. - auto& all_op_kernels = AllOpKernels(); - auto kernels_iter = all_op_kernels.find(type_); - PADDLE_ENFORCE_NE( - kernels_iter, all_op_kernels.end(), - platform::errors::Unavailable( - "There are no kernels which are registered in the %s operator.", - type_)); - - OpKernelMap& kernels = kernels_iter->second; +OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( + const ExecutionContext& ctx) const { + auto& dev_ctx = ctx.device_context(); - auto expected_kernel_key = this->GetExpectedKernelType( - ExecutionContext(*this, scope, *dev_ctx, ctx)); + auto expected_kernel_key = this->GetExpectedKernelType(ctx); if (HasAttr("op_device")) { if (Attr("op_device") == "cpu") { expected_kernel_key.place_ = platform::CPUPlace(); @@ -1376,9 +1262,9 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel // will be executed and a warning will be given at the same time. if (SupportGPU()) { - expected_kernel_key.place_ = dev_ctx->GetPlace(); + expected_kernel_key.place_ = dev_ctx.GetPlace(); } else if (SupportNPU()) { - expected_kernel_key.place_ = dev_ctx->GetPlace(); + expected_kernel_key.place_ = dev_ctx.GetPlace(); } else { expected_kernel_key.place_ = platform::CPUPlace(); LOG_FIRST_N(WARNING, 1) @@ -1389,6 +1275,45 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, } VLOG(3) << "op type:" << type_ << ", expected_kernel_key:" << expected_kernel_key; + return expected_kernel_key; +} + +void OperatorWithKernel::ChoosePtKernel(const ExecutionContext& ctx) const { + pt_kernel_signature_.reset( + new KernelSignature(this->GetExpectedPtKernelArgs(ctx))); + + VLOG(1) << KernelSignatureToString(*pt_kernel_signature_.get()); + + kernel_type_.reset(new OpKernelType(InnerGetExpectedKernelType(ctx))); + + auto pt_kernel_name = pt::KernelName(pt_kernel_signature_->first); + auto pt_kernel_key = TransOpKernelTypeToPtKernelKey(*kernel_type_.get()); + pt_kernel_.reset(new pt::Kernel(pt::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key))); + + if (pt_kernel_->IsValid()) { + VLOG(1) << "Static mode ChoosePtKernel - kernel name: " << pt_kernel_name + << " | kernel key: " << pt_kernel_key + << " | kernel: " << *pt_kernel_; + } else { + VLOG(1) << "Static mode ChoosePtKernel - kernel `" << pt_kernel_name + << "` not found."; + } +} + +void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { + // check if op[type] has kernel registered. + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + PADDLE_ENFORCE_NE( + kernels_iter, all_op_kernels.end(), + platform::errors::Unavailable( + "There are no kernels which are registered in the %s operator.", + type_)); + + OpKernelMap& kernels = kernels_iter->second; + + auto expected_kernel_key = InnerGetExpectedKernelType(ctx); auto kernel_iter = kernels.find(expected_kernel_key); #ifdef PADDLE_WITH_MKLDNN @@ -1844,60 +1769,23 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( tensor.layout()); } -pt::KernelKey OperatorWithKernel::ConstructPtKernelKey( - const VariableValueMap& inputs, const AttributeMap& attrs, - const platform::Place& ctx_place) const { - // 1. get backend based place and attrs - auto attr_reader = AttrReader(attrs); - pt::Backend backend = pt::TransToPtBackend(ctx_place); - if (attrs.count("use_mkldnn") != 0 && - attr_reader.Get("use_mkldnn") == true) { - backend = pt::Backend::kMKLDNN; - } else if (attrs.count("use_cudnn") != 0 && - attr_reader.Get("use_cudnn") == true) { - backend = pt::Backend::kCUDNN; +KernelSignature OperatorWithKernel::GetExpectedPtKernelArgs( + const ExecutionContext& ctx) const { + if (KernelSignatureMap::Instance().Has(Type())) { + return *(KernelSignatureMap::Instance().GetNullable(Type())); } else { - // do nothing + KernelArgsNameMakerByOpProto maker(Info().proto_); + auto signature = maker.GetKernelSignature(); + KernelSignatureMap::Instance().Insert(Type(), signature); + return signature; } - // TODO(chenweihang): add more rules - // if (HasAttr("op_device")) - - // 2. get layout - // default layout same as tensor default layout, need futher check - pt::DataLayout layout = pt::DataLayout::kNCHW; - if (backend == pt::Backend::kMKLDNN) { - layout = pt::DataLayout::kMKLDNN; - } - - // 3. parse data_type form inputs - proto::VarType::Type dafault_data_type = - static_cast(-1); - proto::VarType::Type data_type = dafault_data_type; - for (auto& var_pair : inputs) { - ParseInputDataType(var_pair.second, var_pair.first, &data_type); - } - PADDLE_ENFORCE_NE( - data_type, dafault_data_type, - platform::errors::NotFound( - "DataType should be indicated by input Variable at %s.", Type())); - pt::DataType dtype = pt::TransToPtDataType(data_type); - - // TODO(chenweihang): polish special dtype rules - if (attrs.count("dtype") != 0 && - attr_reader.Get("dtype") != static_cast(data_type)) { - dtype = pt::TransToPtDataType(static_cast( - attr_reader.Get("dtype"))); - } - - // 4. build pt KernelKey - return pt::KernelKey(backend, layout, dtype); } -pt::KernelContext OperatorWithKernel::ConstructPtKernelContext( +pt::KernelContext OperatorWithKernel::BuildPtKernelContext( const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const { VLOG(1) << RuntimeContextDebugString(ctx); - // TODO(chenweihang): now only work for very simple case (sign op), + // TODO(chenweihang): now only work for very simple case, // many cases need to be deal with later: // 1. the input and output are not tensor // 2. the dispensbale, duplicable input and output @@ -1905,42 +1793,36 @@ pt::KernelContext OperatorWithKernel::ConstructPtKernelContext( // 4. use pt Tensor directly // 5. kernel input is not DenseTensor pt::KernelContext op_kernel_ctx(dev_ctx); - auto input_defs = pt_kernel_->args_def().input_defs(); - auto output_defs = pt_kernel_->args_def().output_defs(); - auto attr_defs = pt_kernel_->args_def().attribute_defs(); - - // TODO(chenweihang): use ordered_map for VariableNameMap and VariableValueMap - // If we the VariableValueMap are ordered, we can get tensor by iter the map, - // and its order is same as OpProto - paddle::imperative::KernelArgsNameMakerByOpProto argMaker( - Info().proto_, &ctx.inputs, &ctx.outputs); + auto& input_names = std::get<0>(pt_kernel_signature_->second); + auto& attr_names = std::get<1>(pt_kernel_signature_->second); + auto& output_names = std::get<2>(pt_kernel_signature_->second); - auto& input_names = argMaker.GetInputArgsNames(); - auto& output_names = argMaker.GetOutputArgsNames(); - auto& attr_pairs = argMaker.GetAttrsArgsNamesAndTypes(); + auto input_defs = pt_kernel_->args_def().input_defs(); + auto attr_defs = pt_kernel_->args_def().attribute_defs(); + auto output_defs = pt_kernel_->args_def().output_defs(); PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), platform::errors::InvalidArgument( - "the size of inputs_args names (%d) must be equal to " + "The size of inputs_args names (%d) must be equal to " "the size of kernel input_defs (%d).", input_names.size(), input_defs.size())); PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(), platform::errors::InvalidArgument( - "the size of outputs_args names (%d) must be equal to " + "The size of outputs_args names (%d) must be equal to " "the size of kernel output_defs (%d).", output_names.size(), output_defs.size())); - PADDLE_ENFORCE_EQ(attr_pairs.size(), attr_defs.size(), + PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(), platform::errors::InvalidArgument( - "the size of attribute_args names (%d) must be equal " + "The size of attribute_args names (%d) must be equal " "to the size of kernel attribute_defs (%d).", - attr_pairs.size(), attr_defs.size())); + attr_names.size(), attr_defs.size())); for (size_t i = 0; i < input_names.size(); ++i) { auto in_def = input_defs.at(i); - VLOG(1) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", " + VLOG(2) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", " << in_def.layout; auto ins_vector = ctx.inputs.at(input_names[i]); @@ -1965,49 +1847,33 @@ pt::KernelContext OperatorWithKernel::ConstructPtKernelContext( op_kernel_ctx.EmplaceBackOutputs(tmp_outputs); } - for (size_t i = 0; i < attr_defs.size(); ++i) { + for (size_t i = 0; i < attr_names.size(); ++i) { + auto& attr = Attrs().at(attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(pt::Scalar))) { - // TODO(chenweihang): support other attrs - // In principle, the attr required by the dynamic mode should be - // passed in from the Python side, and there is no need to look up - // from the default_map, but now this nor work - switch (attr_pairs[i].second) { - case framework::proto::AttrType::INT: - op_kernel_ctx.EmplaceBackAttr( - pt::Scalar(Attr(attr_pairs[i].first))); - break; - case framework::proto::AttrType::FLOAT: - op_kernel_ctx.EmplaceBackAttr( - pt::Scalar(Attr(attr_pairs[i].first))); - break; - case framework::proto::AttrType::BOOLEAN: - op_kernel_ctx.EmplaceBackAttr( - pt::Scalar(Attr(attr_pairs[i].first))); - break; - default: - // TODO(chenweihang): support other attrs type - PADDLE_THROW(platform::errors::Unimplemented( - "unsupported cast op attribute `%s` when construct " - "KernelContext.", - attr_pairs[i].first)); + // TODO(chenweihang): support other attrs later + // TODO(zhangyunfei): Scalar should hold scaler type, and we should check + // attribtue type by attr_defs + if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + op_kernel_ctx.EmplaceBackAttr(pt::Scalar(BOOST_GET_CONST(float, attr))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "unsupported cast op attribute `%s` to Scalar when construct " + "KernelContext.", + attr_names[i])); } } else { - // TODO(chenweihang): support other attrs - // In principle, the attr required by the dynamic mode should be - // passed in from the Python side, and there is no need to look up - // from the default_map, but now this nor work + // TODO(chenweihang): support other attrs later if (attr_defs[i].type_index == std::type_index(typeid(int))) { - op_kernel_ctx.EmplaceBackAttr(Attr(attr_pairs[i].first)); + op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { - op_kernel_ctx.EmplaceBackAttr(Attr(attr_pairs[i].first)); + op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { - op_kernel_ctx.EmplaceBackAttr(Attr(attr_pairs[i].first)); + op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else { - // TODO(chenweihang): support other attrs type PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " "KernelContext.", - attr_pairs[i].first)); + attr_names[i])); } } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index b844c2cf61407..7581b65e3b68b 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -116,8 +116,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); -OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key); - class ExecutionContext; class OperatorBase; @@ -534,13 +532,15 @@ class OperatorWithKernel : public OperatorBase { } /* member functions for adapting to tcmpt lib */ - // TODO(chenweihang): Temporarily as a class method - virtual pt::KernelKey ConstructPtKernelKey( - const VariableValueMap& inputs, const AttributeMap& attrs, - const platform::Place& ctx_place) const; - - virtual pt::KernelContext ConstructPtKernelContext( - const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const; + /** In the Tensor calculation library, the new Kernel adopts a clearer and + * more streamlined design. The arguments of the Kernel and the input and + * output arguments registered in the original OpMaker do not match in some + * cases, so we use map to record the arguments required by the kernel. + * When selecting Kernel during Op execution, select the arguments of the + * original Op according to the GetExpectedPtKernelArgs returned arguments. + */ + virtual KernelSignature GetExpectedPtKernelArgs( + const ExecutionContext& ctx) const; private: void RunImpl(const Scope& scope, const platform::Place& place) const final; @@ -563,8 +563,9 @@ class OperatorWithKernel : public OperatorBase { const std::vector& inplace_vars, const Scope& exec_scope) const; - void ChooseKernel(const RuntimeContext& ctx, const Scope& scope, - const platform::Place& place) const; + OpKernelType InnerGetExpectedKernelType(const ExecutionContext& ctx) const; + + void ChooseKernel(const ExecutionContext& ctx) const; void HandleComplexGradToRealGrad(const Scope& scope, RuntimeContext* ctx) const; @@ -582,8 +583,10 @@ class OperatorWithKernel : public OperatorBase { const std::string& name) const; /* member functions for adapting to tcmpt lib */ - void ChoosePtKernel(const RuntimeContext& ctx, - const platform::DeviceContext& dev_ctx) const; + void ChoosePtKernel(const ExecutionContext& ctx) const; + + pt::KernelContext BuildPtKernelContext( + const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const; protected: mutable std::unique_ptr kernel_type_; @@ -595,10 +598,11 @@ class OperatorWithKernel : public OperatorBase { mutable bool all_kernels_must_compute_runtime_shape_ = false; mutable std::mutex cache_update_mutex_; mutable bool enable_cache_transfer_scope_ = false; - // TODO(chenweihang): Similar duplicate members are used for new tcmpt lib, - // maybe we have better impl methods + // NOTE(chenweihang): Similar op members are used to adapt to + // new tcmpt kernel, if there is a better design in the future, + // we may polish the implementation here mutable bool run_pt_kernel_ = false; - mutable std::unique_ptr pt_kernel_key_; + mutable std::unique_ptr pt_kernel_signature_; mutable std::unique_ptr pt_kernel_; }; diff --git a/paddle/fluid/framework/tcmpt_utils.cc b/paddle/fluid/framework/tcmpt_utils.cc index 71ef2d3450ae9..a39e653d0349e 100644 --- a/paddle/fluid/framework/tcmpt_utils.cc +++ b/paddle/fluid/framework/tcmpt_utils.cc @@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/fluid/framework/tcmpt_utils.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace framework { @@ -62,7 +65,7 @@ std::shared_ptr MakeTensorImpl( proto::VarType::Type type) { return MakeTensorImpl( tensor, pt::TransToPtBackend(place), pt::TransToPtDataType(type), - pt::TransToPtLayout(tensor.layout())); + pt::TransToPtDataLayout(tensor.layout())); } template <> @@ -71,21 +74,7 @@ std::shared_ptr MakeTensorImpl( proto::VarType::Type type) { return MakeTensorImpl( tensor, pt::TransToPtBackend(place), pt::TransToPtDataType(type), - pt::TransToPtLayout(tensor.layout())); -} - -template <> -void ShareTensorImpl(pt::DenseTensor* tensor_impl, - LoDTensor* out) { - out->ResetHolderWithType(tensor_impl->allocation(), - pt::TransToProtoVarType(tensor_impl->type())); -} - -template <> -void ShareTensorImpl(pt::DenseTensor* tensor_impl, - Tensor* out) { - out->ResetHolderWithType(tensor_impl->allocation(), - pt::TransToProtoVarType(tensor_impl->type())); + pt::TransToPtDataLayout(tensor.layout())); } std::shared_ptr InputVariableToPtTensor( @@ -164,5 +153,115 @@ std::shared_ptr OutputVariableToPtTensor( return nullptr; } +OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key) { + proto::VarType::Type data_type = pt::TransToProtoVarType(kernel_key.dtype()); + platform::Place place = pt::TransToFluidPlace(kernel_key.backend()); + DataLayout data_layout = pt::TransToFluidDataLayout(kernel_key.layout()); + LibraryType library_type = LibraryType::kPlain; + if (kernel_key.backend() == pt::Backend::kMKLDNN) { + library_type = LibraryType::kMKLDNN; + } else if (kernel_key.backend() == pt::Backend::kCUDNN) { + library_type = LibraryType::kCUDNN; + } else { + // do nothing + } + // TODO(chenweihang): the customized_type_value is lost + return OpKernelType(data_type, place, data_layout, library_type); +} + +pt::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type) { + pt::Backend backend = pt::TransToPtBackend(kernel_type.place_); + if (kernel_type.library_type_ == LibraryType::kMKLDNN) { + backend = pt::Backend::kMKLDNN; + } else if (kernel_type.library_type_ == LibraryType::kCUDNN) { + backend = pt::Backend::kCUDNN; + } else { + // do + } + pt::DataLayout layout = pt::TransToPtDataLayout(kernel_type.data_layout_); + pt::DataType dtype = pt::TransToPtDataType(kernel_type.data_type_); + return pt::KernelKey(backend, layout, dtype); +} + +KernelSignatureMap& KernelSignatureMap::Instance() { + static KernelSignatureMap g_kernel_signature_map; + return g_kernel_signature_map; +} + +const paddle::SmallVector& +KernelArgsNameMakerByOpProto::GetInputArgsNames() { + for (int i = 0; i < op_proto_->inputs_size(); ++i) { + auto& in = op_proto_->inputs()[i]; + auto& in_name = in.name(); + if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) { + VLOG(1) << "Parse PtKernel input: skip extra & quant input - " << in_name; + continue; + } + // If contains dispensable input, we should override the + // GetExpectedPtKernelArgs method self + if (in.has_dispensable() && in.dispensable()) { + VLOG(1) << "Parse PtKernel input: skip dispensable input - " << in_name; + continue; + } + VLOG(1) << "Parse PtKernel input: " << in_name; + input_names_.emplace_back(in_name); + } + return input_names_; +} + +const paddle::SmallVector& +KernelArgsNameMakerByOpProto::GetOutputArgsNames() { + for (int i = 0; i < op_proto_->outputs_size(); ++i) { + auto& out = op_proto_->outputs()[i]; + auto& out_name = out.name(); + // TODO(chenweihang): outputs also need skip some cases + VLOG(1) << "Parse PtKernel output: " << out_name; + output_names_.emplace_back(out_name); + } + return output_names_; +} + +const paddle::SmallVector& +KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { + for (int i = 0; i < op_proto_->attrs_size(); ++i) { + auto& attr = op_proto_->attrs()[i]; + auto& attr_name = attr.name(); + if (attr_name == "use_mkldnn" || attr_name == "op_role" || + attr_name == "op_role_var" || attr_name == "op_namescope" || + attr_name == "op_callstack" || attr_name == "op_device") { + VLOG(1) << "Parse PtKernel attribute: skip needless attr - " << attr_name; + continue; + } + if ((attr.has_extra() && attr.extra()) || + (attr.has_quant() && attr.quant())) { + VLOG(1) << "Parse PtKernel attribute: skip extra & quant attr - " + << attr_name; + continue; + } + VLOG(1) << "Parse PtKernel attribute: " << attr_name; + attr_names_.emplace_back(attr_name); + } + + return attr_names_; +} + +KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { + return std::make_pair( + op_proto_->type(), + std::make_tuple(GetInputArgsNames(), GetAttrsArgsNames(), + GetOutputArgsNames())); +} + +std::string KernelSignatureToString(const KernelSignature& signature) { + std::stringstream os; + os << "Kernel Signature - name: " << signature.first << "; inputs: " + << string::join_strings(std::get<0>(signature.second), ", ") + << "; attributes: " + << string::join_strings(std::get<1>(signature.second), ", ") + << "; outputs: " + << string::join_strings(std::get<2>(signature.second), ", "); + return os.str(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tcmpt_utils.h b/paddle/fluid/framework/tcmpt_utils.h index 0af8cd30bd34d..5ec5476f2b8e5 100644 --- a/paddle/fluid/framework/tcmpt_utils.h +++ b/paddle/fluid/framework/tcmpt_utils.h @@ -14,14 +14,25 @@ limitations under the License. */ #pragma once +#include +#include +#include + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" - #include "paddle/tcmpt/api/include/core.h" +#include "paddle/utils/flat_hash_map.h" +#include "paddle/utils/small_vector.h" namespace paddle { namespace framework { +/* tensor translate */ + template std::shared_ptr MakeTensorImpl(const VariableT& tensor, pt::Backend backend, @@ -38,16 +49,81 @@ std::shared_ptr MakeTensorImpl(const Tensor& tensor, const platform::Place& place, proto::VarType::Type type); -template -void ShareTensorImpl(PtTensorImplT* tensor_impl, LoDTensor* out); - -template -void ShareTensorImpl(PtTensorImplT* tensor_impl, Tensor* out); - std::shared_ptr InputVariableToPtTensor( const framework::Variable& variable, const pt::TensorArgDef& arg_def); std::shared_ptr OutputVariableToPtTensor( framework::Variable* variable, const pt::TensorArgDef& arg_def); +/* Kernel Key translate */ + +OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key); +pt::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type); + +/* Kernel Args parse */ + +// TODO(chenweihang): we can generate this map by proto info in compile time +class KernelSignatureMap { + public: + static KernelSignatureMap& Instance(); + + bool Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); + } + + void Insert(const std::string& op_type, const KernelSignature& signature) { + PADDLE_ENFORCE_NE( + Has(op_type), true, + platform::errors::AlreadyExists( + "Operator (%s)'s Kernel Signature has been registered.", op_type)); + map_.insert({op_type, signature}); + } + + const KernelSignature* GetNullable(const std::string& op_type) const { + auto it = map_.find(op_type); + if (it == map_.end()) { + return nullptr; + } else { + return &it->second; + } + } + + private: + KernelSignatureMap() = default; + paddle::flat_hash_map map_; + + DISABLE_COPY_AND_ASSIGN(KernelSignatureMap); +}; + +class KernelArgsNameMaker { + public: + virtual ~KernelArgsNameMaker() {} + virtual const paddle::SmallVector& GetInputArgsNames() = 0; + virtual const paddle::SmallVector& GetOutputArgsNames() = 0; + virtual const paddle::SmallVector& GetAttrsArgsNames() = 0; +}; + +class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { + public: + explicit KernelArgsNameMakerByOpProto(framework::proto::OpProto* op_proto) + : op_proto_(op_proto) {} + + ~KernelArgsNameMakerByOpProto() {} + + const paddle::SmallVector& GetInputArgsNames() override; + const paddle::SmallVector& GetOutputArgsNames() override; + const paddle::SmallVector& GetAttrsArgsNames() override; + + KernelSignature GetKernelSignature(); + + private: + framework::proto::OpProto* op_proto_; + + paddle::SmallVector input_names_; + paddle::SmallVector output_names_; + paddle::SmallVector attr_names_; +}; + +std::string KernelSignatureToString(const KernelSignature& signature); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tcmpt_utils_test.cc b/paddle/fluid/framework/tcmpt_utils_test.cc new file mode 100644 index 0000000000000..f1966789c1dde --- /dev/null +++ b/paddle/fluid/framework/tcmpt_utils_test.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/tcmpt_utils.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/variable.h" + +namespace paddle { +namespace framework { + +TEST(TcmptUtils, MakeTensor) { + // 1. create tensor + LoDTensor x; + Tensor x2; + x.Resize({2}); + x.mutable_data(platform::CPUPlace()); + x.data()[0] = 0.2; + x.data()[1] = 0.5; + + // 2. test API + auto dense_x = MakeTensorImpl(x, x.place(), x.type()); + + // 3. check result + std::vector expect_value = {0.2, 0.5}; + ASSERT_EQ(dense_x->data()[0], expect_value[0]); + ASSERT_EQ(dense_x->data()[1], expect_value[1]); + ASSERT_EQ(dense_x->backend(), pt::Backend::kCPU); + ASSERT_EQ(dense_x->type(), pt::DataType::kFLOAT32); +} + +TEST(TcmptUtils, VarToPtTensor) { + // 1. create Variable + Variable v; + auto selected_rows = v.GetMutable(); + Tensor* value = selected_rows->mutable_value(); + auto* data = + value->mutable_data(make_ddim({1, 1}), paddle::platform::CPUPlace()); + data[0] = 123; + pt::Backend expect_backend = pt::Backend::kCPU; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + expect_backend = pt::Backend::kCUDA; +#endif + auto tensor_def = pt::TensorArgDef(expect_backend, pt::DataLayout::kNCHW, + pt::DataType::kINT32); + // 2. test API + auto tensor_x = InputVariableToPtTensor(v, tensor_def); + // 3. check result + ASSERT_EQ(tensor_x->backend(), expect_backend); + ASSERT_EQ(tensor_x->type(), pt::DataType::kINT32); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 1c5469d02c3ef..d0d1b915f2317 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -17,11 +17,13 @@ limitations under the License. */ #include #include #include +#include #include #include #include #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/platform/variant.h" +#include "paddle/utils/small_vector.h" namespace paddle { namespace framework { @@ -82,5 +84,13 @@ using InferShapeFN = std::function; using InplacePair = std::unordered_map; using InferInplaceOpFN = std::function; +// tuple(input_names, attr_names, output_names) +using KernelArgsTuple = std::tuple, + paddle::SmallVector, + paddle::SmallVector>; +// TODD(yuanrisheng): impl implicit overload signature, use KernelArgsTuple +// directly +using KernelSignature = std::pair; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/imperative/kernel_args_names_maker.h b/paddle/fluid/imperative/kernel_args_names_maker.h deleted file mode 100644 index 5863f3cae95c2..0000000000000 --- a/paddle/fluid/imperative/kernel_args_names_maker.h +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "glog/logging.h" - -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/imperative/type_defs.h" -#include "paddle/utils/small_vector.h" - -namespace paddle { -namespace imperative { -// TODO(chenweihang): now only check single var input -template -static bool IsValidVar(const std::string& name, - const NameVarMap& inputs) { - auto it = inputs.find(name); - if (it == inputs.end()) { - return false; - } - if (it->second.empty()) { - return false; - } - return it->second[0] != nullptr; -} - -class KernelArgsNameMaker { - public: - virtual ~KernelArgsNameMaker() {} - virtual const paddle::SmallVector& GetInputArgsNames() = 0; - virtual const paddle::SmallVector& GetOutputArgsNames() = 0; - virtual const paddle::SmallVector< - std::pair>& - GetAttrsArgsNamesAndTypes() = 0; -}; - -template -class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { - public: - KernelArgsNameMakerByOpProto(framework::proto::OpProto* op_proto, - const imperative::NameVarMap* inputs, - const imperative::NameVarMap* outputs) - : op_proto_(op_proto), inputs_(inputs), outputs_(outputs) {} - - ~KernelArgsNameMakerByOpProto() {} - - const paddle::SmallVector& GetInputArgsNames() override { - for (int i = 0; i < op_proto_->inputs_size(); ++i) { - auto in = op_proto_->inputs()[i]; - - // TODO(chenweihang): deal with diff param in vector - if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) { - VLOG(1) << "Dygraph PtKernel input: skip extra & quant input - " - << in.name(); - continue; - } - - std::string in_name = in.name(); - if (in.has_dispensable() && in.dispensable()) { - if (this->contain_host_tensor_flags.count(in_name) > 0 && - IsValidVar(in_name, *inputs_)) { - VLOG(1) << "Dygraph PtKernel input: contain host input - " << in_name; - this->contain_host_tensor_flags[in_name] = true; - } else { - VLOG(1) << "Dygraph PtKernel input: skip dispensable input - " - << in_name; - continue; - } - } - - input_names.emplace_back(in.name()); - } - return input_names; - } - - const paddle::SmallVector& GetOutputArgsNames() override { - for (int i = 0; i < op_proto_->outputs_size(); ++i) { - auto out_name = op_proto_->outputs()[i].name(); - VLOG(1) << "Dygraph PtKernel output: " << out_name; - // TODO(chenweihang): outputs also need skip some cases - - output_names.emplace_back(out_name); - } - return output_names; - } - - const paddle::SmallVector>& - GetAttrsArgsNamesAndTypes() override { - for (int i = 0; i < op_proto_->attrs_size(); ++i) { - auto attr = op_proto_->attrs()[i]; - if (attr.name() == "use_mkldnn" || attr.name() == "op_role" || - attr.name() == "op_role_var" || attr.name() == "op_namescope" || - attr.name() == "op_callstack" || attr.name() == "op_device") { - VLOG(1) << "Dygraph PtKernel attribute: skip needless attr - " - << attr.name(); - continue; - } - if ((attr.has_extra() && attr.extra()) || - (attr.has_quant() && attr.quant())) { - VLOG(1) << "Dygraph PtKernel attribute: skip extra & quant attr - " - << attr.name(); - continue; - } - if (attr_to_host_tensor.count(attr.name()) > 0 && - contain_host_tensor_flags.at(attr_to_host_tensor.at(attr.name())) == - true) { - VLOG(1) << "Dygraph PtKernel attribute: skip dynaimc attr - " - << attr.name() << ", because " - << attr_to_host_tensor.at(attr.name()) << " exists."; - continue; - } - // TODO(chenweihang): we need better methods to deal with special cases - if (attr.name() == "dtype") { - VLOG(1) << "Dygraph PtKernel attribute: skip " << op_proto_->type() - << "'s dtype attr."; - continue; - } - VLOG(1) << "Dygraph PtKernel attribute: " << attr.name(); - attr_names.emplace_back( - std::pair(attr.name(), - attr.type())); - } - - return attr_names; - } - - private: - framework::proto::OpProto* op_proto_; - - const imperative::NameVarMap* inputs_; - const imperative::NameVarMap* outputs_; - - paddle::SmallVector input_names; - paddle::SmallVector output_names; - paddle::SmallVector> - attr_names; - - // TODO(chenweihang): For scale op, when the input has a `ScaleTensor`, - // the following scale attribute should be skipped, and there are many - // such ops, which require certain rules to process, now only for verify - // scale op - std::unordered_map contain_host_tensor_flags{ - {"ScaleTensor", false}}; - std::unordered_map attr_to_host_tensor{ - {"scale", "ScaleTensor"}}; -}; - -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index f7e57bec1da9e..87e7e754e3ee8 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -22,7 +22,7 @@ #include "paddle/fluid/platform/xpu/xpu_op_list.h" #endif DECLARE_bool(check_nan_inf); -DECLARE_bool(use_pt_kernel); +DECLARE_bool(run_pt_kernel); namespace paddle { namespace imperative { @@ -47,10 +47,9 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { } } -template -static const T& GetAttr(const framework::AttributeMap& attrs, - const framework::AttributeMap& default_attrs, - const std::string& name) { +static const framework::Attribute& GetAttr( + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const std::string& name) { auto it = attrs.find(name); bool found = it != attrs.end(); if (!found) { @@ -60,7 +59,7 @@ static const T& GetAttr(const framework::AttributeMap& attrs, PADDLE_ENFORCE_EQ( found, true, platform::errors::NotFound("(%s) is not found in AttributeMap.", name)); - return BOOST_GET_CONST(T, it->second); + return it->second; } template @@ -108,63 +107,18 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const pt::KernelKey& pt_kernel_key, + const framework::OpKernelType& kernel_type, + const framework::KernelSignature& kernel_signature, const pt::Kernel& pt_kernel, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), - kernel_type_(framework::OpKernelType(framework::proto::VarType::RAW, - platform::CPUPlace())), + kernel_type_(kernel_type), func_(nullptr), dev_ctx_(dev_ctx), run_pt_kernel_(true), - pt_kernel_key_(pt_kernel_key), - pt_kernel_(pt_kernel) { - // TODO(chenweihang): PrepareData still use old impl, so here need save - // old kernel type, trans it later - kernel_type_ = framework::TransPtKernelKeyToOpKernelType(pt_kernel_key_); -} - -template -static framework::VariableValueMap BuildInputMap( - const NameVarMap& ins) { - framework::VariableValueMap inputs; - for (auto& var_pair : ins) { - for (auto& var : var_pair.second) { - inputs[var_pair.first].emplace_back(var->MutableVar()); - } - } - return inputs; -} - -// TODO(chenweihang): enhance rules, not all dispensable inputs -// are host tensor, now only for scale kernel verify -template -static bool ContainHostTensor(const framework::proto::OpProto& op_proto, - const NameVarMap& inputs) { - for (int i = 0; i < op_proto.inputs_size(); ++i) { - auto in = op_proto.inputs()[i]; - if (in.has_dispensable() && in.dispensable()) { - return IsValidVar(in.name(), inputs); - } - } - return false; -} - -template -static pt::KernelName ConstructPtKernelName( - const std::string& op_type, const framework::proto::OpProto& op_proto, - const NameVarMap& inputs) { - std::string overload_name; - // TODO(chenweihang): adapt SelectedRows by xiaowei's design - if (ContainHostTensor(op_proto, inputs)) { - if (overload_name != "") { - overload_name += "."; - } - overload_name += pt::kContainHostTensorSuffix; - } - return pt::KernelName(op_type, overload_name); -} + pt_kernel_signature_(kernel_signature), + pt_kernel_(pt_kernel) {} template PreparedOp PrepareImpl(const NameVarMap& ins, @@ -192,30 +146,36 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif // 1. get expected kernel key - if (FLAGS_use_pt_kernel && + auto dygraph_exe_ctx = DygraphExecutionContext( + op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, default_attrs); + auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + + if (FLAGS_run_pt_kernel && pt::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) { - auto kernel_name = - ConstructPtKernelName(op.Type(), (*op.Info().proto_), ins); - auto inputs = BuildInputMap(ins); - // we only need attrs here - // auto final_attrs = BuildAttrMap(attrs, default_attrs); - auto pt_kernel_key = op.ConstructPtKernelKey(inputs, attrs, place); - auto pt_kernel = - pt::KernelFactory::Instance().SelectKernel(kernel_name, pt_kernel_key); - // for debug - VLOG(1) << "PrepareImpl - kernel name: " << kernel_name - << " | kernel key: " << pt_kernel_key << " | kernel: " << pt_kernel; + auto pt_kernel_signature = op.GetExpectedPtKernelArgs(dygraph_exe_ctx); + + VLOG(1) << framework::KernelSignatureToString(pt_kernel_signature); + + auto pt_kernel_name = pt::KernelName(pt_kernel_signature.first); + auto pt_kernel_key = TransOpKernelTypeToPtKernelKey(expected_kernel_key); + auto pt_kernel = pt::KernelFactory::Instance().SelectKernel(pt_kernel_name, + pt_kernel_key); + if (pt_kernel.IsValid()) { + VLOG(1) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name + << " | kernel key: " << pt_kernel_key + << " | kernel: " << pt_kernel; + // TODO(chenweihang): using CPUKernel when miss device kernel case - return PreparedOp(op, ctx, pt_kernel_key, pt_kernel, dev_ctx); + return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, + pt_kernel, dev_ctx); + } else { + VLOG(1) << "Dynamic mode ChoosePtKernel - kernel `" << pt_kernel_name + << "` not found."; } } - auto expected_kernel_key = op.GetExpectedKernelType( - DygraphExecutionContext(op, framework::Scope(), *dev_ctx, ctx, - ins, outs, attrs, default_attrs)); - VLOG(3) << "expected_kernel_key:" << expected_kernel_key; - // 2. check if op[type] has kernel registered. auto& all_op_kernels = op.AllOpKernels(); auto kernels_iter = all_op_kernels.find(op.Type()); @@ -283,13 +243,13 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, } template -static pt::KernelContext BuildDygraphKernelContext( - const pt::Kernel& pt_kernel, KernelArgsNameMaker* argsNameMaker, - const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs, +static pt::KernelContext BuildDygraphPtKernelContext( + const framework::KernelSignature& pt_kernel_signature, + const pt::Kernel& pt_kernel, const NameVarMap& ins, + const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs, const platform::DeviceContext& dev_ctx) { - // TODO(chenweihang): now only work for very simple case (sign op), + // TODO(chenweihang): now only work for very simple case, // many cases need to be deal with later: // 1. the input and output are not tensor // 2. the dispensbale, duplicable input and output @@ -297,14 +257,15 @@ static pt::KernelContext BuildDygraphKernelContext( // 4. use pt Tensor directly // 5. kernel input is not DenseTensor pt::KernelContext op_kernel_ctx(dev_ctx); + + auto& input_names = std::get<0>(pt_kernel_signature.second); + auto& attr_names = std::get<1>(pt_kernel_signature.second); + auto& output_names = std::get<2>(pt_kernel_signature.second); + auto input_defs = pt_kernel.args_def().input_defs(); auto output_defs = pt_kernel.args_def().output_defs(); auto attr_defs = pt_kernel.args_def().attribute_defs(); - auto& input_names = argsNameMaker->GetInputArgsNames(); - auto& output_names = argsNameMaker->GetOutputArgsNames(); - auto& attr_pairs = argsNameMaker->GetAttrsArgsNamesAndTypes(); - PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), platform::errors::InvalidArgument( "the size of inputs_args names (%d) must be equal to " @@ -317,16 +278,16 @@ static pt::KernelContext BuildDygraphKernelContext( "the size of kernel output_defs (%d).", output_names.size(), output_defs.size())); - PADDLE_ENFORCE_EQ(attr_pairs.size(), attr_defs.size(), + PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(), platform::errors::InvalidArgument( "the size of attribute_args names (%d) must be equal " "to the size of kernel attribute_defs (%d).", - attr_pairs.size(), attr_defs.size())); + attr_names.size(), attr_defs.size())); for (size_t i = 0; i < input_names.size(); ++i) { - auto in_def = input_defs.at(i); + auto& in_def = input_defs.at(i); + auto& ins_vector = ins.at(input_names[i]); - auto ins_vector = ins.at(input_names[i]); std::vector> tmp_inputs; for (auto var : ins_vector) { const auto& variable = var->Var(); @@ -338,12 +299,12 @@ static pt::KernelContext BuildDygraphKernelContext( } for (size_t i = 0; i < output_names.size(); ++i) { - auto out_def = output_defs.at(i); - auto outs_vector = outs.at(output_names[i]); + auto& out_def = output_defs.at(i); + auto& outs_vector = outs.at(output_names[i]); std::vector> tmp_outputs; for (auto var : outs_vector) { - auto variable = var->MutableVar(); + auto* variable = var->MutableVar(); auto pt_out = framework::OutputVariableToPtTensor(variable, out_def); tmp_outputs.emplace_back(pt_out); @@ -351,52 +312,33 @@ static pt::KernelContext BuildDygraphKernelContext( op_kernel_ctx.EmplaceBackOutputs(tmp_outputs); } - for (size_t i = 0; i < attr_defs.size(); ++i) { + for (size_t i = 0; i < attr_names.size(); ++i) { + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(pt::Scalar))) { - // TODO(chenweihang): support other attrs - // In principle, the attr required by the dynamic mode should be - // passed in from the Python side, and there is no need to look up - // from the default_map, but now this nor work - switch (attr_pairs[i].second) { - case framework::proto::AttrType::INT: - op_kernel_ctx.EmplaceBackAttr(pt::Scalar( - GetAttr(attrs, default_attrs, attr_pairs[i].first))); - break; - case framework::proto::AttrType::FLOAT: - op_kernel_ctx.EmplaceBackAttr(pt::Scalar( - GetAttr(attrs, default_attrs, attr_pairs[i].first))); - break; - case framework::proto::AttrType::BOOLEAN: - op_kernel_ctx.EmplaceBackAttr(pt::Scalar( - GetAttr(attrs, default_attrs, attr_pairs[i].first))); - break; - default: - // TODO(chenweihang): support other attrs type - PADDLE_THROW(platform::errors::Unimplemented( - "unsupported cast op attribute `%s` when construct " - "KernelContext.", - attr_pairs[i].first)); + // TODO(chenweihang): support other attrs later + // TODO(zhangyunfei): Scalar should hold scaler type, and we should check + // attribtue type by attr_defs + if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + op_kernel_ctx.EmplaceBackAttr(pt::Scalar(BOOST_GET_CONST(float, attr))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "unsupported cast op attribute `%s` to Scalar when construct " + "KernelContext in dygraph.", + attr_names[i])); } } else { - // TODO(chenweihang): support other attrs - // In principle, the attr required by the dynamic mode should be - // passed in from the Python side, and there is no need to look up - // from the default_map, but now this nor work + // TODO(chenweihang): support other attrs later if (attr_defs[i].type_index == std::type_index(typeid(int))) { - op_kernel_ctx.EmplaceBackAttr( - GetAttr(attrs, default_attrs, attr_pairs[i].first)); + op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { - op_kernel_ctx.EmplaceBackAttr( - GetAttr(attrs, default_attrs, attr_pairs[i].first)); + op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { - op_kernel_ctx.EmplaceBackAttr( - GetAttr(attrs, default_attrs, attr_pairs[i].first)); + op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else { - // TODO(chenweihang): support other attrs type PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " - "KernelContext.", - attr_pairs[i].first)); + "KernelContext in dygraph.", + attr_names[i])); } } } @@ -446,27 +388,26 @@ static void PreparedOpRunImpl( } template -static void PreparedOpRunPtImpl(const framework::OperatorBase& op, - const pt::KernelKey& pt_kernel_key, - const pt::Kernel& pt_kernel, - platform::DeviceContext* dev_ctx, - const NameVarMap& ins, - const NameVarMap& outs, - const framework::AttributeMap& attrs, - const framework::AttributeMap& default_attrs) { +static void PreparedOpRunPtImpl( + const framework::OperatorBase& op, + const framework::KernelSignature& pt_kernel_signature, + const pt::Kernel& pt_kernel, platform::DeviceContext* dev_ctx, + const NameVarMap& ins, const NameVarMap& outs, + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, op.Type()); static_cast(op).InferShape( &infer_shape_ctx); - paddle::imperative::KernelArgsNameMakerByOpProto argMaker( - op.Info().proto_, &ins, &outs); - auto op_kernel_ctx = BuildDygraphKernelContext( - pt_kernel, &argMaker, ins, outs, attrs, default_attrs, *dev_ctx); + auto op_kernel_ctx = BuildDygraphPtKernelContext( + pt_kernel_signature, pt_kernel, ins, outs, attrs, default_attrs, + *dev_ctx); + pt_kernel(&op_kernel_ctx); - // TODO(chenweihang): add flags - // TODO(chenweihang): deal with complex cases + // TODO(chenweihang): add debug flags later + // TODO(chenweihang): deal with complex cases later } void PreparedOp::Run(const NameVarMap& ins, @@ -474,8 +415,8 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pt_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_key_, pt_kernel_, dev_ctx_, ins, - outs, attrs, default_attrs); + PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, + dev_ctx_, ins, outs, attrs, default_attrs); } else { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, default_attrs); @@ -487,7 +428,7 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pt_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_key_, pt_kernel_, + PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, outs, attrs, default_attrs); } else { diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index d6ea055cecff2..d1a47117f389b 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -21,11 +21,11 @@ #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" -#include "paddle/fluid/imperative/kernel_args_names_maker.h" #include "paddle/tcmpt/api/include/core.h" DECLARE_bool(use_mkldnn); @@ -152,8 +152,9 @@ class PreparedOp { PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const pt::KernelKey& pt_kernel_key, const pt::Kernel& pt_kernel, - platform::DeviceContext* dev_ctx); + const framework::OpKernelType& kernel_type, + const framework::KernelSignature& kernel_signature, + const pt::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, @@ -186,10 +187,11 @@ class PreparedOp { framework::OpKernelType kernel_type_; framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; - // TODo(chenweihang): Similar duplicate members are used for new tcmpt lib, - // maybe we have better impl methods + // NOTE(chenweihang): Similar op members are used to adapt to + // new tcmpt kernel, if there is a better design in the future, + // we may polish the implementation here bool run_pt_kernel_{false}; - pt::KernelKey pt_kernel_key_; + framework::KernelSignature pt_kernel_signature_; pt::Kernel pt_kernel_; }; diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h index fdbbc586979cd..74fd152e72a57 100644 --- a/paddle/fluid/imperative/type_defs.h +++ b/paddle/fluid/imperative/type_defs.h @@ -20,11 +20,6 @@ limitations under the License. */ #include namespace paddle { - -namespace framework { -class Variable; -} // namespace framework - namespace imperative { class VariableWrapper; @@ -50,12 +45,6 @@ template <> struct NameVarMapTrait { using Type = std::map; }; - -template <> -struct NameVarMapTrait { - using Type = std::map>; -}; - } // namespace details template diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index 1e908d5ead9c6..b46a1c3c89b6a 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -47,6 +47,15 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { expected_kernel_type.place_, tensor.layout()); } + + framework::KernelSignature GetExpectedPtKernelArgs( + const framework::ExecutionContext &ctx) const override { + return std::make_pair( + "fill_any_like", + std::make_tuple(paddle::SmallVector({"X"}), + paddle::SmallVector({"value"}), + paddle::SmallVector({"Out"}))); + } }; class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index a195452791048..329a649a5a34d 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -70,6 +70,24 @@ class ScaleOp : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::KernelSignature GetExpectedPtKernelArgs( + const framework::ExecutionContext &ctx) const override { + if (ctx.HasInput("ScaleTensor")) { + return std::make_pair( + "scale.host", + std::make_tuple( + paddle::SmallVector({"X", "ScaleTensor"}), + paddle::SmallVector({"bias", "bias_after_scale"}), + paddle::SmallVector({"Out"}))); + } else { + return std::make_pair( + "scale", std::make_tuple(paddle::SmallVector({"X"}), + paddle::SmallVector( + {"scale", "bias", "bias_after_scale"}), + paddle::SmallVector({"Out"}))); + } + } }; class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index b9c87c672df6e..c3d63f6eb2745 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -683,16 +683,16 @@ PADDLE_DEFINE_EXPORTED_bool( /** * Pt kernel related FLAG - * Name: FLAGS_use_pt_kernel + * Name: FLAGS_run_pt_kernel * Since Version: 2.2.0 * Value Range: bool, default=false - * Example: FLAGS_use_pt_kernel=true would use the pt kernel to compute in the + * Example: FLAGS_run_pt_kernel=true would use the pt kernel to compute in the * Op. * Note: */ // TODO(chentianyu03): change default value to false before merge into develop // branch -PADDLE_DEFINE_EXPORTED_bool(use_pt_kernel, true, +PADDLE_DEFINE_EXPORTED_bool(run_pt_kernel, true, "It controls whether to use pt kernel"); /** diff --git a/paddle/tcmpt/core/convert_utils.cc b/paddle/tcmpt/core/convert_utils.cc index d393dcf51c61b..e5b8acba19cf0 100644 --- a/paddle/tcmpt/core/convert_utils.cc +++ b/paddle/tcmpt/core/convert_utils.cc @@ -72,7 +72,7 @@ pt::DataType TransToPtDataType( } } -DataLayout TransToPtLayout(const paddle::framework::DataLayout& layout) { +DataLayout TransToPtDataLayout(const paddle::framework::DataLayout& layout) { switch (layout) { case paddle::framework::DataLayout::kNHWC: return DataLayout::kNHWC; diff --git a/paddle/tcmpt/core/convert_utils.h b/paddle/tcmpt/core/convert_utils.h index 9e8d85c7cfa92..a567775811349 100644 --- a/paddle/tcmpt/core/convert_utils.h +++ b/paddle/tcmpt/core/convert_utils.h @@ -32,7 +32,7 @@ namespace pt { Backend TransToPtBackend(const paddle::platform::Place& place); DataType TransToPtDataType( const paddle::framework::proto::VarType::Type& dtype); -DataLayout TransToPtLayout(const paddle::framework::DataLayout& layout); +DataLayout TransToPtDataLayout(const paddle::framework::DataLayout& layout); paddle::platform::Place TransToFluidPlace(const Backend& backend); paddle::framework::proto::VarType::Type TransToProtoVarType( diff --git a/paddle/tcmpt/core/kernel_factory.cc b/paddle/tcmpt/core/kernel_factory.cc index 3c6daaa776742..a301d6a995ce7 100644 --- a/paddle/tcmpt/core/kernel_factory.cc +++ b/paddle/tcmpt/core/kernel_factory.cc @@ -51,6 +51,11 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( "The kernel `%s` is not registered.", kernel_name)); auto kernel_iter = iter->second.find(kernel_key); + if (kernel_key.layout() != pt::DataLayout::kAny) { + pt::KernelKey any_layout_kernel_key( + kernel_key.backend(), pt::DataLayout::kAny, kernel_key.dtype()); + kernel_iter = iter->second.find(any_layout_kernel_key); + } PADDLE_ENFORCE_NE( kernel_iter, iter->second.end(), diff --git a/paddle/tcmpt/core/kernel_factory.h b/paddle/tcmpt/core/kernel_factory.h index 180f0ce2c6b87..5978264c9ef26 100644 --- a/paddle/tcmpt/core/kernel_factory.h +++ b/paddle/tcmpt/core/kernel_factory.h @@ -26,6 +26,8 @@ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/enforce.h" +#include "paddle/utils/flat_hash_map.h" +#include "paddle/utils/small_vector.h" namespace pt { @@ -209,25 +211,30 @@ class KernelArgsDef { attribute_defs_.emplace_back(AttributeArgDef(type_index)); } - const std::vector& input_defs() const { return input_defs_; } + const paddle::SmallVector& input_defs() const { + return input_defs_; + } - const std::vector& output_defs() const { return output_defs_; } + const paddle::SmallVector& output_defs() const { + return output_defs_; + } - const std::vector& attribute_defs() const { + const paddle::SmallVector& attribute_defs() const { return attribute_defs_; } - std::vector& input_defs() { return input_defs_; } + paddle::SmallVector& input_defs() { return input_defs_; } - std::vector& output_defs() { return output_defs_; } + paddle::SmallVector& output_defs() { return output_defs_; } - std::vector& attribute_defs() { return attribute_defs_; } + paddle::SmallVector& attribute_defs() { + return attribute_defs_; + } private: - // TODO(chenweihang): replaced by paddle::small_vector - std::vector input_defs_{{}}; - std::vector output_defs_{{}}; - std::vector attribute_defs_{{}}; + paddle::SmallVector input_defs_{{}}; + paddle::SmallVector output_defs_{{}}; + paddle::SmallVector attribute_defs_{{}}; }; class Kernel { @@ -263,10 +270,10 @@ class Kernel { class KernelFactory { public: // replaced by paddle::flat_hash_map later - using KernelMap = - std::unordered_map, - KernelName::Hash>; + using KernelMap = paddle::flat_hash_map< + KernelName, + paddle::flat_hash_map, + KernelName::Hash>; static KernelFactory& Instance(); diff --git a/paddle/tcmpt/core/kernel_registry.h b/paddle/tcmpt/core/kernel_registry.h index 40ee968dd987c..661d387e9b8e2 100644 --- a/paddle/tcmpt/core/kernel_registry.h +++ b/paddle/tcmpt/core/kernel_registry.h @@ -42,6 +42,13 @@ struct KernelArgsParseFunctor { using Arg = typename std::tuple_element::type; static void Parse(const KernelKey& default_key, KernelArgsDef* args_def) { + // TODO(chenweihang): The fluid Tensor's default layout is NCHW, + // it is not same as kernel's layout, we should fix this error on + // fluid Tensor + auto default_tensor_layout = pt::DataLayout::kNCHW; + if (default_key.layout() != pt::DataLayout::kAny) { + default_tensor_layout = default_key.layout(); + } auto args_type = ParseArgType(Indices{}); for (auto arg_type : args_type) { if (arg_type == std::type_index(typeid(const CPUContext&)) @@ -54,10 +61,10 @@ struct KernelArgsParseFunctor { // do nothing, skip context arg now } else if (arg_type == std::type_index(typeid(const DenseTensor&))) { args_def->AppendInput( - default_key.backend(), default_key.layout(), default_key.dtype()); + default_key.backend(), default_tensor_layout, default_key.dtype()); } else if (arg_type == std::type_index(typeid(DenseTensor*))) { args_def->AppendOutput( - default_key.backend(), default_key.layout(), default_key.dtype()); + default_key.backend(), default_tensor_layout, default_key.dtype()); } else { // Attribute deal with // TODO(chenweihang): now here allow any types of attribute, maybe diff --git a/paddle/tcmpt/kernels/cpu/creation.cc b/paddle/tcmpt/kernels/cpu/creation.cc index 4871e11da2112..37b589d776822 100644 --- a/paddle/tcmpt/kernels/cpu/creation.cc +++ b/paddle/tcmpt/kernels/cpu/creation.cc @@ -24,7 +24,7 @@ void FillAnyLike(const CPUContext& dev_ctx, const DenseTensor& x, const Scalar& val, DenseTensor* out) { - eigen::fill(dev_ctx, out, val.to()); + eigen::fill(dev_ctx, out, val.to()); } } // namespace pt @@ -33,7 +33,7 @@ PT_REGISTER_MODULE(CreationCPU); PT_REGISTER_KERNEL("fill_any_like", CPU, - NCHW, + Any, pt::FillAnyLike, float, double, diff --git a/paddle/tcmpt/kernels/cpu/linalg.cc b/paddle/tcmpt/kernels/cpu/linalg.cc index 8b63219fdd2db..821cd5c092e85 100644 --- a/paddle/tcmpt/kernels/cpu/linalg.cc +++ b/paddle/tcmpt/kernels/cpu/linalg.cc @@ -62,7 +62,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL("dot", CPU, - NCHW, + Any, pt::Dot, float, double, diff --git a/paddle/tcmpt/kernels/cpu/manipulation.cc b/paddle/tcmpt/kernels/cpu/manipulation.cc index 91f1e941cd028..edf7f5aff0389 100644 --- a/paddle/tcmpt/kernels/cpu/manipulation.cc +++ b/paddle/tcmpt/kernels/cpu/manipulation.cc @@ -60,7 +60,7 @@ PT_REGISTER_MODULE(ManipulationCPU); // architecture, kernel_name should be "flatten". PT_REGISTER_KERNEL("flatten_contiguous_range", CPU, - NCHW, + Any, pt::Flatten, float, double, @@ -71,7 +71,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range", PT_REGISTER_KERNEL("flatten_contiguous_range.mid", CPU, - NCHW, + Any, pt::FlattenWithXShape, float, double, diff --git a/paddle/tcmpt/kernels/cpu/math.cc b/paddle/tcmpt/kernels/cpu/math.cc index d304db0a9a34e..4fa14141209a1 100644 --- a/paddle/tcmpt/kernels/cpu/math.cc +++ b/paddle/tcmpt/kernels/cpu/math.cc @@ -69,11 +69,11 @@ PT_REGISTER_MODULE(MathCPU); // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_KERNEL("sign", CPU, NCHW, pt::Sign, float, double) {} -PT_REGISTER_KERNEL("mean", CPU, NCHW, pt::Mean, float, double) {} +PT_REGISTER_KERNEL("sign", CPU, Any, pt::Sign, float, double) {} +PT_REGISTER_KERNEL("mean", CPU, Any, pt::Mean, float, double) {} PT_REGISTER_KERNEL("scale", CPU, - NCHW, + Any, pt::Scale, float, double, @@ -85,7 +85,7 @@ PT_REGISTER_KERNEL("scale", int64_t) {} PT_REGISTER_KERNEL("scale.host", CPU, - NCHW, + Any, pt::ScaleHost, float, double, diff --git a/paddle/tcmpt/kernels/cuda/creation.cu b/paddle/tcmpt/kernels/cuda/creation.cu index 7f082400eaaf7..54afec95735df 100644 --- a/paddle/tcmpt/kernels/cuda/creation.cu +++ b/paddle/tcmpt/kernels/cuda/creation.cu @@ -24,7 +24,7 @@ void FillAnyLike(const CUDAContext& dev_ctx, const DenseTensor& x, const Scalar& val, DenseTensor* out) { - eigen::fill(dev_ctx, out, val.to()); + eigen::fill(dev_ctx, out, val.to()); } } // namespace pt @@ -33,7 +33,7 @@ PT_REGISTER_MODULE(CreationCUDA); PT_REGISTER_KERNEL("fill_any_like", CUDA, - NCHW, + Any, pt::FillAnyLike, float, double, diff --git a/paddle/tcmpt/kernels/cuda/linalg.cu b/paddle/tcmpt/kernels/cuda/linalg.cu index 25d1df5cbc65a..77001d988038d 100644 --- a/paddle/tcmpt/kernels/cuda/linalg.cu +++ b/paddle/tcmpt/kernels/cuda/linalg.cu @@ -39,7 +39,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL("dot", CUDA, - NCHW, + Any, pt::Dot, float, double, diff --git a/paddle/tcmpt/kernels/cuda/manipulation.cu b/paddle/tcmpt/kernels/cuda/manipulation.cu index bb4a2cc9a677b..99ee2506fdf41 100644 --- a/paddle/tcmpt/kernels/cuda/manipulation.cu +++ b/paddle/tcmpt/kernels/cuda/manipulation.cu @@ -61,7 +61,7 @@ using float16 = paddle::platform::float16; // architecture, kernel_name should be "flatten". PT_REGISTER_KERNEL("flatten_contiguous_range", CUDA, - NCHW, + Any, pt::Flatten, float, float16, @@ -73,7 +73,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range", PT_REGISTER_KERNEL("flatten_contiguous_range.mid", CUDA, - NCHW, + Any, pt::FlattenWithXShape, float, double, diff --git a/paddle/tcmpt/kernels/cuda/math.cu b/paddle/tcmpt/kernels/cuda/math.cu index 743615d70f996..f0d76744f68bd 100644 --- a/paddle/tcmpt/kernels/cuda/math.cu +++ b/paddle/tcmpt/kernels/cuda/math.cu @@ -121,11 +121,11 @@ void ScaleHost(const CUDAContext& dev_ctx, PT_REGISTER_MODULE(MathCUDA); using float16 = paddle::platform::float16; -PT_REGISTER_KERNEL("sign", CUDA, NCHW, pt::Sign, float, double, float16) {} -PT_REGISTER_KERNEL("mean", CUDA, NCHW, pt::Mean, float, double, float16) {} +PT_REGISTER_KERNEL("sign", CUDA, Any, pt::Sign, float, double, float16) {} +PT_REGISTER_KERNEL("mean", CUDA, Any, pt::Mean, float, double, float16) {} PT_REGISTER_KERNEL("scale", CUDA, - NCHW, + Any, pt::Scale, float, double, @@ -137,7 +137,7 @@ PT_REGISTER_KERNEL("scale", int64_t) {} PT_REGISTER_KERNEL("scale.host", CUDA, - NCHW, + Any, pt::ScaleHost, float, double, diff --git a/paddle/utils/small_vector.h b/paddle/utils/small_vector.h index f51a3b623ce3b..e9e7996babcf7 100644 --- a/paddle/utils/small_vector.h +++ b/paddle/utils/small_vector.h @@ -3,6 +3,8 @@ // 1. remove macro // 2. remove LLVM_LIKELY and LLVM_UNLIKELY // 3. add at(index) method for small vector +// 4. wrap the call to max and min with parenthesis to prevent the macro +// expansion to fix the build error on windows platform //===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===// // @@ -90,7 +92,7 @@ class SmallVectorBase { /// The maximum value of the Size_T used. static constexpr size_t SizeTypeMax() { - return std::numeric_limits::max(); + return (std::numeric_limits::max)(); } SmallVectorBase() = delete; @@ -309,7 +311,7 @@ class SmallVectorTemplateCommon size_type size_in_bytes() const { return size() * sizeof(T); } size_type max_size() const { - return std::min(this->SizeTypeMax(), size_type(-1) / sizeof(T)); + return (std::min)(this->SizeTypeMax(), size_type(-1) / sizeof(T)); } size_t capacity_in_bytes() const { return capacity() * sizeof(T); } @@ -727,7 +729,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase { } // Assign over existing elements. - std::fill_n(this->begin(), std::min(NumElts, this->size()), Elt); + std::fill_n(this->begin(), (std::min)(NumElts, this->size()), Elt); if (NumElts > this->size()) std::uninitialized_fill_n(this->end(), NumElts - this->size(), Elt); else if (NumElts < this->size()) @@ -1393,7 +1395,7 @@ static void report_at_maximum_capacity(size_t MaxSize) { // Note: Moving this function into the header may cause performance regression. template static size_t getNewCapacity(size_t MinSize, size_t TSize, size_t OldCapacity) { - constexpr size_t MaxSize = std::numeric_limits::max(); + constexpr size_t MaxSize = (std::numeric_limits::max)(); // Ensure we can fit the new capacity. // This is only going to be applicable when the capacity is 32 bit. @@ -1408,7 +1410,7 @@ static size_t getNewCapacity(size_t MinSize, size_t TSize, size_t OldCapacity) { // In theory 2*capacity can overflow if the capacity is 64 bit, but the // original capacity would never be large enough for this to be a problem. size_t NewCapacity = 2 * OldCapacity + 1; // Always grow. - return std::min(std::max(NewCapacity, MinSize), MaxSize); + return (std::min)((std::max)(NewCapacity, MinSize), MaxSize); } // Note: Moving this function into the header may cause performance regression.