Skip to content

Commit

Permalink
[DML EP] Add SimplifiedLayerNorm and SkipSimplifiedLayerNorm (microso…
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Apr 19, 2024
1 parent a747a00 commit b8c90be
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 21 deletions.
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ Do not modify directly.*
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(float), tensor(float16)|
|||7+|**T** = tensor(float), tensor(float16)|
|LayerNormalization|*in* X:**T**<br> *in* Scale:**T**<br> *in* B:**T**<br> *out* Y:**T**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**<br><br>or<br><br>*in* X:**T**<br> *in* Scale:**V**<br> *in* B:**V**<br> *out* Y:**V**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**|17+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float)|
|||1+|**T** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|||1+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|LeakyRelu|*in* X:**T**<br> *out* Y:**T**|16+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
|Less|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
Expand Down Expand Up @@ -1224,6 +1224,7 @@ Do not modify directly.*
|||6+|**T** = tensor(float), tensor(float16)|
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|Sinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
|Size|*in* data:**T**<br> *out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
Expand Down Expand Up @@ -1306,6 +1307,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
|**Operator Domain:** *com.microsoft.dml*||||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Dml
class DmlOperatorLayerNormalization : public DmlOperator
{
public:
DmlOperatorLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
DmlOperatorLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext, bool simplified)
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2};
Expand Down Expand Up @@ -128,17 +128,18 @@ class DmlOperatorLayerNormalization : public DmlOperator
outputCastOpDesc.Desc = &outputCastDesc;
}

DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {};
DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputCastOpDesc.Desc ? &inputCastOutputDmlTensorDesc : &inputDesc;
operatorDesc.ScaleTensor = scaleCastOpDesc.Desc ? &scaleCastOutputDmlTensorDesc : &scaleDesc;
operatorDesc.BiasTensor = biasCastOpDesc.Desc ? &biasCastOutputDmlTensorDesc : (biasDesc.Desc ? &biasDesc : nullptr);
operatorDesc.OutputTensor = outputCastOpDesc.Desc ? &outputCastOutputDmlTensorDesc : &outputDesc;
operatorDesc.Axes = onnxAxes.data();
operatorDesc.AxisCount = gsl::narrow_cast<uint32_t>(onnxAxes.size());
operatorDesc.NormalizeVariance = true;
operatorDesc.UseMean = !simplified;
operatorDesc.UseVariance = true;
operatorDesc.Epsilon = epsilon;
operatorDesc.FusedActivation = nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2, &operatorDesc };

// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
Expand Down Expand Up @@ -258,7 +259,19 @@ void CALLBACK QueryLayerNormalization(IMLOperatorSupportQueryContextPrivate* con
*isSupported = context->GetOutputCount() == 1;
}

DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization, DmlOperatorLayerNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization17, DmlOperatorLayerNormalization);
// A specific type of operation for registration.
template <bool simplified>
class LayerNormalizationTemplate : public DmlOperatorLayerNormalization
{
public:
LayerNormalizationTemplate(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperatorLayerNormalization(kernelCreationContext, simplified)
{
}
};

DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization, LayerNormalizationTemplate<false>);
DML_OP_DEFINE_CREATION_FUNCTION(LayerNormalization17, LayerNormalizationTemplate<false>);
DML_OP_DEFINE_CREATION_FUNCTION(SimplifiedLayerNormalization, LayerNormalizationTemplate<true>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
namespace Dml
{

template <bool simplified>
class DmlOperatorSkipLayerNormalization : public DmlOperator
{
public:
Expand Down Expand Up @@ -83,17 +84,18 @@ class DmlOperatorSkipLayerNormalization : public DmlOperator
inputSkipBiasAddDesc.OutputTensor = &inputDesc;
DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc };

DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {};
DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC mvnDesc = {};
mvnDesc.InputTensor = &inputDesc;
mvnDesc.ScaleTensor = &gammaDesc;
mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr;
mvnDesc.OutputTensor = &outputDesc;
mvnDesc.Axes = axes.data();
mvnDesc.AxisCount = gsl::narrow_cast<uint32_t>(axes.size());
mvnDesc.NormalizeVariance = true;
mvnDesc.UseMean = !simplified;
mvnDesc.UseVariance = true;
mvnDesc.Epsilon = epsilon;
mvnDesc.FusedActivation = nullptr;
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc };
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2, &mvnDesc };

// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
Expand Down Expand Up @@ -223,6 +225,7 @@ void CALLBACK QuerySkipLayerNormalization(IMLOperatorSupportQueryContextPrivate*
*isSupported = true;
}

DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization<false>);
DML_OP_DEFINE_CREATION_FUNCTION(SkipSimplifiedLayerNormalization, DmlOperatorSkipLayerNormalization<true>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(BiasAdd);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(GroupNorm);
DML_OP_EXTERN_CREATION_FUNCTION(SimplifiedLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(SkipSimplifiedLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(RNN);
DML_OP_EXTERN_CREATION_FUNCTION(GRU);
Expand Down Expand Up @@ -548,7 +550,7 @@ constexpr static std::array<const char*, 2> typeNameListAttention = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListRotaryEmbedding = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 2> typeNameListLayerNorm = { "T", "U" };
constexpr static std::array<const char*, 2> typeNameListLayerNormContrib = { "T", "V" };
constexpr static std::array<const char*, 3> typeNameListLayerNormContrib = { "T", "U", "V" };
constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T3" };
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
Expand Down Expand Up @@ -612,7 +614,7 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListIntege
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListLayerNormalizationContrib = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerNormalization = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
Expand Down Expand Up @@ -1110,7 +1112,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO( 7, SimplifiedLayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)},
{REG_INFO_MS( 1, SkipSimplifiedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)},
{REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, BiasSplitGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,9 @@ using ShapeInferenceHelper_GroupNorm = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = SkipLayerNormHelper;
using ShapeInferenceHelper_SkipSimplifiedLayerNormalization = SkipLayerNormHelper;
using ShapeInferenceHelper_EmbedLayerNormalization = EmbedLayerNormalizationHelper;
using ShapeInferenceHelper_SimplifiedLayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_RNN = RecurrentHelper;
using ShapeInferenceHelper_GRU = RecurrentHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Upsample = 7;
static const int sc_sinceVer_Xor = 7;
static const int sc_sinceVer_LayerNormalization = 1;
static const int sc_sinceVer_SimplifiedLayerNormalization = 1;

// Special operators
static const int sc_sinceVer_MemcpyToHost = 1;
Expand Down Expand Up @@ -454,6 +455,7 @@ namespace OperatorHelper
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_SkipSimplifiedLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
static const int sc_sinceVer_BiasSplitGelu = 1;
static const int sc_sinceVer_NhwcConv = 1;
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/test/contrib_ops/layer_norm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ static void TestLayerNorm(const std::vector<int64_t>& x_dims,
// TODO keep_dims is not implemented, default behavior is to keep ones for reduced dimensions
ASSERT_NE(keep_dims, 0);

const std::vector<int64_t>& stats_dims = keep_dims ? n_and_ones_dims : n_dims;

CompareOpTester test(op.c_str(), opset);
test.AddAttribute("axis", axis);
test.AddAttribute("keep_dims", keep_dims);
Expand All @@ -65,16 +63,20 @@ static void TestLayerNorm(const std::vector<int64_t>& x_dims,
}

std::vector<float> Y_data = FillZeros<float>(n_x_m_dims);
test.AddOutput<float>("output", n_x_m_dims, Y_data);

#ifndef USE_DML
// DML doesn't support more than one output for these ops yet
const std::vector<int64_t>& stats_dims = keep_dims ? n_and_ones_dims : n_dims;
std::vector<float> mean_data = FillZeros<float>(stats_dims);
std::vector<float> var_data = FillZeros<float>(stats_dims);

test.AddOutput<float>("output", n_x_m_dims, Y_data);

// the Main and InvStdDev outputs are training specific
if (op.compare(SIMPLIFIED_LAYER_NORM_OP) != 0) {
test.AddOutput<float>("mean", stats_dims, mean_data);
}
test.AddOutput<float>("var", stats_dims, var_data);
#endif

#ifdef USE_CUDA
test.CompareWithCPU(kCudaExecutionProvider);
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,6 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) {
true);
}

// SkipSimplifiedLayerNorm has not been enabled for DML yet
#if !defined(USE_DML)
TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
Expand Down Expand Up @@ -768,9 +766,8 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
true,
true);
}
#endif

#if !defined(USE_ROCM) && !defined(USE_DML)
#if !defined(USE_ROCM)
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) {
int batch_size = 2;
int sequence_length = 2;
Expand Down

0 comments on commit b8c90be

Please sign in to comment.