Skip to content

Commit

Permalink
Add Sparse op sparse_relu (#40959)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo authored Mar 29, 2022
1 parent 054fc99 commit c544a18
Show file tree
Hide file tree
Showing 11 changed files with 346 additions and 3 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/kernels/sparse/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@

set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils math_function custom_kernel)
register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse_kernel")
set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils math_function custom_kernel copy_kernel)
register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse")
70 changes: 70 additions & 0 deletions paddle/phi/kernels/sparse/sparse_activation_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h"
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void SparseReluGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
SparseCooTensor* x_grad) {
DenseTensor non_zero_indices =
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_indices());
DenseTensor non_zero_elements =
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements());
phi::Copy(dev_ctx,
x.non_zero_indices(),
dev_ctx.GetPlace(),
false,
&non_zero_indices);
phi::ReluGradKernel<T, Context>(dev_ctx,
x.non_zero_elements(),
out_grad.non_zero_elements(),
&non_zero_elements);
x_grad->SetMember(non_zero_indices, non_zero_elements, x.dims(), true);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(sparse_relu_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SparseReluGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(sparse_relu_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SparseReluGradKernel,
float,
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#endif
29 changes: 29 additions & 0 deletions paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/sparse_coo_tensor.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void SparseReluGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
SparseCooTensor* x_grad);

} // namespace sparse
} // namespace phi
66 changes: 66 additions & 0 deletions paddle/phi/kernels/sparse/sparse_activation_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/sparse_activation_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void SparseReluKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
DenseTensor non_zero_indices =
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_indices());
DenseTensor non_zero_elements =
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements());
phi::Copy(dev_ctx,
x.non_zero_indices(),
dev_ctx.GetPlace(),
false,
&non_zero_indices);
phi::ReluKernel<T, Context>(
dev_ctx, x.non_zero_elements(), &non_zero_elements);
out->SetMember(non_zero_indices, non_zero_elements, x.dims(), true);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(sparse_relu,
CPU,
ALL_LAYOUT,
phi::sparse::SparseReluKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(sparse_relu,
GPU,
ALL_LAYOUT,
phi::sparse::SparseReluKernel,
float,
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#endif
39 changes: 39 additions & 0 deletions paddle/phi/kernels/sparse/sparse_activation_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void SparseReluKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out);

template <typename T, typename Context>
SparseCooTensor SparseRelu(const Context& dev_ctx, const SparseCooTensor& x) {
DenseTensor indices, values;
SparseCooTensor coo(indices, values, x.dims());
SparseReluKernel<T, Context>(dev_ctx, x, &coo);
return coo;
}

} // namespace sparse
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/tests/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_sparse_conv3d_dev_api SRCS test_sparse_conv3d_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_sparse_pool_dev_api SRCS test_sparse_pool_dev_api.cc DEPS phi phi_api_utils)
cc_test(test_sparse_activation_dev_api SRCS test_sparse_activation_dev_api.cc DEPS phi phi_api_utils)

cc_test(test_math_function SRCS test_math_function.cc DEPS math_function)
if(WITH_GPU)
Expand Down
83 changes: 83 additions & 0 deletions paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include <memory>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"

#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_activation_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

namespace phi {
namespace tests {

TEST(DEV_API, sparse_relu) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.Init();

DenseTensor dense_x =
phi::Empty(dev_ctx_cpu,
DenseTensorMeta(DataType::FLOAT32, {3, 4}, DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_coo = sparse::DenseToSparseCoo<float>(dev_ctx_cpu, dense_x, 2);

auto sparse_out = sparse::SparseRelu<float>(dev_ctx_cpu, sparse_coo);
DenseTensor dense_out =
phi::EmptyLike<float>(dev_ctx_cpu, sparse_out.non_zero_elements());
ReluKernel<float>(dev_ctx_cpu, sparse_coo.non_zero_elements(), &dense_out);

int cmp = memcmp(dense_out.data<float>(),
sparse_out.non_zero_elements().data<float>(),
dense_out.numel() * sizeof(float));
ASSERT_EQ(cmp, 0);
// backward
DenseTensor dense_grad_x = phi::EmptyLike<float>(dev_ctx_cpu, dense_out);
ReluGradKernel<float>(
dev_ctx_cpu, sparse_coo.non_zero_elements(), dense_out, &dense_grad_x);
SparseCooTensor sparse_grad_x(
phi::EmptyLike<int>(dev_ctx_cpu, sparse_coo.non_zero_indices()),
phi::EmptyLike<int>(dev_ctx_cpu, sparse_coo.non_zero_elements()),
{3, 4});

SparseCooTensor sparse_out_grad(
sparse_coo.non_zero_indices(), dense_out, {3, 4});
sparse::SparseReluGradKernel<float>(
dev_ctx_cpu, sparse_coo, sparse_out_grad, &sparse_grad_x);

cmp = memcmp(dense_grad_x.data<float>(),
sparse_grad_x.non_zero_elements().data<float>(),
dense_grad_x.numel() * sizeof(float));
ASSERT_EQ(cmp, 0);
}

} // namespace tests
} // namespace phi
40 changes: 40 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_activation_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard


class TestSparseActivation(unittest.TestCase):
def test_sparse_relu(self):
with _test_eager_guard():
x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]]
dense_x = paddle.to_tensor(x, dtype='float32')
dense_shape = [3, 4]
stop_gradient = True
sparse_dim = 2
sparse_coo_x = dense_x.to_sparse_coo(sparse_dim)
#TODO(zhangkaihuo): change to test the corresponding API: paddle.sparse.relu(sparse_coo_x)
sparse_act_out = _C_ops.final_state_sparse_relu(sparse_coo_x)
correct_result = [0, 2, 0, 4, 5]
actual_result = sparse_act_out.non_zero_elements().numpy()
assert np.array_equal(correct_result, actual_result)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion python/paddle/tensor/to_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def tensor_to_string(tensor, prefix='Tensor'):

_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"

if not tensor._is_dense_tensor_hold_allocation():
if not tensor._is_initialized():
return "Tensor(Not initialized)"

if tensor.is_sparse():
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/utils/code_gen/sparse_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@
args : (Tensor x)
output : Tensor(out@SparseCsrTensor)
invoke : to_sparse_csr_impl(x)

- api : relu
args : (Tensor x)
output : Tensor(out@SparseCooTensor)
kernel :
func : sparse_relu
layout : x
backward : sparse_relu_grad
7 changes: 7 additions & 0 deletions python/paddle/utils/code_gen/sparse_bw_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@
output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
kernel :
func : sparse_conv3d_grad

- backward_api : sparse_relu_grad
forward : sparse_relu(Tensor x) -> Tensor(out@SparseCooTensor)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad@SparseCooTensor)
kernel :
func : sparse_relu_grad

0 comments on commit c544a18

Please sign in to comment.