Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weight_quantize/weight_only_linear support Volta Arch #58082

Merged
merged 12 commits into from
Nov 22, 2023
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2584,8 +2584,8 @@
no_need_buffer : input

- backward_op : weight_only_linear_grad
forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype) -> Tensor(out)
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype)
forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch) -> Tensor(out)
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype, int arch)
output : Tensor(x_grad)
infer_meta :
func : WeightOnlyLinearGradInferMeta
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2814,7 +2814,7 @@
data_type : out_dtype

- op : weight_only_linear
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype)
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80)
output : Tensor(out)
infer_meta :
func : WeightOnlyLinearInferMeta
Expand All @@ -2825,7 +2825,7 @@
backward: weight_only_linear_grad

- op : weight_quantize
args : (Tensor x, str algo="weight_only_int8")
args : (Tensor x, str algo = "weight_only_int8", int arch = 80)
output : Tensor(out), Tensor(scale)
infer_meta :
func : WeightQuantizeInferMeta
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,13 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x,
const MetaTensor& weight_scale,
const MetaTensor& out_grad,
const std::string& weight_dtype,
const int32_t arch,
MetaTensor* x_grad) {
PADDLE_ENFORCE_EQ(
arch,
80,
phi::errors::InvalidArgument(
"Currently weightonly linear grad only support arch = 80. "));
x_grad->set_dims(x.dims());
x_grad->set_dtype(x.dtype());
}
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x,
const MetaTensor& weight_scale,
const MetaTensor& out_grad,
const std::string& weight_dtype,
const int32_t arch,
MetaTensor* x_grad);

void YoloLossGradInferMeta(const MetaTensor& x,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3858,6 +3858,7 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& weight_scale,
const std::string& weight_dtype,
const int32_t arch,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& weight_scale,
const std::string& weight_dtype,
const int32_t arch,
MetaTensor* out);

void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5118,8 +5118,14 @@ void UnStackInferMeta(const MetaTensor& x,

void WeightQuantizeInferMeta(const MetaTensor& x,
const std::string& algo,
const int32_t arch,
MetaTensor* out,
MetaTensor* scale) {
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 70)),
true,
phi::errors::InvalidArgument("Currently, arch only support 70, 80."));

auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ void QuantizeXPUInferMeta(const MetaTensor& x,

void WeightQuantizeInferMeta(const MetaTensor& x,
const std::string& algo,
const int32_t arch,
MetaTensor* out,
MetaTensor* scale);

Expand Down
43 changes: 32 additions & 11 deletions paddle/phi/kernels/cpu/weight_quantize_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ void quant_compute(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
DenseTensor* scale,
const std::string& algo) {
const std::string& algo,
const int32_t arch) {
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 70)),
true,
phi::errors::InvalidArgument("Currently, arch only support 70, 80."));

const auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
Expand All @@ -43,7 +49,14 @@ void quant_compute(const DeviceContext& dev_ctx,
float* scale_data = scale->data<float>();

DenseTensor x_int(out->type());
x_int.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
if (arch == 80) {
x_int.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
} else {
// phi::Copy may change tensor meta info, here we transpose the quanted
// data's shape.
x_int.Resize({static_cast<int64_t>(n), static_cast<int64_t>(m)});
}

dev_ctx.template Alloc<D>(&x_int);
D* x_int_data = x_int.data<D>();

Expand All @@ -64,28 +77,36 @@ void quant_compute(const DeviceContext& dev_ctx,
funcs::Transpose<DeviceContext, int8_t, 2> trans;
trans(dev_ctx, x_int, out, axis);
} else {
permute_B_rows_for_mixed_gemm<bits>(
int_processed_data, x_int_data, std::vector<size_t>{m, n});
subbyte_transpose_impl<bits>(
int_processed_2_data, int_processed_data, std::vector<size_t>{m, n});
interleave_column_major_tensor<bits>(
out_data, int_processed_2_data, std::vector<size_t>{m, n});
add_bias_and_interleave_inplace<bits>(out_data, num);
if (arch == 70) {
// Note(Zhengzekang): In sm70, we only need RowMajor layout, just add bias
// to make it unsigned.
add_bias_and_interleave_inplace<bits>(x_int_data, num);
phi::Copy(dev_ctx, x_int, dev_ctx.GetPlace(), false, out);
} else if (arch == 80) {
permute_B_rows_for_mixed_gemm<bits>(
int_processed_data, x_int_data, std::vector<size_t>{m, n});
subbyte_transpose_impl<bits>(
int_processed_2_data, int_processed_data, std::vector<size_t>{m, n});
interleave_column_major_tensor<bits>(
out_data, int_processed_2_data, std::vector<size_t>{m, n});
add_bias_and_interleave_inplace<bits>(out_data, num);
}
}
}

template <typename T, typename Context>
void WeightQuantizeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& algo,
const int32_t arch,
DenseTensor* out,
DenseTensor* scale) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
if (algo == "weight_only_int8" || algo == "llm.int8") {
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, algo);
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, algo, arch);
} else if (algo == "weight_only_int4") {
quant_compute<Context, T, int8_t, 4>(dev_ctx, x, out, scale, algo);
quant_compute<Context, T, int8_t, 4>(dev_ctx, x, out, scale, algo, arch);
} else {
phi::errors::Unimplemented(
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ static bool is_valid_split_k_factor(const int64_t m,
static std::vector<CutlassTileConfig> get_candidate_tiles(
const bool is_weight_only,
const bool is_weight_only_encoder,
const bool simt_configs_only) {
const bool simt_configs_only,
const int sm) {
std::vector<CutlassTileConfig> simt_configs{
CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};

Expand All @@ -116,11 +117,29 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
};

std::vector<CutlassTileConfig> quant_B_configs{
std::vector<CutlassTileConfig> quant_B_configs_sm70{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
};
std::vector<CutlassTileConfig> quant_B_configs_sm80{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};

std::vector<CutlassTileConfig> quant_B_configs;
switch (sm) {
case 80:
quant_B_configs = quant_B_configs_sm80;
break;
case 75:
case 70:
quant_B_configs = quant_B_configs_sm70;
break;
default:
quant_B_configs = quant_B_configs_sm70;
break;
}

std::vector<CutlassTileConfig> encoder_quant_B_configs{
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64
// CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64
Expand All @@ -138,7 +157,7 @@ static std::vector<CutlassGemmConfig> get_candidate_configs(
const bool is_weight_only_encoder,
const bool simt_configs_only) {
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(
is_weight_only, is_weight_only_encoder, simt_configs_only);
is_weight_only, is_weight_only_encoder, simt_configs_only, sm);

std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = 2;
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx,
const DenseTensor& weight_scale,
const DenseTensor& out_grad,
const std::string& weight_dtype,
const int32_t arch,
DenseTensor* x_grad) {
#if defined(PADDLE_WITH_CUTLASS)
PADDLE_ENFORCE_EQ(
arch,
80,
phi::errors::InvalidArgument(
"Currently weightonly linear grad only support arch = 80. "));

int n = weight_scale.dims()[0];
int k = weight.dims()[1];
dev_ctx.template Alloc<T>(x_grad);
Expand Down
20 changes: 18 additions & 2 deletions paddle/phi/kernels/gpu/weight_only_linear_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const std::string& weight_dtype,
const int32_t arch,
DenseTensor* out) {
#if defined(PADDLE_WITH_CUTLASS)
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 70)),
true,
phi::errors::InvalidArgument("Currently, arch only support 70, 80."));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Please compile with cutlass to make cutlass available"));
#endif

dev_ctx.template Alloc<T>(out);
const T* x_data = x.data<T>();
const int8_t* weight_data = weight.data<int8_t>();
Expand All @@ -43,8 +54,13 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
int k = w_dims[1];
int m = x.numel() / k;

// m > 1: run gemm
if (m > 1 || weight_dtype == "int4") {
// m > 1: run gemm.
if (m > 1 || weight_dtype == "int4" || (arch == 70)) {
/*
Note(Zhengzekang):
If using arch = 70, we always dispatch to weightonly Gemm,
we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor.
*/
#if defined(PADDLE_WITH_CUTLASS)
if (weight_dtype == "int8") {
auto mixed_gemm_runner =
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/weight_only_linear_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx,
const DenseTensor& weight_scale,
const DenseTensor& out_grad,
const std::string& weight_dtype,
const int32_t arch,
DenseTensor* x_grad);

} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/kernels/weight_only_linear_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const std::string& weight_dtype,
const int32_t arch,
DenseTensor* out);
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/kernels/weight_quantize_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ template <typename T, typename Context>
void WeightQuantizeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& algo,
const int32_t arch,
DenseTensor* out,
DenseTensor* scale);

Expand Down
Loading