Skip to content

Commit

Permalink
Merge pull request #2 from ckaznable/vad
Browse files Browse the repository at this point in the history
feat: use sileo-vad v3.1 models to split voice in audio
  • Loading branch information
ckaznable authored May 19, 2023
2 parents 479f07b + 11a55be commit 43e2299
Show file tree
Hide file tree
Showing 10 changed files with 1,333 additions and 98 deletions.
945 changes: 936 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ keywords = ["youtube", "cli", "whisper"]

[dependencies]
mpeg2ts = "0.2"
wait-timeout = "0.2.0"
whisper-rs = { git = "https://github.com/ckaznable-archive/whisper-rs", branch = "cuda", features = ["cuda"] }
whisper-rs = { git = "https://github.com/ckaznable-archive/whisper-rs", branch = "cuda" }
symphonia = { version = "0.5.2", features=["aac", "mpa"] }
clap = { version = "4.2.5", features = ["derive"] }
clap = { version = "4.2.7", features = ["derive"] }
ringbuf = "0.3.3"

tract-onnx = "0.17.9"
rubato = "0.12.0"

[profile.release]
opt-level = 'z' # Optimize for size
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# Youtube Text Live Streaming in CLI

This project is currently a work in progress (WIP). It aims to enable streaming YouTube videos and converting the audio into text, displaying it in the command line interface (CLI). The project utilizes the [whisper-rs](https://github.com/tazz4843/whisper-rs), [whisper.cpp](https://github.com/ggerganov/whisper.cpp) and [yt-dlp](https://github.com/yt-dlp/yt-dlp) libraries and is being developed in Rust.
This project is currently a work in progress (WIP). It aims to enable streaming YouTube videos and converting the audio into text, displaying it in the command line interface (CLI). The project utilizes the [whisper-rs](https://github.com/tazz4843/whisper-rs), [whisper.cpp](https://github.com/ggerganov/whisper.cpp), [silero-vad](https://github.com/snakers4/silero-vad) and [yt-dlp](https://github.com/yt-dlp/yt-dlp) libraries and is being developed in Rust.

Please note that the project is still under active development, and certain features or functionalities may be incomplete or subject to change. Contributions, suggestions, and bug reports are welcome.

## Requirement

- [yt-dlp](https://github.com/yt-dlp/yt-dlp)

Using yt-dlp for youtube streaming
This project using yt-dlp for youtube streaming

- whisper model

This project using whisper for ASR(Automatic Speech Recognition)

then you can following [whisper.cpp](https://github.com/ggerganov/whisper.cpp) README to download whisper models

Suggested use of base or small model

## Usage

Expand Down
5 changes: 5 additions & 0 deletions models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Models

- silero_vad.onnx

model use from [silero-vad](https://github.com/snakers4/silero-vad)
Binary file added models/silero_vad.onnx
Binary file not shown.
32 changes: 28 additions & 4 deletions src/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use mpeg2ts::{
es::StreamType,
ts::{ReadTsPacket, TsPacketReader, TsPayload},
};
use rubato::{InterpolationParameters, InterpolationType, Resampler, SincFixedIn, WindowFunction};
use symphonia::core::{
audio::AudioBuffer,
codecs::{DecoderOptions, CODEC_TYPE_NULL},
Expand All @@ -13,6 +14,8 @@ use symphonia::core::{
probe::Hint,
};

pub const YOUTUBE_TS_SAMPLE_RATE: u16 = 22050;

#[derive(Debug)]
pub enum Error {
Format,
Expand Down Expand Up @@ -182,16 +185,37 @@ fn extract_ts_audio(raw: &[u8]) -> Vec<u8> {
}
}
}
Ok(None) => {
break
}
_ => ()
Ok(None) => break,
_ => (),
}
}

data
}

pub fn resample_to_16k(input: &[f32], input_sample_rate: f64) -> Vec<f32> {
let params = InterpolationParameters {
sinc_len: 256,
f_cutoff: 0.95,
interpolation: InterpolationType::Linear,
oversampling_factor: 256,
window: WindowFunction::BlackmanHarris2,
};

let mut resampler = SincFixedIn::<f32>::new(
16000. / input_sample_rate,
2.0,
params,
input.len(),
1,
)
.unwrap();

let waves_in = vec![input.to_vec()];
let mut output = resampler.process(&waves_in, None).unwrap();
output.remove(0)
}

#[cfg(test)]
mod tests {
use std::{fs::File, io::Read};
Expand Down
234 changes: 166 additions & 68 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
use audio::{YOUTUBE_TS_SAMPLE_RATE, resample_to_16k};
use clap::Parser;
use ringbuf::LocalRb;
use ringbuf::{Consumer, LocalRb, SharedRb, HeapRb, Producer, Rb};
use speech::{SpeechConfig, WhisperPayload};
use vad::{VadState, WINDOW_SIZE_SAMPLES, split_audio_data_with_window_size};
use std::{
error::Error,
ffi::c_int,
io::{BufRead, BufReader},
mem::MaybeUninit,
process::{Child, ChildStdout, Command, Stdio},
time::{Duration, Instant},
sync::{
mpsc::{self, Receiver, SyncSender},
Arc,
},
thread::{self, JoinHandle},
time::Instant,
};
use wait_timeout::ChildExt;
use whisper_rs::WhisperContext;

use util::Log;

mod audio;
mod speech;
mod util;
mod vad;

/// Simple program to greet a person
#[derive(Parser, Debug)]
type F32Consumer = Consumer<f32, Arc<SharedRb<f32, Vec<MaybeUninit<f32>>>>>;
type SegmentProducer = Producer<vad::VadSegment, Arc<SharedRb<vad::VadSegment, Vec<MaybeUninit<vad::VadSegment>>>>>;
type SegmentConsumer = Consumer<vad::VadSegment, Arc<SharedRb<vad::VadSegment, Vec<MaybeUninit<vad::VadSegment>>>>>;

enum ThreadState {
End,
Sync,
}

#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// path of whisper model
Expand Down Expand Up @@ -46,67 +62,28 @@ fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
let logger = Log::new(args.verbose);

// 1Mb buffer
let rb_size = 1024 * 1024;
let mut rb = LocalRb::<u8, Vec<_>>::new(rb_size);
let (mut ts_prod, mut ts_cons) = rb.split_ref();

let ctx = WhisperContext::new(&args.model).expect("failed to load model");
let mut state = ctx.create_state().expect("failed to create state");

let mut last_processed = Instant::now();
let mut streaming_time = 0.0f64;
let collect_time = Duration::from_secs(5);

let (mut child, stdout) = get_yt_dlp_stdout(&args.url);
let mut reader = BufReader::new(stdout);

let mut process = || {
if ts_cons.is_empty() {
logger.error(audio::Error::Empty.to_string());
return;
}
// local buffer for ts file in 1Mb
let rb_size = 1024 * 1024;
let rb = LocalRb::<u8, Vec<_>>::new(rb_size);
let (mut prod, mut cons) = rb.split();

let data = ts_cons.pop_iter().collect::<Vec<u8>>();
logger.verbose(format!(
"Reading {}kb data from yt-dlp",
data.len() / 1024
));
// shared buffer f32 transformed pcm in 30s audio data
let rb_size = YOUTUBE_TS_SAMPLE_RATE as usize * 30;
let rb = HeapRb::<f32>::new(rb_size);
let (mut ts_prod, ts_cons) = rb.split();

match audio::get_audio_data(&data) {
Ok((audio_data, dur)) => {
logger.verbose(format!(
"Get {}kb audio data and duration {:.3}s from ts",
audio_data.len() / 1024,
dur
));
// shared buffer for vad output in 20 segment
let rb = HeapRb::<vad::VadSegment>::new(20);
let (vad_prod, vad_cons ) = rb.split();

let config = SpeechConfig::new(args.threads as c_int, Some(&args.lang));
let mut payload: WhisperPayload = WhisperPayload::new(&audio_data, config);
let running_calc = Instant::now();
let (tx, rx) = mpsc::sync_channel::<ThreadState>(1);
let (vad_tx, vad_rx) = mpsc::sync_channel::<ThreadState>(1);

let segment_time = (streaming_time * 1000.0) as i64;
streaming_time += dur;

speech::process(
&mut state,
&mut payload,
&mut |segment, start| {
println!(
"[{}] {}",
util::format_timestamp_to_time(segment_time + start),
segment
);
},
);

logger.verbose(format!("whisper process time: {}s", running_calc.elapsed().as_secs()));
}
Err(err) => {
logger.error(err.to_string());
}
}
};
let handle_vad = evoke_vad_thread(args.clone(), (vad_tx.clone(), rx), (vad_prod, ts_cons));
let handle_whisper = evoke_whisper_thread(args, vad_rx, vad_cons);

loop {
let buf = reader.fill_buf()?;
Expand All @@ -115,20 +92,42 @@ fn main() -> Result<(), Box<dyn Error>> {
}

let len = buf.len();
ts_prod.push_slice(buf);

if last_processed.elapsed() >= collect_time || ts_prod.is_full() {
process();
last_processed = Instant::now();
prod.push_slice(buf);

if prod.is_full() || prod.len() > 128000 {
let data = cons.pop_iter().collect::<Vec<u8>>();
logger.verbose(format!("Reading {}kb data from yt-dlp", data.len() / 1024));

match audio::get_audio_data(&data) {
Ok((audio_data, dur)) => {
logger.verbose(format!(
"Get {}kb audio data and duration {:.3}s from ts",
audio_data.len() / 1024,
dur
));

ts_prod.push_slice(&audio_data);
if let Err(e) = tx.try_send(ThreadState::Sync) {
match e {
mpsc::TrySendError::Full(_) => (),
mpsc::TrySendError::Disconnected(_) => break,
}
}
}
Err(err) => {
logger.error(err.to_string());
}
}
}

reader.consume(len);
}

process();
child
.wait_timeout(Duration::from_secs(3))
.expect("failed to wait on yt-dlp");
tx.send(ThreadState::End).unwrap();
vad_tx.send(ThreadState::End).unwrap();
child.kill().expect("failed to kill yt-dlp process");
handle_vad.join().unwrap();
handle_whisper.join().unwrap();

Ok(())
}
Expand All @@ -149,3 +148,102 @@ fn get_yt_dlp_stdout(url: &str) -> (Child, ChildStdout) {

(child, stdout)
}

fn evoke_vad_thread(
args: Args,
channel: (SyncSender<ThreadState>, Receiver<ThreadState>),
rb: (SegmentProducer, F32Consumer),
) -> JoinHandle<()> {
let logger = Log::new(args.verbose);
let (tx, rx) = channel;
let (mut prod, mut cons) = rb;

thread::spawn(move || {
let mut vad_state = VadState::new().unwrap();
let mut rb = LocalRb::<f32, Vec<_>>::new(WINDOW_SIZE_SAMPLES);

while let Ok(ThreadState::Sync) = rx.recv() {
if cons.is_empty() {
logger.error("empty pcm data");
continue;
}

let data = cons.pop_iter().collect::<Vec<f32>>();
let mut data = resample_to_16k(&data, YOUTUBE_TS_SAMPLE_RATE as f64);

if rb.len() > 0 {
data.splice(0..0, rb.pop_iter().collect::<Vec<f32>>());
}

let (left, right) = split_audio_data_with_window_size(data);
if let Some(d) = right {
d.iter().for_each(|d| {
rb.push_overwrite(*d);
})
}

if let Some(data) = left {
let mut buf = vec![];

let running_calc = Instant::now();
data.chunks(WINDOW_SIZE_SAMPLES).for_each(|data| {
let _ = vad::vad(&mut vad_state, data.to_vec(), &mut buf);
});
logger.verbose(format!("vad process time: {}s, detect {} segment", running_calc.elapsed().as_secs(), buf.len()));

if !buf.is_empty() {
prod.push_iter(&mut buf.into_iter());
if let Err(e) = tx.try_send(ThreadState::Sync) {
match e {
mpsc::TrySendError::Full(_) => (),
mpsc::TrySendError::Disconnected(_) => break,
}
}
}
}
}
})
}

fn evoke_whisper_thread(
args: Args,
rx: Receiver<ThreadState>,
mut cons: SegmentConsumer,
) -> JoinHandle<()> {
let ctx = WhisperContext::new(&args.model).expect("failed to load model");
let logger = Log::new(args.verbose);

thread::spawn(move || {
let mut state = ctx.create_state().expect("failed to create state");
let mut streaming_time = 0.0f64;

while let Ok(ThreadState::Sync) = rx.recv() {
if cons.is_empty() {
logger.error(audio::Error::Empty.to_string());
continue;
}

cons.pop_iter().for_each(|segment| {
let config = SpeechConfig::new(args.threads as c_int, Some(&args.lang));
let mut payload: WhisperPayload = WhisperPayload::new(&segment.data, config);
let running_calc = Instant::now();

let segment_time = (streaming_time * 1000.0) as i64;
streaming_time += segment.duration as f64;

speech::process(&mut state, &mut payload, &mut |segment, start| {
println!(
"[{}] {}",
util::format_timestamp_to_time(segment_time + start),
segment
);
});

logger.verbose(format!(
"whisper process time: {}s",
running_calc.elapsed().as_secs()
));
});
}
})
}
Loading

0 comments on commit 43e2299

Please sign in to comment.