Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: conv1d quantization #1601

Merged
merged 5 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ namespace ctranslate2 {
const ops::Conv1D _conv_op;
const StorageView& _weight;
const StorageView* _bias;
const StorageView* _qscale;
};

}
Expand Down
17 changes: 11 additions & 6 deletions include/ctranslate2/ops/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ namespace ctranslate2 {
void operator()(const StorageView& input,
const StorageView& weight,
const StorageView& bias,
StorageView& output) const;
StorageView& output,
const StorageView* qscale = nullptr) const;

void operator()(const StorageView& input,
const StorageView& weight,
StorageView& output) const;
StorageView& output,
const StorageView* qscale = nullptr) const;

private:
dim_t _stride;
Expand All @@ -27,17 +29,20 @@ namespace ctranslate2 {
void operator()(const StorageView& input,
const StorageView& weight,
const StorageView* bias,
StorageView& output) const;
StorageView& output,
const StorageView* qscale) const;

template <Device D, typename T>
void compute(const StorageView& input,
const StorageView& weight,
const StorageView* bias,
StorageView& output) const;
StorageView& output,
const StorageView* qscale = nullptr) const;

void compute_with_gemm(const StorageView& input, const StorageView& weight, StorageView& output) const;
void compute_with_gemm(const StorageView& input, const StorageView& weight, StorageView& output,
const StorageView* qscale) const;

void im2col(const StorageView& input, StorageView& output, dim_t kernel_size) const;
void im2col_transposed(const StorageView& input, StorageView& output, dim_t kernel_size) const;
};

}
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def has_bias(self):
class Conv1DSpec(model_spec.LayerSpec):
def __init__(self):
self.weight = None
self.weight_scale = model_spec.OPTIONAL
self.bias = model_spec.OPTIONAL


Expand Down
8 changes: 8 additions & 0 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,20 @@ def _quantize(spec, name, value):
"int8_bfloat16",
):
value = value.to("float32").numpy()
# For conv1d layer we need to reshape to 2D before calculating scale
old_shape = None
if len(value.shape) == 3:
old_shape = value.shape
value = value.reshape(value.shape[0], -1)
amax = np.amax(np.absolute(value), axis=1)
amax[amax == 0] = 127.0
scale = 127.0 / amax
value *= np.expand_dims(scale, 1)
value = np.rint(value)
value = value.astype(np.int8)
# reshape back to old shape
if old_shape:
value = value.reshape(old_shape)
scale = NumpyVariable(scale)
value = NumpyVariable(value)
elif quantization in ("float16", "bfloat16", "float32"):
Expand Down
31 changes: 26 additions & 5 deletions src/cpu/kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,14 +572,35 @@ namespace ctranslate2 {

const auto amax = reduce_amax<TARGET_ISA>(x, depth);
const auto scale = (amax != 0.f ? int8_max / amax : 1.f);
using VecType = Vec<float, TARGET_ISA>;
const dim_t remaining = depth % VecType::width;
depth -= remaining;
auto vec_a_scale = VecType::load(scale);

if (shift_to_uint8) {
auto vec_int8_min = VecType::load(int8_min);
auto* dst = reinterpret_cast<uint8_t*>(y);
for (dim_t j = 0; j < depth; ++j)
dst[j] = round_func(x[j] * scale - int8_min);
for (dim_t j = 0; j < depth; j += VecType::width) {
auto v = VecType::load(x + j);
v = round_func(VecType::sub(VecType::mul(v, vec_a_scale), vec_int8_min));
VecType::convert_and_store(v, dst + j, VecType::width);
}
if (remaining) {
auto v = VecType::load(x + depth, remaining);
v = round_func(VecType::sub(VecType::mul(v, vec_a_scale), vec_int8_min));
VecType::convert_and_store(v, dst + depth, remaining);
}
} else {
for (dim_t j = 0; j < depth; ++j)
y[j] = round_func(x[j] * scale);
for (dim_t j = 0; j < depth; j += VecType::width) {
auto v = VecType::load(x + j);
v = round_func(VecType::mul(v, vec_a_scale));
VecType::convert_and_store(v, y + j, VecType::width);
}
if (remaining) {
auto v = VecType::load(x + depth, remaining);
v = round_func(VecType::mul(v, vec_a_scale));
VecType::convert_and_store(v, y + depth, remaining);
}
}

return scale;
Expand Down Expand Up @@ -612,7 +633,7 @@ namespace ctranslate2 {
bool shift_to_uint8,
bool round_before_cast) {
if (round_before_cast)
quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, std::nearbyintf);
quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, Vec<float, TARGET_ISA>::round);
else
quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, identity());
}
Expand Down
8 changes: 8 additions & 0 deletions src/cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ namespace ctranslate2 {
return a;
}

static inline float round(float a) {
return std::nearbyintf(a);
}

template<typename U>
static inline void convert_and_store(float v, U* a, dim_t count) {
*a = v;
}
};

template <typename T, CpuIsa ISA = CpuIsa::GENERIC>
Expand Down
11 changes: 11 additions & 0 deletions src/cpu/vec_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ namespace ctranslate2 {
return reduce_m256(a, max);
}

static inline value_type round(value_type a) {
return _mm256_round_ps(a, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
}

template<typename T>
static void convert_and_store(value_type v, T* a, dim_t count) {
auto i32 = _mm256_cvttps_epi32(v);
int32_t tmp[8];
_mm256_storeu_si256(reinterpret_cast<__m256i *>(tmp), i32);
std::copy(tmp, tmp + count, a);
}
};

}
Expand Down
13 changes: 13 additions & 0 deletions src/cpu/vec_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ namespace ctranslate2 {
return _mm512_reduce_max_ps(a);
}

static inline value_type round(value_type a) {
return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
}

static inline void convert_and_store(value_type v, int8_t* a, const dim_t count) {
auto i32 = _mm512_cvttps_epi32(v);
_mm512_mask_cvtsepi32_storeu_epi8(a, get_length_mask(count), i32);
}

static inline void convert_and_store(value_type v, uint8_t* a, const dim_t count) {
auto u32 = _mm512_cvttps_epu32(v);
_mm512_mask_cvtusepi32_storeu_epi8(a, get_length_mask(count), u32);
}
};

}
Expand Down
35 changes: 33 additions & 2 deletions src/cpu/vec_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,38 @@ namespace ctranslate2 {
return vmaxvq_f32(a);
}

};
static inline value_type round(value_type v) {
#ifdef __aarch64__
return vrndiq_f32(v);
#else
float temp[4] = {std::nearbyintf(v[0]), std::nearbyintf(v[1]), std::nearbyintf(v[2]), std::nearbyintf(v[3])};
return load(temp);
#endif
}

}
static inline void convert_and_store(value_type v, int8_t *a, dim_t count) {
//convert float32x4_t to int32x4_t
auto i32x4 = vcvtq_s32_f32(v);
//then convert to int16x4_t
auto i16x4 = vqmovn_s32(i32x4);
//finally convert to int8x4_t
auto i8x8 = vqmovn_s16(vcombine_s16(i16x4, vdup_n_s16(0)));
int8_t tmp[8];
vst1_s8(tmp, i8x8);
std::copy(tmp, tmp + count, a);
}

static inline void convert_and_store(value_type v, uint8_t *a, dim_t count) {
//convert float32x4_t to uint32x4_t
auto u32x4 = vcvtq_u32_f32(v);
//then convert to uint16x4_t
auto u16x4 = vqmovn_u32(u32x4);
//finally convert to uint8x8_t
auto u8x8 = vqmovn_u16(vcombine_u16(u16x4, vdup_n_u16(0)));
uint8_t tmp[8];
vst1_u8(tmp, u8x8);
std::copy(tmp, tmp + count, a);
}
};
}
}
7 changes: 4 additions & 3 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ namespace ctranslate2 {
dim_t dilation)
: _conv_op(stride, padding, dilation)
, _weight(model.get_variable(scope + "/weight"))
, _bias(model.get_variable_if_exists(scope + "/bias")) {
, _bias(model.get_variable_if_exists(scope + "/bias"))
, _qscale(model.get_variable_if_exists(scope + "/weight_scale")) {
}

DataType Conv1D::output_type() const {
Expand All @@ -454,9 +455,9 @@ namespace ctranslate2 {

void Conv1D::operator()(const StorageView& input, StorageView& output) const {
if (_bias)
_conv_op(input, _weight, *_bias, output);
_conv_op(input, _weight, *_bias, output, _qscale);
else
_conv_op(input, _weight, output);
_conv_op(input, _weight, output, _qscale);
}

}
Expand Down
25 changes: 23 additions & 2 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,29 @@ namespace ctranslate2 {

// Convert "weight" variables to the expected compute type.
// Other float variables (e.g. biases) may be converted to another float type.
if (is_quantizable(name))
ensure_dtype(name, variable, weight_dtype);
if (is_quantizable(name)) {
auto variable_weight_dtype = weight_dtype;
// For conv layer, we need to reshape to ensure dtype as its weights are 3D.
auto is_conv = name.find("conv") != std::string::npos;
auto kernel_size = -1;
if (is_conv) {
kernel_size = variable.dim(2);
variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)});
// For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype.
if (device == Device::CUDA
#ifdef CT2_WITH_DNNL
|| true
#endif
) {
variable_weight_dtype = float_dtype;
}
}
ensure_dtype(name, variable, variable_weight_dtype);
// Undo reshape for conv weights
if (is_conv) {
variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size});
}
}
else if (is_convertible(variable, name)
&& is_float_type(variable.dtype())
&& variable.dtype() != float_dtype)
Expand Down
3 changes: 1 addition & 2 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ namespace ctranslate2 {
}

bool WhisperModel::is_quantizable(const std::string& variable_name) const {
return (Model::is_quantizable(variable_name)
&& variable_name.find("conv") == std::string::npos);
return Model::is_quantizable(variable_name);
}

bool WhisperModel::is_linear_weight(const std::string& variable_name) const {
Expand Down
22 changes: 9 additions & 13 deletions src/ops/conv1d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,24 @@ namespace ctranslate2 {
void Conv1D::operator()(const StorageView& input,
const StorageView& weight,
const StorageView& bias,
StorageView& output) const {
operator()(input, weight, &bias, output);
StorageView& output,
const StorageView* qscale) const {
operator()(input, weight, &bias, output, qscale);
}

void Conv1D::operator()(const StorageView& input,
const StorageView& weight,
StorageView& output) const {
operator()(input, weight, nullptr, output);
StorageView& output,
const StorageView* qscale) const {
operator()(input, weight, nullptr, output, qscale);
}

void Conv1D::operator()(const StorageView& input,
const StorageView& weight,
const StorageView* bias,
StorageView& output) const {
StorageView& output,
const StorageView* qscale) const {
PROFILE("Conv1D");

if (input.dtype() != weight.dtype())
throw std::invalid_argument("Conv1D: input dtype is "
+ dtype_name(input.dtype())
+ " but expected dtype "
+ dtype_name(weight.dtype()));

const dim_t batch_size = input.dim(0);
const dim_t input_length = input.dim(2);
const dim_t out_channels = weight.dim(0);
Expand All @@ -47,7 +43,7 @@ namespace ctranslate2 {
output.resize({batch_size, out_channels, output_length});

DEVICE_AND_FLOAT_DISPATCH("Conv1D", input.device(), input.dtype(),
(compute<D, T>(input, weight, bias, output)));
(compute<D, T>(input, weight, bias, output, qscale)));
}

}
Expand Down
Loading
Loading