Skip to content

Commit

Permalink
[Phi] Move determinant op kernel into phi (PaddlePaddle#40539)
Browse files Browse the repository at this point in the history
* add determinant phi kernel

* remove original determinant op kernel

* add determinant grad [hi kernel

* fix determinant test failed

* remove original determinant grad op kernel
  • Loading branch information
chenwhql authored and liqitong-a committed Mar 17, 2022
1 parent a74d1ef commit 8d0d6c1
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 246 deletions.
8 changes: 0 additions & 8 deletions paddle/fluid/operators/determinant_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,6 @@ REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker,

REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp)

REGISTER_OP_CPU_KERNEL(determinant,
ops::DeterminantKernel<plat::CPUDeviceContext, float>,
ops::DeterminantKernel<plat::CPUDeviceContext, double>);

REGISTER_OP_CPU_KERNEL(
determinant_grad, ops::DeterminantGradKernel<plat::CPUDeviceContext, float>,
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>);

REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops::SlogDeterminantOpMaker,
ops::SlogDeterminantGradOpMaker<paddle::framework::OpDesc>,
Expand Down
8 changes: 0 additions & 8 deletions paddle/fluid/operators/determinant_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ limitations under the License. */

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
determinant, ops::DeterminantKernel<plat::CUDADeviceContext, float>,
ops::DeterminantKernel<plat::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
determinant_grad,
ops::DeterminantGradKernel<plat::CUDADeviceContext, float>,
ops::DeterminantGradKernel<plat::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
slogdeterminant, ops::SlogDeterminantKernel<plat::CUDADeviceContext, float>,
Expand Down
237 changes: 8 additions & 229 deletions paddle/fluid/operators/determinant_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
Expand All @@ -40,232 +43,6 @@ T sign(T val) {
return static_cast<T>(T(0) < val) - (val < T(0));
}

template <typename T>
class EigenMatrix {};

template <>
class EigenMatrix<float> {
public:
using MatrixType = Eigen::MatrixXf;
};

template <>
class EigenMatrix<double> {
public:
using MatrixType = Eigen::MatrixXd;
};

inline int64_t GetBatchCount(const framework::DDim dims) {
int64_t batch_count = 1;
auto dim_size = dims.size();
PADDLE_ENFORCE_GE(
dim_size, 2,
platform::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));

// Cumulative multiplying each dimension until the last 2 to get the batch
// count,
// for example a tensor with shape [3,3,3,3], the batch count of matrices is
// 9.
for (int64_t i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}

return batch_count;
}

template <typename T>
struct DeterminantFunctor {
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
int64_t rank, int64_t batch_count, Tensor* output) {
std::vector<T> input_vec;
std::vector<T> output_vec;
framework::TensorToVector(input, ctx.device_context(), &input_vec);
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
}
}
output_vec.push_back(matrix.determinant());
}
framework::TensorFromVector(output_vec, output);
}
};
template <typename DeviceContext, typename T>
class DeterminantKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
auto input_dim = vectorize(input->dims());
auto input_dim_size = input_dim.size();
auto* output = context.Output<framework::Tensor>("Out");

auto batch_count = GetBatchCount(input->dims());
VLOG(2) << "input dim:" << input->dims();
PADDLE_ENFORCE_GE(
input_dim_size, 2,
platform::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1],
input_dim[input_dim_size - 2],
platform::errors::InvalidArgument(
"the input matrix should be square matrix."));
auto rank = input_dim[input_dim_size - 1]; // square matrix length
DeterminantFunctor<T>()(*input, context, rank, batch_count, output);
auto output_dims = phi::slice_ddim(input->dims(), 0, input_dim_size - 2);
if (input_dim_size > 2) {
output->Resize(output_dims);
} else {
// when input is a two-dimension matrix, The det value is a number.
output->Resize({1});
}
VLOG(2) << "output dim:" << output->dims();
}
};

template <typename T>
struct FoundZeroFunctor {
FoundZeroFunctor(const T* x, int64_t numel, bool* res)
: x_(x), numel_(numel), res_(res) {}
HOSTDEVICE void operator()(size_t idx) const {
if (*res_ || idx >= static_cast<size_t>(numel_)) {
// founded zero number
return;
}
*res_ = (x_[idx] == static_cast<T>(0));
}
const T* x_;
int64_t numel_;
bool* res_;
};

template <typename DeviceContext, typename T>
inline bool CheckMatrixInvertible(const framework::ExecutionContext& ctx,
const framework::Tensor* det) {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto numel = det->numel();

framework::Tensor dev_tensor;
auto* data = dev_tensor.mutable_data<bool>({1}, ctx.GetPlace());

// set false
phi::funcs::SetConstant<DeviceContext, bool> zero;
zero(dev_ctx, &dev_tensor, false);

// find whether zero
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
FoundZeroFunctor<T> functor(det->data<T>(), numel, data);
for_range(functor);

// copy to host
dev_ctx.Wait();
framework::Tensor cpu_tensor;
framework::TensorCopy(dev_tensor, platform::CPUPlace(), &cpu_tensor);

// if founded zero, the matrix is not invertible
// else the matrix is invertible
auto* res = cpu_tensor.data<bool>();
return !(*res);
}

template <typename DeviceContext, typename T>
class DeterminantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input");
const auto* det = context.Input<framework::Tensor>("Out");
const auto* grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ddet =
context.Output<framework::Tensor>(framework::GradVarName("Input"));

auto input_dims_size = input->dims().size();
if (input_dims_size > 2) {
PADDLE_ENFORCE_EQ(
grad->dims().size() + 2, input_dims_size,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else if (input_dims_size == 2) {
// input dims size 2 and grad dims size 1 is possible
PADDLE_ENFORCE_EQ(
grad->dims().size(), 1,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else {
// checked in forward, pass
}

auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);

// Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0)
if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
ddet->Resize(input->dims());
phi::Full<T>(dev_ctx, phi::vectorize(input->dims()), static_cast<T>(0.0f),
ddet);
return;
}

// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1)

// First: inverse(A)
framework::Tensor inverse_A;
// A must be square matrices!
inverse_A.Resize(input->dims());
inverse_A.mutable_data<T>(context.GetPlace());

phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(orig_dev_ctx, *input, &inverse_A);

VLOG(3) << "inverse(A) dims: " << inverse_A.dims();

// Second: inverse(A).transpose(-2, -1)
framework::Tensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);

VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();

// Third: dA * |A|
auto mul_dA_detA = phi::Multiply<T>(dev_ctx, *grad, *det);
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();

// Fourth: unsqueeze(dA * |A|, [-1, -2])
auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();

// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);

VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();

framework::TensorCopy(res, context.GetPlace(), ddet);

ddet->Resize(input->dims());
VLOG(3) << "d|A| dims: " << ddet->dims();
}
};

template <typename T>
struct SlogDeterminantFunctor {
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
Expand All @@ -280,7 +57,7 @@ struct SlogDeterminantFunctor {
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
typename phi::detail::EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
Expand Down Expand Up @@ -311,7 +88,7 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
auto input_dim_size = input_dim.size();
auto* output = context.Output<framework::Tensor>("Out");

auto batch_count = GetBatchCount(input->dims());
auto batch_count = phi::detail::GetBatchCount(input->dims());
VLOG(2) << "input dim:" << input->dims();
PADDLE_ENFORCE_GE(
input_dim_size, 2,
Expand Down Expand Up @@ -370,7 +147,9 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
// (matrix A not invertible) == (absslogdet(A)=0)
auto slogdet_vec = slogdet->Split(1, 0);
auto absslogdet_val = slogdet_vec[0];
if (!CheckMatrixInvertible<DeviceContext, T>(context, &absslogdet_val)) {
if (!phi::detail::CheckMatrixInvertible<
T, typename framework::ConvertToPhiContext<DeviceContext>::TYPE>(
dev_ctx, &absslogdet_val)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
dslogdet->Resize(input->dims());
Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel math_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel triangular_solve_grad_kernel)
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel math_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(gumbel_softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(gumbel_softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
Expand All @@ -46,6 +50,7 @@ kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)

# 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/kernels/cpu/determinant_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/determinant_grad_kernel.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"

PD_REGISTER_KERNEL(determinant_grad,
CPU,
ALL_LAYOUT,
phi::DeterminantGradKernel,
float,
double) {}
21 changes: 21 additions & 0 deletions paddle/phi/kernels/cpu/determinant_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/determinant_kernel.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"

PD_REGISTER_KERNEL(
determinant, CPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {}
Loading

0 comments on commit 8d0d6c1

Please sign in to comment.