Skip to content

Commit

Permalink
[NewIR]fix new ir edit distance bug (#55294)
Browse files Browse the repository at this point in the history
* fix edit distance bug

* add op define kernel data type

* fix bug

* update

* add header

* add op test to cmake
  • Loading branch information
phlrain authored Jul 13, 2023
1 parent 6f7ceca commit 2194e4c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ void HandleForSpecialOp(ir::Operation* op,
auto feed_list = feed_var->Get<paddle::framework::FeedList>();
auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index)));
out_tensor->ShareDataWith(in_tensor);
out_tensor->set_lod(in_tensor.lod());
}

if (op_name == "builtin.combine") {
Expand Down
18 changes: 16 additions & 2 deletions paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,23 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_factory.h"

namespace paddle {
namespace dialect {

const int init_on_gpu_threashold = 1000;

std::unordered_map<std::string, phi::DataType> Str2PhiDataType = {
{"DataType::FLOAT16", phi::DataType::FLOAT16},
{"DataType::BFLOAT16", phi::DataType::BFLOAT16},
{"DataType::FLOAT32", phi::DataType::FLOAT32},
{"DataType::FLOAT64", phi::DataType::FLOAT64},
{"DataType::INT16", phi::DataType::INT16},
{"DataType::INT32", phi::DataType::INT32},
{"DataType::INT64", phi::DataType::INT64},
{"DataType::INT8", phi::DataType::INT8},
{"DataType::BOOL", phi::DataType::BOOL},
};

phi::KernelKey GetKernelKey(
ir::Operation* op,
const phi::Place& place,
Expand Down Expand Up @@ -67,7 +78,10 @@ phi::KernelKey GetKernelKey(
auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->InputName2Id();

if (input_map.count(slot_name)) {
auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);

Expand Down
3 changes: 2 additions & 1 deletion test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,8 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
FLAGS_new_executor_static_build=true)
endforeach()

set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2)
set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2
test_edit_distance_op)

foreach(NEW_IR_COVERAGE_TEST ${NEW_IR_COVERAGE_TESTS})
py_test_modules(
Expand Down

0 comments on commit 2194e4c

Please sign in to comment.