Skip to content

Commit

Permalink
[PHI] Move set_value kernel to phi (#40195)
Browse files Browse the repository at this point in the history
* save code

* fix bug of set_value

* add coverage test
  • Loading branch information
zyfncg authored Mar 9, 2022
1 parent 63fb034 commit cd28cdd
Show file tree
Hide file tree
Showing 14 changed files with 1,701 additions and 212 deletions.
65 changes: 64 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,20 @@ bool ExecutionContext::HasInput(const std::string& name) const {
return var != nullptr;
}

bool ExecutionContext::HasInputs(const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (const auto* input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}

bool ExecutionContext::HasOutput(const std::string& name) const {
auto* var = OutputVar(name);
return var != nullptr;
Expand Down Expand Up @@ -2189,6 +2203,51 @@ void OperatorWithKernel::BuildPhiKernelContext(
std::move(experimental::MakePhiScalarFromVar(*ins_vector.front())));
}

} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext.",
attr_names[i]));
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]);
Expand All @@ -2212,7 +2271,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ class ExecutionContext {

virtual bool HasInput(const std::string& name) const;

virtual bool HasInputs(const std::string& name) const;

virtual bool HasOutput(const std::string& name) const;

virtual size_t InputSize(const std::string& name) const {
Expand Down Expand Up @@ -449,7 +451,7 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
: ctx_(ctx) {}

bool HasInput(const std::string& name) const override {
return ctx_.HasInput(name);
return ctx_.HasInputs(name);
}

bool HasOutput(const std::string& name) const override {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/imperative/execution_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
return (it != var_map_in_.end() && it->second.size() > 0);
}

bool HasInputs(const std::string& name) const override {
auto it = var_map_in_.find(name);
return (it != var_map_in_.end() && it->second.size() > 0);
}

bool HasOutput(const std::string& name) const override {
auto it = var_map_out_.find(name);
return (it != var_map_out_.end() && it->second.size() > 0);
Expand Down
61 changes: 60 additions & 1 deletion paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ void BuildDygraphPhiKernelContext(
}

for (size_t i = 0; i < attr_names.size(); ++i) {
VLOG(1) << "############## attr_name: " << i << " : " << attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) {
if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute
Expand Down Expand Up @@ -409,6 +410,60 @@ void BuildDygraphPhiKernelContext(
experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
}

} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<bool>))) {
const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext.",
attr_names[i]));
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
Expand All @@ -432,7 +487,11 @@ void BuildDygraphPhiKernelContext(
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,6 @@ REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops::SetValueGradMaker<paddle::imperative::OpBase>,
ops::SetValueOpInplaceInferer);

REGISTER_OP_CPU_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CPUDeviceContext, int>,
ops::SetValueKernel<plat::CPUDeviceContext, int64_t>,
ops::SetValueKernel<plat::CPUDeviceContext, float>,
ops::SetValueKernel<plat::CPUDeviceContext, double>,
ops::SetValueKernel<plat::CPUDeviceContext, bool>);

REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);

REGISTER_OP_CPU_KERNEL(
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/set_value_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@

namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CUDADeviceContext, int>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, float>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, double>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, bool>);

REGISTER_OP_CUDA_KERNEL(
set_value_grad,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int>,
Expand Down
Loading

0 comments on commit cd28cdd

Please sign in to comment.