diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 98b376b55f864..d5b8bd8fc258c 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2584,8 +2584,8 @@ no_need_buffer : input - backward_op : weight_only_linear_grad - forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype) -> Tensor(out) - args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype) + forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch) -> Tensor(out) + args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype, int arch) output : Tensor(x_grad) infer_meta : func : WeightOnlyLinearGradInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c55e8ffc132e6..4e3d1c6a5682a 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2814,7 +2814,7 @@ data_type : out_dtype - op : weight_only_linear - args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype) + args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80) output : Tensor(out) infer_meta : func : WeightOnlyLinearInferMeta @@ -2825,7 +2825,7 @@ backward: weight_only_linear_grad - op : weight_quantize - args : (Tensor x, str algo="weight_only_int8") + args : (Tensor x, str algo = "weight_only_int8", int arch = 80) output : Tensor(out), Tensor(scale) infer_meta : func : WeightQuantizeInferMeta diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 4c5e130aab7a0..e7a4e16fb912c 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1162,7 +1162,13 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x, const MetaTensor& weight_scale, const MetaTensor& out_grad, const std::string& weight_dtype, + const int32_t arch, MetaTensor* x_grad) { + PADDLE_ENFORCE_EQ( + arch, + 80, + phi::errors::InvalidArgument( + "Currently weightonly linear grad only support arch = 80. ")); x_grad->set_dims(x.dims()); x_grad->set_dtype(x.dtype()); } diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 13dd392344f97..85d70286226a7 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -451,6 +451,7 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x, const MetaTensor& weight_scale, const MetaTensor& out_grad, const std::string& weight_dtype, + const int32_t arch, MetaTensor* x_grad); void YoloLossGradInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 09b643a030998..7106aaaad5df9 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3858,6 +3858,7 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x, const MetaTensor& bias, const MetaTensor& weight_scale, const std::string& weight_dtype, + const int32_t arch, MetaTensor* out) { auto x_dims = x.dims(); auto w_dims = weight.dims(); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index ee62d6d51d655..e885e8292fc9f 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -717,6 +717,7 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x, const MetaTensor& bias, const MetaTensor& weight_scale, const std::string& weight_dtype, + const int32_t arch, MetaTensor* out); void WeightedSampleNeighborsInferMeta(const MetaTensor& row, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index f61e2c1badd70..23aef3ffe4ed0 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5118,8 +5118,14 @@ void UnStackInferMeta(const MetaTensor& x, void WeightQuantizeInferMeta(const MetaTensor& x, const std::string& algo, + const int32_t arch, MetaTensor* out, MetaTensor* scale) { + PADDLE_ENFORCE_EQ( + ((arch == 80) || (arch == 70)), + true, + phi::errors::InvalidArgument("Currently, arch only support 70, 80.")); + auto x_dims = x.dims(); PADDLE_ENFORCE_EQ( x_dims.size(), diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 1fe7968bcd189..daab02f2b46b1 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -467,6 +467,7 @@ void QuantizeXPUInferMeta(const MetaTensor& x, void WeightQuantizeInferMeta(const MetaTensor& x, const std::string& algo, + const int32_t arch, MetaTensor* out, MetaTensor* scale); diff --git a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc index 8db05de311082..9b23537764209 100644 --- a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc +++ b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc @@ -27,7 +27,13 @@ void quant_compute(const DeviceContext& dev_ctx, const DenseTensor& x, DenseTensor* out, DenseTensor* scale, - const std::string& algo) { + const std::string& algo, + const int32_t arch) { + PADDLE_ENFORCE_EQ( + ((arch == 80) || (arch == 70)), + true, + phi::errors::InvalidArgument("Currently, arch only support 70, 80.")); + const auto x_dims = x.dims(); PADDLE_ENFORCE_EQ( x_dims.size(), @@ -43,7 +49,14 @@ void quant_compute(const DeviceContext& dev_ctx, float* scale_data = scale->data(); DenseTensor x_int(out->type()); - x_int.Resize({static_cast(m), static_cast(n)}); + if (arch == 80) { + x_int.Resize({static_cast(m), static_cast(n)}); + } else { + // phi::Copy may change tensor meta info, here we transpose the quanted + // data's shape. + x_int.Resize({static_cast(n), static_cast(m)}); + } + dev_ctx.template Alloc(&x_int); D* x_int_data = x_int.data(); @@ -64,13 +77,20 @@ void quant_compute(const DeviceContext& dev_ctx, funcs::Transpose trans; trans(dev_ctx, x_int, out, axis); } else { - permute_B_rows_for_mixed_gemm( - int_processed_data, x_int_data, std::vector{m, n}); - subbyte_transpose_impl( - int_processed_2_data, int_processed_data, std::vector{m, n}); - interleave_column_major_tensor( - out_data, int_processed_2_data, std::vector{m, n}); - add_bias_and_interleave_inplace(out_data, num); + if (arch == 70) { + // Note(Zhengzekang): In sm70, we only need RowMajor layout, just add bias + // to make it unsigned. + add_bias_and_interleave_inplace(x_int_data, num); + phi::Copy(dev_ctx, x_int, dev_ctx.GetPlace(), false, out); + } else if (arch == 80) { + permute_B_rows_for_mixed_gemm( + int_processed_data, x_int_data, std::vector{m, n}); + subbyte_transpose_impl( + int_processed_2_data, int_processed_data, std::vector{m, n}); + interleave_column_major_tensor( + out_data, int_processed_2_data, std::vector{m, n}); + add_bias_and_interleave_inplace(out_data, num); + } } } @@ -78,14 +98,15 @@ template void WeightQuantizeKernel(const Context& dev_ctx, const DenseTensor& x, const std::string& algo, + const int32_t arch, DenseTensor* out, DenseTensor* scale) { dev_ctx.template Alloc(out); dev_ctx.template Alloc(scale); if (algo == "weight_only_int8" || algo == "llm.int8") { - quant_compute(dev_ctx, x, out, scale, algo); + quant_compute(dev_ctx, x, out, scale, algo, arch); } else if (algo == "weight_only_int4") { - quant_compute(dev_ctx, x, out, scale, algo); + quant_compute(dev_ctx, x, out, scale, algo, arch); } else { phi::errors::Unimplemented( "The algo must be in ['weight_only_int8', 'weight_only_int4', " diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h index 8c09a73f0cd64..79e91546d008f 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -106,7 +106,8 @@ static bool is_valid_split_k_factor(const int64_t m, static std::vector get_candidate_tiles( const bool is_weight_only, const bool is_weight_only_encoder, - const bool simt_configs_only) { + const bool simt_configs_only, + const int sm) { std::vector simt_configs{ CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; @@ -116,11 +117,29 @@ static std::vector get_candidate_tiles( CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64, }; - std::vector quant_B_configs{ + std::vector quant_B_configs_sm70{ CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, }; + std::vector quant_B_configs_sm80{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + + std::vector quant_B_configs; + switch (sm) { + case 80: + quant_B_configs = quant_B_configs_sm80; + break; + case 75: + case 70: + quant_B_configs = quant_B_configs_sm70; + break; + default: + quant_B_configs = quant_B_configs_sm70; + break; + } + std::vector encoder_quant_B_configs{ CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64 // CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64 @@ -138,7 +157,7 @@ static std::vector get_candidate_configs( const bool is_weight_only_encoder, const bool simt_configs_only) { std::vector tiles = get_candidate_tiles( - is_weight_only, is_weight_only_encoder, simt_configs_only); + is_weight_only, is_weight_only_encoder, simt_configs_only, sm); std::vector candidate_configs; const int min_stages = 2; diff --git a/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu index 7ebe0c983a344..cd9db409792d0 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu @@ -32,8 +32,15 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx, const DenseTensor& weight_scale, const DenseTensor& out_grad, const std::string& weight_dtype, + const int32_t arch, DenseTensor* x_grad) { #if defined(PADDLE_WITH_CUTLASS) + PADDLE_ENFORCE_EQ( + arch, + 80, + phi::errors::InvalidArgument( + "Currently weightonly linear grad only support arch = 80. ")); + int n = weight_scale.dims()[0]; int k = weight.dims()[1]; dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu index 0d2ab397ad130..9933b46457480 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -30,7 +30,18 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, const paddle::optional& bias, const DenseTensor& weight_scale, const std::string& weight_dtype, + const int32_t arch, DenseTensor* out) { +#if defined(PADDLE_WITH_CUTLASS) + PADDLE_ENFORCE_EQ( + ((arch == 80) || (arch == 70)), + true, + phi::errors::InvalidArgument("Currently, arch only support 70, 80.")); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "Please compile with cutlass to make cutlass available")); +#endif + dev_ctx.template Alloc(out); const T* x_data = x.data(); const int8_t* weight_data = weight.data(); @@ -43,8 +54,13 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, int k = w_dims[1]; int m = x.numel() / k; - // m > 1: run gemm - if (m > 1 || weight_dtype == "int4") { + // m > 1: run gemm. + if (m > 1 || weight_dtype == "int4" || (arch == 70)) { +/* +Note(Zhengzekang): +If using arch = 70, we always dispatch to weightonly Gemm, +we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. +*/ #if defined(PADDLE_WITH_CUTLASS) if (weight_dtype == "int8") { auto mixed_gemm_runner = diff --git a/paddle/phi/kernels/weight_only_linear_grad_kernel.h b/paddle/phi/kernels/weight_only_linear_grad_kernel.h index 518ef43c98d0f..af05059c488f3 100644 --- a/paddle/phi/kernels/weight_only_linear_grad_kernel.h +++ b/paddle/phi/kernels/weight_only_linear_grad_kernel.h @@ -26,6 +26,7 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx, const DenseTensor& weight_scale, const DenseTensor& out_grad, const std::string& weight_dtype, + const int32_t arch, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/weight_only_linear_kernel.h b/paddle/phi/kernels/weight_only_linear_kernel.h index 4e0de2ec9a645..17037fb531f06 100644 --- a/paddle/phi/kernels/weight_only_linear_kernel.h +++ b/paddle/phi/kernels/weight_only_linear_kernel.h @@ -25,5 +25,6 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, const paddle::optional& bias, const DenseTensor& weight_scale, const std::string& weight_dtype, + const int32_t arch, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/weight_quantize_kernel.h b/paddle/phi/kernels/weight_quantize_kernel.h index ba4277e84e637..b906e68a40338 100644 --- a/paddle/phi/kernels/weight_quantize_kernel.h +++ b/paddle/phi/kernels/weight_quantize_kernel.h @@ -22,6 +22,7 @@ template void WeightQuantizeKernel(const Context& dev_ctx, const DenseTensor& x, const std::string& algo, + const int32_t arch, DenseTensor* out, DenseTensor* scale); diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 862dfcdf3d1b4..e783de05fe77d 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _C_ops +from paddle import _C_ops, version from paddle.base.data_feeder import check_dtype from paddle.base.framework import convert_np_dtype_to_dtype_ +from paddle.device.cuda import get_device_capability from paddle.framework import ( LayerHelper, in_dynamic_mode, @@ -22,7 +23,20 @@ ) -def weight_quantize(x, algo="weight_only_int8"): +def _get_arch_info(): + # Get SMVersion from device. + cuda_version = version.cuda() + if cuda_version is not None and cuda_version != 'False': + major, minor = get_device_capability() + arch = int(major * 10 + minor) + return arch + else: + raise ValueError( + "Paddle is not compiled with CUDA, we cannot get SMVersion from device, please try to compile Paddle with CUDA" + ) + + +def weight_quantize(x, algo="weight_only_int8", arch=None): """ Quantization function for weight_only and llm.int8's weight. @@ -30,6 +44,7 @@ def weight_quantize(x, algo="weight_only_int8"): x (Tensor): The input Tensor to be quantized, the data type is float16 or bfloat16. algo (str): The algo that is x will be apply, must be one of 'weight_only_int8', 'weight_only_int4' and 'llm.int8', default: 'weight_only_int8'. + arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70, if you do not assign arch, we will get arch from your device, default: None. Returns: out (Tensor): The Tensor which is the quantitative results, the data type is int8, the shape is transposition of x. @@ -49,9 +64,15 @@ def weight_quantize(x, algo="weight_only_int8"): >>> print(scale.shape) [32] """ + if arch is None: + arch = _get_arch_info() + + assert ( + arch == 70 or arch == 80 + ), "Currently weight_quantize only support SM70/80. " if in_dynamic_mode(): - return _C_ops.weight_quantize(x, algo) + return _C_ops.weight_quantize(x, algo, arch) else: type = "weight_quantize" helper = LayerHelper(type, **locals()) @@ -62,7 +83,7 @@ def weight_quantize(x, algo="weight_only_int8"): type=type, inputs={"x": x}, outputs={'out': out, "scale": scale}, - attrs={"algo": algo}, + attrs={"algo": algo, "arch": arch}, ) return (out, scale) @@ -114,11 +135,7 @@ def weight_dequantize(x, scale, algo="weight_only_int8", out_dtype='float16'): def weight_only_linear( - x, - weight, - bias=None, - weight_scale=None, - weight_dtype="int8", + x, weight, bias=None, weight_scale=None, weight_dtype="int8", arch=None ): """ Applies matrix multiplication of two tensors and then bias addition if provided. @@ -131,6 +148,7 @@ def weight_only_linear( be performed. Otherwise, The bias is added to the matrix multiplication result. weight_scale (Tensor|None): The input scale Tensor Provided to weight for dequantization. Its rank must be 1. weight_dtype(str): The dtype of weight Tensor, must be one of 'int8', 'int4', Defaulted to 'int8'. + arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70, if you do not assign arch, we will get arch from your device, default: None. Returns: Tensor: the output Tensor, the data type is the same as that of x. @@ -150,9 +168,16 @@ def weight_only_linear( ... print(out.shape) [1, 2, 32] """ + if arch is None: + arch = _get_arch_info() + + assert ( + arch == 70 or arch == 80 + ), "Currently weight_quantize only support SM70/80. " + if in_dynamic_mode(): out = _C_ops.weight_only_linear( - x, weight, bias, weight_scale, weight_dtype + x, weight, bias, weight_scale, weight_dtype, arch ) return out else: @@ -170,7 +195,7 @@ def weight_only_linear( } if bias is not None: inputs["bias"] = [bias] - attrs = {'weight_dtype': weight_dtype} + attrs = {'weight_dtype': weight_dtype, 'arch': arch} out = helper.create_variable_for_type_inference(dtype)