Skip to content

Commit

Permalink
fix: handle quantization of conv weights while loading
Browse files Browse the repository at this point in the history
  • Loading branch information
ebraraktas committed Mar 23, 2024
1 parent ac99886 commit bfb6261
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
25 changes: 23 additions & 2 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,29 @@ namespace ctranslate2 {

// Convert "weight" variables to the expected compute type.
// Other float variables (e.g. biases) may be converted to another float type.
if (is_quantizable(name))
ensure_dtype(name, variable, weight_dtype);
if (is_quantizable(name)) {
auto variable_weight_dtype = weight_dtype;
// For conv layer, we need to reshape to ensure dtype as its weights are 3D.
auto is_conv = name.find("conv") != std::string::npos;
auto kernel_size = -1;
if (is_conv) {
kernel_size = variable.dim(2);
variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)});
// For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype.
if (device == Device::CUDA
#ifdef CT2_WITH_DNNL
|| true
#endif
) {
variable_weight_dtype = float_dtype;
}
}
ensure_dtype(name, variable, variable_weight_dtype);
// Undo reshape for conv weights
if (is_conv) {
variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size});
}
}
else if (is_convertible(variable, name)
&& is_float_type(variable.dtype())
&& variable.dtype() != float_dtype)
Expand Down
3 changes: 1 addition & 2 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ namespace ctranslate2 {
}

bool WhisperModel::is_quantizable(const std::string& variable_name) const {
return (Model::is_quantizable(variable_name)
&& variable_name.find("conv") == std::string::npos);
return Model::is_quantizable(variable_name);
}

bool WhisperModel::is_linear_weight(const std::string& variable_name) const {
Expand Down

0 comments on commit bfb6261

Please sign in to comment.