Skip to content

Commit

Permalink
New ir support data transfer (#54763)
Browse files Browse the repository at this point in the history
* add kernel dialect

* change DenseTensorTypeStorage to DenseTensorType

* add test case`

* add first pd_op to kernel dialect

* lower pd op to kernel dialect

* update

* update

* remove useless code

* add attrite print test

* fix bug

* update

* update

* update

* update

* polish code

* fix bug

* polish  code  and add python test

* add test

* fix test error

* add env flag

* fix bug

* revert test env

* change cc_test_old to cc_test

* fix build_static bug

* fix type test error

* udpate cmake

* disable test in windows

* update

* update

* fix bug

* split file

* fix conflict

* polish code and fix conflict

* support place transformer

* finish bug

* add gpu flags

* fix with cuda macro

* update

* add scope guard

* polish code
  • Loading branch information
phlrain authored Jun 27, 2023
1 parent 9307d35 commit b58869f
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -985,9 +985,9 @@ void BuildOpFuncList(
.data();

VLOG(6) << "finish process infer meta context";
auto t1 =
phi::KernelFactory::Instance().SelectKernel(kernel_name, kernel_key);
op_func_node.phi_kernel_ = new phi::Kernel(t1);
auto t1 = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
op_func_node.phi_kernel_ = new phi::Kernel(t1.kernel);

PADDLE_ENFORCE_EQ(op_func_node.phi_kernel_->IsValid(),
true,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,12 @@ void NewIRInterpreter::RunInstruction(const Instruction& instr_node) {
if (instr_node.PreDefineContext()) {
VLOG(5) << "run new ir selected kernel";
auto op_func_node = const_cast<OpFuncNode*>((instr_node.OpFunc()));
VLOG(5) << "begin to run op " << op_func_node->phi_op_name_;
op_func_node->infer_shape_interface_->infer_shape_(
&(op_func_node->infer_meta_context_));
VLOG(5) << "after run infer meta";
(*(op_func_node->phi_kernel_))(&(op_func_node->kernel_context_));
VLOG(5) << "after run kernel";
} else if (!instr_node.IsArtificial()) {
RunOperator(instr_node);
CheckGC(instr_node);
Expand Down
95 changes: 85 additions & 10 deletions paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
#include "paddle/fluid/ir/dialect/kernel_type.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace dialect {

const int init_on_gpu_threashold = 1000;

phi::KernelKey GetKernelKey(
ir::Operation* op,
const phi::Place& place,
Expand Down Expand Up @@ -100,7 +104,31 @@ phi::KernelKey GetKernelKey(
// parse all the input tensor
if (tensor_input_number == 0 || op->name() == "pd.full_") {
// all the information have to get from attribute and context
kernel_backend = paddle::experimental::ParseBackend(place);

if (op->name() == "pd.uniform") {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand(0).source().GetDefiningOp();
if (define_op->name() == "pd.full_int_array") {
auto shape = define_op->attributes()
.at("value")
.dyn_cast<dialect::IntArrayAttribute>()
.data()
.GetData();

size_t numel = 1;
for (auto& s : shape) {
numel *= s;
}
if (numel > init_on_gpu_threashold) {
kernel_backend = phi::Backend::GPU;
}
}
}

if (kernel_backend == phi::Backend::UNDEFINED) {
kernel_backend = paddle::experimental::ParseBackend(place);
}
}
}

Expand Down Expand Up @@ -166,6 +194,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
phi::Place cpu_place(phi::AllocationType::CPU);

ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleKernelDialect>();

std::unordered_map<ir::Operation*, ir::Operation*> map_op_pair;
Expand Down Expand Up @@ -220,22 +249,68 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
// constuct input
std::vector<ir::OpResult> vec_inputs;

if ((*it)->name() != "pd.full" && (*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i).source();
auto new_in = map_value_pair.at(cur_in);

vec_inputs.push_back(new_in);
}
}

paddle::dialect::OpYamlInfoInterface op_info_interface =
(*it)->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
std::string kernel_fn_str;
std::vector<paddle::dialect::OpInputInfo> input_info;
if (op_info_interface) {
auto op_info_res = op_info_interface.GetOpInfo();
auto runtime_info = std::get<3>(op_info_res);
kernel_fn_str = runtime_info.kernel_func[0];
input_info = std::get<0>(op_info_res);
}

if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i).source();
auto new_in = map_value_pair.at(cur_in);

auto new_in_type = new_in.type();

auto& kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN(
kernel_fn_str, kernel_key);

if (kernel.IsValid()) {
if (new_in_type.isa<dialect::AllocatedDenseTensorType>()) {
// allocated type
auto place =
new_in_type.dyn_cast<dialect::AllocatedDenseTensorType>()
.place();

if ((i < input_info.size()) &&
(!input_info[i].is_mutable_attribute) &&
(place != phi::TransToPhiPlace(kernel_key.backend()))) {
if (paddle::experimental::NeedTransformPlace(
place, kernel.InputAt(i).backend, {})) {
VLOG(6) << "need trans from " << place << " to "
<< kernel_key.backend();
// build memcopy op
auto copy_kernel_key = kernel_key;
copy_kernel_key.set_backend(phi::Backend::GPU);
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"op_name", ir::StrAttribute::get(ctx, "pd.memcpy_h2d")},
{"kernel_name", ir::StrAttribute::get(ctx, "memcpy_h2d")},
{"kernel_key",
dialect::KernelAttribute::get(ctx, copy_kernel_key)},
{"dst_place_type", ir::Int32Attribute::get(ctx, 1)}};

ir::Operation* op1 = ir::Operation::Create(
{new_in}, op1_attribute, {new_in_type}, op1_info);

program->block()->push_back(op1);

new_in = op1->result(0);
}
}
} else if (new_in_type.isa<ir::VectorType>()) {
// [ todo need update here, support combine data transfomer]
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support allocated dense tensor type for now"));
}
}
vec_inputs.push_back(new_in);
}
}

std::unordered_map<std::string, ir::Attribute> op1_attribute{
Expand Down
17 changes: 0 additions & 17 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,6 @@ inline bool NeedTransformDataType(const DataType& input,
target == DataType::COMPLEX64 || target == DataType::COMPLEX128);
}

inline bool NeedTransformPlace(const phi::Place& input,
const Backend& target,
const TransformFlag& transform_flag) {
// NOTE(dev): The default value of TransformFlag is True, if it is set with
// False
// somewhere such as ops.yaml or backward.yaml that means we should skip data
// transform. Because "stop_transform_" has highest priority.
if (!transform_flag.need_trans_backend()) {
return false;
}
bool ret = input.GetType() == AllocationType::GPUPINNED ||
(target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) !=
(target != Backend::GPUDNN ? target : Backend::GPU));
return ret;
}

inline bool NeedTransformLayout(const DataLayout& input,
const DataLayout& target,
const phi::Place& place,
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,22 @@ void TransDataBackend(const phi::SelectedRows* tensor,
Backend target_backend,
phi::SelectedRows* out);

inline bool NeedTransformPlace(const phi::Place& input,
const Backend& target,
const TransformFlag& transform_flag) {
// NOTE(dev): The default value of TransformFlag is True, if it is set with
// False
// somewhere such as ops.yaml or backward.yaml that means we should skip data
// transform. Because "stop_transform_" has highest priority.
if (!transform_flag.need_trans_backend()) {
return false;
}
bool ret = input.GetType() == AllocationType::GPUPINNED ||
(target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) !=
(target != Backend::GPUDNN ? target : Backend::GPU));
return ret;
}

} // namespace experimental
} // namespace paddle
18 changes: 18 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,24 @@
func : mean
backward : mean_grad

- op : memcpy_d2h
args : (Tensor x, int dst_place_type)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : memcpy_d2h

- op : memcpy_h2d
args : (Tensor x, int dst_place_type)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : memcpy_h2d

- op : min
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
Expand Down
45 changes: 45 additions & 0 deletions paddle/phi/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,58 @@ const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
if (iter == kernels_.end()) {
return empty_kernel;
}

auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
phi::KernelKey any_layout_kernel_key(
kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}

#if defined(PADDLE_WITH_CUSTOM_DEVICE)
if (kernel_iter == iter->second.end() &&
kernel_key.backend() > phi::Backend::NUM_BACKENDS) {
kernel_iter = iter->second.find({phi::Backend::CUSTOM,
phi::DataLayout::ALL_LAYOUT,
kernel_key.dtype()});
}
#endif

if (kernel_iter == iter->second.end()) {
return empty_kernel;
}

return kernel_iter->second;
}

const Kernel& KernelFactory::SelectKernelWithGPUDNN(
const std::string& kernel_name, const KernelKey& const_kernel_key) const {
auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) {
return empty_kernel;
}
KernelKey kernel_key = KernelKey(const_kernel_key);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (kernel_key.backend() == Backend::GPUDNN) {
auto kernel_iter = iter->second.find(
{Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()});
if (kernel_iter != iter->second.end()) {
return kernel_iter->second;
}
kernel_key =
KernelKey(Backend::GPU, kernel_key.layout(), kernel_key.dtype());
}
#endif

auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
phi::KernelKey any_layout_kernel_key(
kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}

#if defined(PADDLE_WITH_CUSTOM_DEVICE)
if (kernel_iter == iter->second.end() &&
kernel_key.backend() > phi::Backend::NUM_BACKENDS) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ class KernelFactory {
const Kernel& SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;

const Kernel& SelectKernelWithGPUDNN(const std::string& kernel_name,
const KernelKey& kernel_key) const;

KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;

const KernelArgsDef& GetFirstKernelArgsDef(
Expand Down
76 changes: 76 additions & 0 deletions test/cpp/new_executor/standalone_executor_new_ir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,81 @@ TEST(StandaloneExecutor, run_2) {
EXPECT_EQ(res3, true);
}

#ifdef PADDLE_WITH_CUDA
TEST(StandaloneExecutor, data_transfer) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder(ctx, program.block());
ir::Block* block = program.block();

// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{1},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);

// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{100, 100},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);

// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->result(0), uniform2->result(0));
EXPECT_EQ(add->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);

program.Print(std::cout);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

kernel_program->Print(std::cout);

auto place = platform::CPUPlace();
Scope scope;

ProgramDesc prog_desc;

InterpreterCore test_core(place, std::move(kernel_program), &scope);

test_core.Run({});

auto out_tensor = scope.Var("inner_var_9")->Get<phi::DenseTensor>();

auto& pool = phi::DeviceContextPool::Instance();
phi::DenseTensor out;
phi::DeviceContext* dev_ctx = pool.Get(out_tensor.place());
phi::Copy(*dev_ctx, out_tensor, place, true, &out);

bool res0 = simple_cmp(out.data<float>()[0], 0.903649);
bool res1 = simple_cmp(out.data<float>()[1], 1.07367);
bool res2 = simple_cmp(out.data<float>()[2], 1.10631);
bool res3 = simple_cmp(out.data<float>()[3], 1.68683);
std::cerr << out.data<float>()[0] << "\t" << out.data<float>()[1] << "\t"
<< out.data<float>()[2] << "\t" << out.data<float>()[3]
<< std::endl;
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
#endif

} // namespace framework
} // namespace paddle

0 comments on commit b58869f

Please sign in to comment.