Skip to content

Commit

Permalink
Add a quantized variant of whisper (huggingface#1017)
Browse files Browse the repository at this point in the history
* Add the quantized-whisper model.

* Quantized the whisper model.

* Adapt the whisper example to handle quantization.

* Add the quantized flag.

* Load the proper weights.
  • Loading branch information
LaurentMazare authored Oct 2, 2023
1 parent 263a172 commit e04c789
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 62 deletions.
120 changes: 82 additions & 38 deletions candle-examples/examples/whisper/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,48 @@ use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;

mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, model};
use model::{Config, Whisper};
use candle_transformers::models::whisper::{self as m, audio, Config};

pub enum Model {
Normal(m::model::Whisper),
Quantized(m::quantized_model::Whisper),
}

// Maybe we should use some traits rather than doing the dispatch for all these.
impl Model {
pub fn config(&self) -> &Config {
match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}

pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.encoder.forward(x, flush),
Self::Quantized(m) => m.encoder.forward(x, flush),
}
}

pub fn decoder_forward(
&mut self,
x: &Tensor,
xa: &Tensor,
flush: bool,
) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.forward(x, xa, flush),
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
}
}

pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}

#[allow(dead_code)]
#[derive(Debug, Clone)]
Expand All @@ -41,7 +81,7 @@ struct Segment {
}

struct Decoder {
model: Whisper,
model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
Expand All @@ -60,7 +100,7 @@ struct Decoder {
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Whisper,
model: Model,
tokenizer: Tokenizer,
seed: u64,
device: &Device,
Expand All @@ -72,9 +112,9 @@ impl Decoder {
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| {
if model.config.suppress_tokens.contains(&i)
if model.config().suppress_tokens.contains(&i)
|| timestamps && i == no_timestamps_token
{
f32::NEG_INFINITY
Expand Down Expand Up @@ -109,11 +149,11 @@ impl Decoder {

fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
let audio_features = model.encoder_forward(mel, true)?;
if self.verbose {
println!("audio features: {:?}", audio_features.dims());
}
let sample_len = model.config.max_target_positions / 2;
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
Expand All @@ -133,21 +173,20 @@ impl Decoder {
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;

// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}

let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder
.final_linear(&ys.i((..1, seq_len - 1..))?)?
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
Expand Down Expand Up @@ -176,7 +215,7 @@ impl Decoder {
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break;
}
sum_logprob += prob.ln();
Expand Down Expand Up @@ -333,6 +372,7 @@ impl WhichModel {
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}

fn model_and_revision(&self) -> (&'static str, &'static str) {
match self {
Self::Tiny => ("openai/whisper-tiny", "main"),
Expand Down Expand Up @@ -382,6 +422,9 @@ struct Args {
#[arg(long)]
tracing: bool,

#[arg(long)]
quantized: bool,

/// Language.
#[arg(long)]
language: Option<String>,
Expand Down Expand Up @@ -413,31 +456,21 @@ fn main() -> Result<()> {
None
};
let device = candle_examples::device(args.cpu)?;
let (default_model, default_revision) = args.model.model_and_revision();
let (default_model, default_revision) = if args.quantized {
("lmz/candle-whisper", "main")
} else {
args.model.model_and_revision()
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let path = std::path::PathBuf::from(default_model.clone());
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};

let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
let mut config_filename = path.clone();
config_filename.push("config.json");
let mut tokenizer_filename = path.clone();
tokenizer_filename.push("tokenizer.json");
let mut model_filename = path;
model_filename.push("model.safetensors");
(
config_filename,
tokenizer_filename,
model_filename,
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
)
} else {
let (config_filename, tokenizer_filename, weights_filename, input) = {
let api = Api::new()?;
let dataset = api.dataset("Narsil/candle-examples".to_string());
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
Expand All @@ -451,12 +484,17 @@ fn main() -> Result<()> {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
dataset.get("samples_jfk.wav")?
};
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
sample,
)
let config = if args.quantized {
repo.get("config-tiny.json")?
} else {
repo.get("config.json")?
};
let model = if args.quantized {
repo.get("model-tiny-q40.gguf")?
} else {
repo.get("model.safetensors")?
};
(config, repo.get("tokenizer.json")?, model, sample)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

Expand All @@ -481,10 +519,16 @@ fn main() -> Result<()> {
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());

let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?;
let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config)?)
};

let language_token = match (args.model.is_multilingual(), args.language) {
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
Expand Down
19 changes: 13 additions & 6 deletions candle-examples/examples/whisper/multilingual.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::Whisper;
use candle::{IndexOp, Result, Tensor, D};
use tokenizers::Tokenizer;

Expand Down Expand Up @@ -105,20 +104,28 @@ const LANGUAGES: [(&str, &str); 99] = [
];

/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
pub fn detect_language(
model: &mut super::Model,
tokenizer: &Tokenizer,
mel: &Tensor,
) -> Result<u32> {
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
Expand Down
20 changes: 20 additions & 0 deletions candle-transformers/src/models/whisper/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
pub mod audio;
pub mod model;
pub mod quantized_model;

use serde::Deserialize;

// The names in comments correspond to the original implementation:
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_mel_bins: usize, // n_mels
pub max_source_positions: usize, // n_audio_ctx
pub d_model: usize, // n_audio_state
pub encoder_attention_heads: usize, // n_audio_head
pub encoder_layers: usize, // n_audio_layer
pub vocab_size: usize, // n_vocab
pub max_target_positions: usize, // n_text_ctx
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
pub suppress_tokens: Vec<u32>,
}

pub const DTYPE: candle::DType = candle::DType::F32;

Expand Down
19 changes: 1 addition & 18 deletions candle-transformers/src/models/whisper/model.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
use super::Config;
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;

// The names in comments correspond to the original implementation:
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_mel_bins: usize, // n_mels
pub max_source_positions: usize, // n_audio_ctx
pub d_model: usize, // n_audio_state
pub encoder_attention_heads: usize, // n_audio_head
pub encoder_layers: usize, // n_audio_layer
pub vocab_size: usize, // n_vocab
pub max_target_positions: usize, // n_text_ctx
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
pub suppress_tokens: Vec<u32>,
}

fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Expand Down
Loading

0 comments on commit e04c789

Please sign in to comment.