-
Notifications
You must be signed in to change notification settings - Fork 647
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TensorRT dot product attention ops (#949)
* add detr support * fix softmax * add placeholder * add implement * add docs and ut * update testcase * update docs * update docs
- Loading branch information
q.yao
authored
Sep 5, 2022
1 parent
e21cad8
commit 9541be9
Showing
8 changed files
with
525 additions
and
34 deletions.
There are no files selected for viewing
183 changes: 183 additions & 0 deletions
183
...deploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#include "scaled_dot_product_attention.hpp" | ||
|
||
#include <assert.h> | ||
|
||
#include <chrono> | ||
|
||
#include "scaled_dot_product_attention_kernel.hpp" | ||
#include "trt_serialize.hpp" | ||
|
||
using namespace nvinfer1; | ||
|
||
namespace mmdeploy { | ||
namespace { | ||
static const char *PLUGIN_VERSION{"1"}; | ||
static const char *PLUGIN_NAME{"ScaledDotProductAttentionTRT"}; | ||
} // namespace | ||
|
||
ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string &name) | ||
: TRTPluginBase(name), mask_dim(0) {} | ||
|
||
ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, const void *data, | ||
size_t length) | ||
: TRTPluginBase(name), mask_dim(0) {} | ||
|
||
ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {} | ||
|
||
nvinfer1::IPluginV2DynamicExt *ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT { | ||
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(mLayerName); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions( | ||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, | ||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { | ||
if (outputIndex == 0) return inputs[0]; | ||
nvinfer1::DimsExprs ret; | ||
ret.nbDims = 3; | ||
ret.d[0] = inputs[0].d[0]; | ||
ret.d[1] = inputs[0].d[1]; | ||
ret.d[2] = inputs[1].d[1]; | ||
|
||
return ret; | ||
} | ||
|
||
bool ScaledDotProductAttentionTRT::supportsFormatCombination( | ||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { | ||
if (pos == 0) { | ||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && | ||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); | ||
} else { | ||
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; | ||
} | ||
} | ||
|
||
// Attach the plugin object to an execution context and grant the plugin the | ||
// access to some context resource. | ||
void ScaledDotProductAttentionTRT::attachToContext(cudnnContext *cudnnContext, | ||
cublasContext *cublasContext, | ||
IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { | ||
_cublas_handle = cublasContext; | ||
_cudnn_handle = cudnnContext; | ||
cudnnCreateTensorDescriptor(&_x_desc); | ||
cudnnCreateTensorDescriptor(&_y_desc); | ||
cudnnCreateTensorDescriptor(&_mask_desc); | ||
} | ||
|
||
// Detach the plugin object from its execution context. | ||
void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT { | ||
cudnnDestroyTensorDescriptor(_y_desc); | ||
cudnnDestroyTensorDescriptor(_x_desc); | ||
cudnnDestroyTensorDescriptor(_mask_desc); | ||
} | ||
|
||
void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, | ||
int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc *out, | ||
int nbOutputs) TRT_NOEXCEPT { | ||
if (nbInputs != 4) { | ||
mask_dim = 0; | ||
} else { | ||
mask_dim = in[3].desc.dims.nbDims; | ||
} | ||
} | ||
|
||
int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, | ||
const void *const *inputs, void *const *outputs, | ||
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { | ||
if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1; | ||
if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1; | ||
int B = inputDesc[0].dims.d[0]; // batch * heads | ||
int Nt = inputDesc[0].dims.d[1]; | ||
int Ns = inputDesc[1].dims.d[1]; | ||
int E = inputDesc[0].dims.d[2]; // embeding size | ||
|
||
const void *query = inputs[0]; | ||
const void *key = inputs[1]; | ||
const void *value = inputs[2]; | ||
const void *mask = nullptr; | ||
|
||
int mask_dims[3]; | ||
mask_dims[0] = 0; | ||
if (mask_dim > 0) { | ||
mask = inputs[3]; | ||
// check if mask need broadcast | ||
if (mask_dim == 2) { | ||
mask_dims[0] = 1; | ||
mask_dims[1] = inputDesc[3].dims.d[0]; | ||
mask_dims[2] = inputDesc[3].dims.d[1]; | ||
} else { | ||
mask_dims[0] = inputDesc[3].dims.d[0]; | ||
mask_dims[1] = inputDesc[3].dims.d[1]; | ||
mask_dims[2] = inputDesc[3].dims.d[2]; | ||
} | ||
} | ||
|
||
void *output = outputs[0]; | ||
void *attn = outputs[1]; | ||
|
||
auto data_type = inputDesc[0].type; | ||
cudnnDataType_t cudnn_dtype{}; | ||
convert_trt2cudnn_dtype(data_type, &cudnn_dtype); | ||
switch (data_type) { | ||
case nvinfer1::DataType::kFLOAT: | ||
dot_product_attention_impl<float>((float *)query, (float *)key, (float *)value, (float *)mask, | ||
(float *)attn, (float *)output, B, Nt, Ns, E, &mask_dims[0], | ||
_x_desc, _y_desc, _mask_desc, cudnn_dtype, stream, | ||
_cublas_handle, _cudnn_handle); | ||
break; | ||
default: | ||
return 1; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType( | ||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { | ||
return inputTypes[0]; | ||
} | ||
|
||
// IPluginV2 Methods | ||
const char *ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } | ||
|
||
const char *ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT { return 2; } | ||
|
||
size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT { return 0; } | ||
|
||
void ScaledDotProductAttentionTRT::serialize(void *buffer) const TRT_NOEXCEPT {} | ||
|
||
////////////////////// creator ///////////////////////////// | ||
|
||
ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {} | ||
|
||
const char *ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT { | ||
return PLUGIN_NAME; | ||
} | ||
|
||
const char *ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::createPlugin( | ||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { | ||
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(name); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::deserializePlugin( | ||
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { | ||
auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator); | ||
} // namespace mmdeploy |
73 changes: 73 additions & 0 deletions
73
...deploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP | ||
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP | ||
#include <cublas_v2.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "trt_plugin_base.hpp" | ||
|
||
namespace mmdeploy { | ||
class ScaledDotProductAttentionTRT : public TRTPluginBase { | ||
public: | ||
ScaledDotProductAttentionTRT(const std::string &name); | ||
|
||
ScaledDotProductAttentionTRT(const std::string name, const void *data, size_t length); | ||
|
||
ScaledDotProductAttentionTRT() = delete; | ||
|
||
~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override; | ||
|
||
virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc *out, | ||
int nbOutputs) TRT_NOEXCEPT override; | ||
// IPluginV2DynamicExt Methods | ||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; | ||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, | ||
int nbInputs, nvinfer1::IExprBuilder &exprBuilder) | ||
TRT_NOEXCEPT override; | ||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, | ||
int nbOutputs) TRT_NOEXCEPT override; | ||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, | ||
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; | ||
|
||
// IPluginV2Ext Methods | ||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, | ||
int nbInputs) const TRT_NOEXCEPT override; | ||
|
||
// IPluginV2 Methods | ||
const char *getPluginType() const TRT_NOEXCEPT override; | ||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
int getNbOutputs() const TRT_NOEXCEPT override; | ||
size_t getSerializationSize() const TRT_NOEXCEPT override; | ||
void serialize(void *buffer) const TRT_NOEXCEPT override; | ||
void attachToContext(cudnnContext *cudnn, cublasContext *cublas, | ||
nvinfer1::IGpuAllocator *allocator) TRT_NOEXCEPT override; | ||
void detachFromContext() TRT_NOEXCEPT override; | ||
|
||
private: | ||
int mask_dim; | ||
cublasHandle_t _cublas_handle{}; | ||
cudnnHandle_t _cudnn_handle{}; | ||
cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{}; | ||
}; | ||
|
||
class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase { | ||
public: | ||
ScaledDotProductAttentionTRTCreator(); | ||
|
||
const char *getPluginName() const TRT_NOEXCEPT override; | ||
|
||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
|
||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) | ||
TRT_NOEXCEPT override; | ||
|
||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, | ||
size_t serialLength) TRT_NOEXCEPT override; | ||
}; | ||
} // namespace mmdeploy | ||
#endif // TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP |
103 changes: 103 additions & 0 deletions
103
.../backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#include <thrust/functional.h> | ||
#include <thrust/iterator/counting_iterator.h> | ||
#include <thrust/iterator/transform_iterator.h> | ||
#include <thrust/transform.h> | ||
|
||
#include <cmath> | ||
#include <vector> | ||
|
||
#include "common_cuda_helper.hpp" | ||
#include "scaled_dot_product_attention_kernel.hpp" | ||
#include "trt_plugin_helper.hpp" | ||
|
||
template <typename scalar_t> | ||
cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa, | ||
cublasOperation_t transb, int m, int n, int k, | ||
const scalar_t* alpha, const scalar_t* A, int lda, | ||
long long int strideA, const scalar_t* B, int ldb, | ||
long long int strideB, const scalar_t* beta, | ||
scalar_t* C, int ldc, long long int strideC, | ||
int batchCount); | ||
|
||
template <> | ||
cublasStatus_t cublasgemmStridedBatchedWrap<float>(cublasHandle_t handle, cublasOperation_t transa, | ||
cublasOperation_t transb, int m, int n, int k, | ||
const float* alpha, const float* A, int lda, | ||
long long int strideA, const float* B, int ldb, | ||
long long int strideB, const float* beta, | ||
float* C, int ldc, long long int strideC, | ||
int batchCount) { | ||
return cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, | ||
strideB, beta, C, ldc, strideC, batchCount); | ||
} | ||
|
||
template <> | ||
cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, cublasOperation_t transa, | ||
cublasOperation_t transb, int m, int n, int k, | ||
const __half* alpha, const __half* A, int lda, | ||
long long int strideA, const __half* B, int ldb, | ||
long long int strideB, const __half* beta, | ||
__half* C, int ldc, long long int strideC, | ||
int batchCount) { | ||
return cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, | ||
strideB, beta, C, ldc, strideC, batchCount); | ||
} | ||
|
||
template <typename scalar_t> | ||
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, | ||
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, | ||
int Nt, int Ns, int E, const int* mask_dims, | ||
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, | ||
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, | ||
cudaStream_t stream, cublasHandle_t cublas_handle, | ||
cudnnHandle_t cudnn_handle) { | ||
{ | ||
// Q @ K | ||
const int m = Ns; | ||
const int n = Nt; | ||
const int k = E; | ||
const auto alpha = scalar_t(1.0f / sqrt(float(E))); | ||
const auto beta = scalar_t(0); | ||
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, key, k, | ||
Ns * E, query, k, Nt * E, &beta, attn, m, Nt * Ns, B); | ||
} | ||
|
||
if (mask_dims != nullptr && mask_dims[0] != 0) { | ||
const auto alpha = scalar_t(1); | ||
const auto beta = scalar_t(1); | ||
cudnnSetTensor4dDescriptor(mask_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, mask_dims[0], | ||
mask_dims[1], mask_dims[2]); | ||
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns); | ||
cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn); | ||
} | ||
|
||
{ | ||
// softmax attention | ||
const auto alpha = scalar_t(1); | ||
const auto beta = scalar_t(0); | ||
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); | ||
cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); | ||
cudnnSoftmaxForward(cudnn_handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, | ||
x_desc, attn, &beta, y_desc, attn); | ||
} | ||
|
||
{ | ||
// attn @ v | ||
const int m = E; | ||
const int n = Nt; | ||
const int k = Ns; | ||
const auto alpha = scalar_t(1); | ||
const auto beta = scalar_t(0); | ||
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, value, m, | ||
Ns * E, (const scalar_t*)(attn), k, Ns * Nt, &beta, output, m, | ||
Nt * E, B); | ||
} | ||
} | ||
|
||
template void dot_product_attention_impl<float>( | ||
const float* query, const float* key, const float* value, const float* mask, float* attn, | ||
float* output, int B, int Nt, int Ns, int E, const int* mask_dims, | ||
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, | ||
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream, | ||
cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle); |
17 changes: 17 additions & 0 deletions
17
...backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP | ||
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP | ||
#include <cublas_v2.h> | ||
#include <cuda_runtime.h> | ||
#include <cudnn.h> | ||
|
||
template <typename scalar_t> | ||
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, | ||
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, | ||
int Nt, int Ns, int E, const int* mask_dims, | ||
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, | ||
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, | ||
cudaStream_t stream, cublasHandle_t cublas_handle, | ||
cudnnHandle_t cudnn_handle); | ||
|
||
#endif |
Oops, something went wrong.