Skip to content

Commit

Permalink
TensorRT dot product attention ops (#949)
Browse files Browse the repository at this point in the history
* add detr support

* fix softmax

* add placeholder

* add implement

* add docs and ut

* update testcase

* update docs

* update docs
  • Loading branch information
q.yao authored Sep 5, 2022
1 parent e21cad8 commit 9541be9
Show file tree
Hide file tree
Showing 8 changed files with 525 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "scaled_dot_product_attention.hpp"

#include <assert.h>

#include <chrono>

#include "scaled_dot_product_attention_kernel.hpp"
#include "trt_serialize.hpp"

using namespace nvinfer1;

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"ScaledDotProductAttentionTRT"};
} // namespace

ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string &name)
: TRTPluginBase(name), mask_dim(0) {}

ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, const void *data,
size_t length)
: TRTPluginBase(name), mask_dim(0) {}

ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {}

nvinfer1::IPluginV2DynamicExt *ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT {
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
if (outputIndex == 0) return inputs[0];
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = inputs[1].d[1];

return ret;
}

bool ScaledDotProductAttentionTRT::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
}
}

// Attach the plugin object to an execution context and grant the plugin the
// access to some context resource.
void ScaledDotProductAttentionTRT::attachToContext(cudnnContext *cudnnContext,
cublasContext *cublasContext,
IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {
_cublas_handle = cublasContext;
_cudnn_handle = cudnnContext;
cudnnCreateTensorDescriptor(&_x_desc);
cudnnCreateTensorDescriptor(&_y_desc);
cudnnCreateTensorDescriptor(&_mask_desc);
}

// Detach the plugin object from its execution context.
void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT {
cudnnDestroyTensorDescriptor(_y_desc);
cudnnDestroyTensorDescriptor(_x_desc);
cudnnDestroyTensorDescriptor(_mask_desc);
}

void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {
if (nbInputs != 4) {
mask_dim = 0;
} else {
mask_dim = in[3].desc.dims.nbDims;
}
}

int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs,
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1;
if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1;
int B = inputDesc[0].dims.d[0]; // batch * heads
int Nt = inputDesc[0].dims.d[1];
int Ns = inputDesc[1].dims.d[1];
int E = inputDesc[0].dims.d[2]; // embeding size

const void *query = inputs[0];
const void *key = inputs[1];
const void *value = inputs[2];
const void *mask = nullptr;

int mask_dims[3];
mask_dims[0] = 0;
if (mask_dim > 0) {
mask = inputs[3];
// check if mask need broadcast
if (mask_dim == 2) {
mask_dims[0] = 1;
mask_dims[1] = inputDesc[3].dims.d[0];
mask_dims[2] = inputDesc[3].dims.d[1];
} else {
mask_dims[0] = inputDesc[3].dims.d[0];
mask_dims[1] = inputDesc[3].dims.d[1];
mask_dims[2] = inputDesc[3].dims.d[2];
}
}

void *output = outputs[0];
void *attn = outputs[1];

auto data_type = inputDesc[0].type;
cudnnDataType_t cudnn_dtype{};
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
dot_product_attention_impl<float>((float *)query, (float *)key, (float *)value, (float *)mask,
(float *)attn, (float *)output, B, Nt, Ns, E, &mask_dims[0],
_x_desc, _y_desc, _mask_desc, cudnn_dtype, stream,
_cublas_handle, _cudnn_handle);
break;
default:
return 1;
}

return 0;
}

nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT { return 2; }

size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT { return 0; }

void ScaledDotProductAttentionTRT::serialize(void *buffer) const TRT_NOEXCEPT {}

////////////////////// creator /////////////////////////////

ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {}

const char *ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}

const char *ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT {
auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator);
} // namespace mmdeploy
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class ScaledDotProductAttentionTRT : public TRTPluginBase {
public:
ScaledDotProductAttentionTRT(const std::string &name);

ScaledDotProductAttentionTRT(const std::string name, const void *data, size_t length);

ScaledDotProductAttentionTRT() = delete;

~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override;

virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
void attachToContext(cudnnContext *cudnn, cublasContext *cublas,
nvinfer1::IGpuAllocator *allocator) TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;

private:
int mask_dim;
cublasHandle_t _cublas_handle{};
cudnnHandle_t _cudnn_handle{};
cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{};
};

class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase {
public:
ScaledDotProductAttentionTRTCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>

#include <cmath>
#include <vector>

#include "common_cuda_helper.hpp"
#include "scaled_dot_product_attention_kernel.hpp"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const scalar_t* alpha, const scalar_t* A, int lda,
long long int strideA, const scalar_t* B, int ldb,
long long int strideB, const scalar_t* beta,
scalar_t* C, int ldc, long long int strideC,
int batchCount);

template <>
cublasStatus_t cublasgemmStridedBatchedWrap<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float* alpha, const float* A, int lda,
long long int strideA, const float* B, int ldb,
long long int strideB, const float* beta,
float* C, int ldc, long long int strideC,
int batchCount) {
return cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb,
strideB, beta, C, ldc, strideC, batchCount);
}

template <>
cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const __half* alpha, const __half* A, int lda,
long long int strideA, const __half* B, int ldb,
long long int strideB, const __half* beta,
__half* C, int ldc, long long int strideC,
int batchCount) {
return cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb,
strideB, beta, C, ldc, strideC, batchCount);
}

template <typename scalar_t>
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value,
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B,
int Nt, int Ns, int E, const int* mask_dims,
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype,
cudaStream_t stream, cublasHandle_t cublas_handle,
cudnnHandle_t cudnn_handle) {
{
// Q @ K
const int m = Ns;
const int n = Nt;
const int k = E;
const auto alpha = scalar_t(1.0f / sqrt(float(E)));
const auto beta = scalar_t(0);
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, key, k,
Ns * E, query, k, Nt * E, &beta, attn, m, Nt * Ns, B);
}

if (mask_dims != nullptr && mask_dims[0] != 0) {
const auto alpha = scalar_t(1);
const auto beta = scalar_t(1);
cudnnSetTensor4dDescriptor(mask_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, mask_dims[0],
mask_dims[1], mask_dims[2]);
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns);
cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn);
}

{
// softmax attention
const auto alpha = scalar_t(1);
const auto beta = scalar_t(0);
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1);
cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1);
cudnnSoftmaxForward(cudnn_handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha,
x_desc, attn, &beta, y_desc, attn);
}

{
// attn @ v
const int m = E;
const int n = Nt;
const int k = Ns;
const auto alpha = scalar_t(1);
const auto beta = scalar_t(0);
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, value, m,
Ns * E, (const scalar_t*)(attn), k, Ns * Nt, &beta, output, m,
Nt * E, B);
}
}

template void dot_product_attention_impl<float>(
const float* query, const float* key, const float* value, const float* mask, float* attn,
float* output, int B, int Nt, int Ns, int E, const int* mask_dims,
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream,
cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle);
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cudnn.h>

template <typename scalar_t>
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value,
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B,
int Nt, int Ns, int E, const int* mask_dims,
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype,
cudaStream_t stream, cublasHandle_t cublas_handle,
cudnnHandle_t cudnn_handle);

#endif
Loading

0 comments on commit 9541be9

Please sign in to comment.