From dd954a9ede59302596c297cb21b6bc1d8b48cffa Mon Sep 17 00:00:00 2001 From: Aleks Date: Mon, 25 Mar 2024 19:21:30 -0400 Subject: [PATCH] Further refinement of TensorRT-LLM backend based on WhisperS2T --- .../engine_builder/create_trt_model.py | 174 ++++++-- .../engines/tensorrt_llm/model.py | 24 +- .../engines/tensorrt_llm/segmenter.py | 397 ++++++++++++++++++ .../engines/tensorrt_llm/tokenizer.py | 6 +- .../engines/tensorrt_llm/trt_model.py | 6 + .../engines/tensorrt_llm/whisper_model.py | 2 +- .../services/transcribe_service.py | 8 +- 7 files changed, 561 insertions(+), 56 deletions(-) create mode 100644 src/wordcab_transcribe/engines/tensorrt_llm/segmenter.py diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py index 827c21c..c9d22a4 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py @@ -1,7 +1,10 @@ +import hashlib import os import subprocess import requests +from loguru import logger +from tqdm import tqdm _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", @@ -18,8 +21,45 @@ "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", } +_TOKENIZERS = { + "tiny.en": ( + "https://huggingface.co/Systran/faster-whisper-tiny.en/raw/main/tokenizer.json" + ), + "tiny": ( + "https://huggingface.co/Systran/faster-whisper-tiny/raw/main/tokenizer.json" + ), + "small.en": ( + "https://huggingface.co/Systran/faster-whisper-small.en/raw/main/tokenizer.json" + ), + "small": ( + "https://huggingface.co/Systran/faster-whisper-small/raw/main/tokenizer.json" + ), + "base.en": ( + "https://huggingface.co/Systran/faster-whisper-base.en/raw/main/tokenizer.json" + ), + "base": ( + "https://huggingface.co/Systran/faster-whisper-base/raw/main/tokenizer.json" + ), + "medium.en": "https://huggingface.co/Systran/faster-whisper-medium.en/raw/main/tokenizer.json", + "medium": ( + "https://huggingface.co/Systran/faster-whisper-medium/raw/main/tokenizer.json" + ), + "large-v1": ( + "https://huggingface.co/Systran/faster-whisper-large-v1/raw/main/tokenizer.json" + ), + "large-v2": ( + "https://huggingface.co/Systran/faster-whisper-large-v2/raw/main/tokenizer.json" + ), + "large-v3": ( + "https://huggingface.co/Systran/faster-whisper-large-v3/raw/main/tokenizer.json" + ), + "large": ( + "https://huggingface.co/Systran/faster-whisper-large-v3/raw/main/tokenizer.json" + ), +} + -def build_whisper_model( +def build_whisper_trt_model( output_dir, use_gpt_attention_plugin=True, use_gemm_plugin=True, @@ -44,41 +84,97 @@ def build_whisper_model( None """ model_url = _MODELS[model_name] - model_path = f"assets/{model_name}.pt" - - # Download the model if it doesn't exist - if not os.path.exists(model_path): - os.makedirs("assets", exist_ok=True) - - print(f"Downloading model '{model_name}' from {model_url}...") - response = requests.get(model_url) - - if response.status_code == 200: - with open(model_path, "wb") as file: - file.write(response.content) - print(f"Model '{model_name}' downloaded successfully.") - else: - print( - f"Failed to download model '{model_name}'. Status code:" - f" {response.status_code}" - ) - return - - command = ["python3", "build.py", "--output_dir", output_dir] - - if use_gpt_attention_plugin: - command.append("--use_gpt_attention_plugin") - if use_gemm_plugin: - command.append("--use_gemm_plugin") - if use_bert_attention_plugin: - command.append("--use_bert_attention_plugin") - if enable_context_fmha: - command.append("--enable_context_fmha") - if use_weight_only: - command.append("--use_weight_only") - - try: - subprocess.run(command, check=True) - except subprocess.CalledProcessError as e: - print(f"Error occurred while building the model: {e}") - raise + expected_sha256 = model_url.split("/")[-2] + model_ckpt_path = f"../assets/{model_name}.pt" + tokenizer_path = f"{output_dir}/tokenizer.json" + + if not os.path.exists(model_ckpt_path): + os.makedirs("../assets", exist_ok=True) + + logger.info(f"Downloading model '{model_name}' from {model_url}...") + + response = requests.get(model_url, stream=True) + total_size = int(response.headers.get("Content-Length", 0)) + + with open(model_ckpt_path, "wb") as output: + with tqdm( + total=total_size, + ncols=80, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + for data in response.iter_content(chunk_size=8192): + size = output.write(data) + pbar.update(size) + + with open(model_ckpt_path, "rb") as f: + model_bytes = f.read() + if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not" + " match. Please retry loading the model." + ) + + if not os.path.exists(output_dir): + logger.info("Building the model...") + command = [ + "python3", + "build.py", + "--output_dir", + output_dir, + "--model_name", + model_name, + ] + + if use_gpt_attention_plugin: + command.append("--use_gpt_attention_plugin") + if use_gemm_plugin: + command.append("--use_gemm_plugin") + if use_bert_attention_plugin: + command.append("--use_bert_attention_plugin") + if enable_context_fmha: + command.append("--enable_context_fmha") + if use_weight_only: + command.append("--use_weight_only") + + try: + subprocess.run(command, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"Error occurred while building the model: {e}") + raise + logger.info("Model has been built successfully.") + + if not os.path.exists(tokenizer_path): + logger.info(f"Downloading tokenizer for model '{model_name}'...") + response = requests.get(_TOKENIZERS[model_name], stream=True) + total_size = int(response.headers.get("Content-Length", 0)) + + with open(tokenizer_path, "wb") as output: + with tqdm( + total=total_size, + ncols=80, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + for data in response.iter_content(chunk_size=8192): + size = output.write(data) + pbar.update(size) + logger.info("Tokenizer has been downloaded successfully.") + + for filename in os.listdir(output_dir): + if "encoder" in filename and filename.endswith(".engine"): + new_filename = "encoder.engine" + old_path = os.path.join(output_dir, filename) + new_path = os.path.join(output_dir, new_filename) + os.rename(old_path, new_path) + logger.info(f"Renamed '{filename}' to '{new_filename}'") + elif "decoder" in filename and filename.endswith(".engine"): + new_filename = "decoder.engine" + old_path = os.path.join(output_dir, filename) + new_path = os.path.join(output_dir, new_filename) + os.rename(old_path, new_path) + logger.info(f"Renamed '{filename}' to '{new_filename}'") + + return output_dir diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/model.py b/src/wordcab_transcribe/engines/tensorrt_llm/model.py index 0d28819..ba07437 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/model.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/model.py @@ -2,10 +2,12 @@ import ctranslate2 import numpy as np -import tokenizers +from wordcab_transcribe.engines.tensorrt_llm.engine_builder.create_trt_model import ( + build_whisper_trt_model, +) from wordcab_transcribe.engines.tensorrt_llm.hf_utils import download_model -from wordcab_transcribe.engines.tensorrt_llm.tokenizer import Tokenizer +from wordcab_transcribe.engines.tensorrt_llm.tokenizers import Tokenizer from wordcab_transcribe.engines.tensorrt_llm.trt_model import WhisperTRT from wordcab_transcribe.engines.tensorrt_llm.whisper_model import WhisperModel @@ -77,11 +79,11 @@ def exact_div(x, y): class WhisperModelTRT(WhisperModel): - """TensorRT implementation of the Whisper model.""" + """TensorRT-LLM implementation of the Whisper model.""" def __init__( self, - model_name_or_path: str, + model_name: str, asr_options: dict, cpu_threads=4, num_workers=1, @@ -91,20 +93,20 @@ def __init__( max_text_token_len=15, **model_kwargs ): - # ASR Options self.asr_options = FAST_ASR_OPTIONS self.asr_options.update(asr_options) - self.model_path = model_name_or_path - # # TODO build engine if not exists + self.model_name = model_name + self.model_path = os.path.join("models", self.model_name) - # Load model + if not os.path.exists(self.model_path): + self.model_path = build_whisper_trt_model( + self.model_path, model_name=self.model_name + ) self.model = WhisperTRT(self.model_path) - # Load tokenizer - # TODO: Have this downloaded as well tokenizer_file = os.path.join(self.model_path, "tokenizer.json") tokenizer = Tokenizer( - tokenizers.Tokenizer.from_file(tokenizer_file), self.model.is_multilingual + Tokenizer.from_file(tokenizer_file), self.model.is_multilingual ) if self.asr_options["word_timestamps"]: diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/segmenter.py b/src/wordcab_transcribe/engines/tensorrt_llm/segmenter.py new file mode 100644 index 0000000..629b32b --- /dev/null +++ b/src/wordcab_transcribe/engines/tensorrt_llm/segmenter.py @@ -0,0 +1,397 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch + + +class VADBaseClass(ABC): + def __init__(self, sampling_rate=16000): + self.sampling_rate = sampling_rate + + @abstractmethod + def update_params(self, params): + pass + + @abstractmethod + def __call__(self, audio_signal, batch_size=4): + pass + + +class FrameVAD(VADBaseClass): + def __init__( + self, + device=None, + chunk_size=15.0, + margin_size=1.0, + frame_size=0.02, + batch_size=4, + sampling_rate=16000, + ): + super().__init__(sampling_rate=sampling_rate) + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.device = device + + if self.device == "cpu": + # This is a JIT Scripted model of Nvidia's NeMo Framewise Marblenet Model: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/vad_multilingual_frame_marblenet + self.vad_pp = torch.jit.load("assets/vad_pp_cpu.ts").to(self.device) + self.vad_model = torch.jit.load("assets/frame_vad_model_cpu.ts").to( + self.device + ) + else: + self.vad_pp = torch.jit.load("assets/vad_pp_gpu.ts").to(self.device) + self.vad_model = torch.jit.load("assets/frame_vad_model_gpu.ts").to( + self.device + ) + + self.vad_pp.eval() + self.vad_model.eval() + + self.batch_size = batch_size + self.frame_size = frame_size + self.chunk_size = chunk_size + self.margin_size = margin_size + + self._init_params() + + def _init_params(self): + self.signal_chunk_len = int(self.chunk_size * self.sampling_rate) + self.signal_stride = int( + self.signal_chunk_len - 2 * int(self.margin_size * self.sampling_rate) + ) + + self.margin_logit_len = int(self.margin_size / self.frame_size) + self.signal_to_logit_len = int(self.frame_size * self.sampling_rate) + + self.vad_pp.to(self.device) + self.vad_model.to(self.device) + + def update_params(self, params): + for key, value in params.items(): + setattr(self, key, value) + + self._init_params() + + def prepare_input_batch(self, audio_signal): + input_signal = [] + input_signal_length = [] + for s_idx in range(0, len(audio_signal), self.signal_stride): + _signal = audio_signal[s_idx : s_idx + self.signal_chunk_len] + _signal_len = len(_signal) + input_signal.append(_signal) + input_signal_length.append(_signal_len) + + if _signal_len < self.signal_chunk_len: + input_signal[-1] = np.pad( + input_signal[-1], (0, self.signal_chunk_len - _signal_len) + ) + break + + return input_signal, input_signal_length + + @torch.cuda.amp.autocast() + @torch.no_grad() + def forward(self, input_signal, input_signal_length): + all_logits = [] + for s_idx in range(0, len(input_signal), self.batch_size): + input_signal_pt = torch.stack( + [ + torch.tensor(_, device=self.device) + for _ in input_signal[s_idx : s_idx + self.batch_size] + ] + ) + input_signal_length_pt = torch.tensor( + input_signal_length[s_idx : s_idx + self.batch_size], device=self.device + ) + + x, x_len = self.vad_pp(input_signal_pt, input_signal_length_pt) + logits = self.vad_model(x, x_len) + + for _logits, _len in zip(logits, input_signal_length_pt): + all_logits.append(_logits[: int(_len / self.signal_to_logit_len)]) + + if len(all_logits) > 1 and self.margin_logit_len > 0: + all_logits[0] = all_logits[0][: -self.margin_logit_len] + all_logits[-1] = all_logits[-1][self.margin_logit_len :] + + for i in range(1, len(all_logits) - 1): + all_logits[i] = all_logits[i][ + self.margin_logit_len : -self.margin_logit_len + ] + + all_logits = torch.concatenate(all_logits) + all_logits = torch.softmax(all_logits, dim=-1) + + return all_logits[:, 1].detach().cpu().numpy() + + def __call__(self, audio_signal): + audio_duration = len(audio_signal) / self.sampling_rate + + input_signal, input_signal_length = self.prepare_input_batch(audio_signal) + speech_probs = self.forward(input_signal, input_signal_length) + + vad_times = [] + for idx, prob in enumerate(speech_probs): + s_time = idx * self.frame_size + e_time = min(audio_duration, (idx + 1) * self.frame_size) + + if s_time >= e_time: + break + + vad_times.append([prob, s_time, e_time]) + + return np.array(vad_times) + + +class SegmentVAD(VADBaseClass): + def __init__( + self, + device=None, + win_len=0.32, + win_step=0.08, + batch_size=512, + sampling_rate=16000, + ): + super().__init__(sampling_rate=sampling_rate) + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.device = device + + if self.device == "cpu": + # This is a JIT Scripted model of Nvidia's NeMo Marblenet Model: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/vad_multilingual_marblenet + self.vad_pp = torch.jit.load("assets/vad_pp_cpu.ts").to(self.device) + self.vad_model = torch.jit.load("assets/seg_vad_model_cpu.ts").to( + self.device + ) + else: + self.vad_pp = torch.jit.load("assets/vad_pp_gpu.ts").to(self.device) + self.vad_model = torch.jit.load("assets/seg_vad_model_gpu.ts").to( + self.device + ) + + self.vad_pp = torch.jit.load("assets/vad_pp.ts") + self.vad_model = torch.jit.load("assets/segment_vad_model.ts") + + self.vad_model.eval() + self.vad_model.to(self.device) + + self.vad_pp.eval() + self.vad_pp.to(self.device) + + self.batch_size = batch_size + self.win_len = win_len + self.win_step = win_step + + self._init_params() + + def _init_params(self): + self.signal_win_len = int(self.win_len * self.sampling_rate) + self.signal_win_step = int(self.win_step * self.sampling_rate) + + def update_params(self, params: dict): + for key, value in params.items(): + setattr(self, key, value) + + self._init_params() + + def prepare_input_batch(self, audio_signal): + num_chunks = ( + self.signal_win_len // 2 + len(audio_signal) + ) // self.signal_win_step + if ( + num_chunks + < (self.signal_win_len // 2 + len(audio_signal)) / self.signal_win_step + ): + num_chunks += 1 + + input_signal = np.zeros((num_chunks, self.signal_win_len), dtype=np.float32) + input_signal_length = np.zeros(num_chunks, dtype=np.int64) + + chunk_idx = 0 + for idx in range( + -1 * self.signal_win_len // 2, len(audio_signal), self.signal_win_step + ): + s_idx = max(idx, 0) + e_idx = min(idx + self.signal_win_len, len(audio_signal)) + input_signal[chunk_idx][: e_idx - s_idx] = audio_signal[s_idx:e_idx] + input_signal_length[chunk_idx] = e_idx - s_idx + chunk_idx += 1 + + return input_signal, input_signal_length + + @torch.cuda.amp.autocast() + @torch.no_grad() + def forward(self, input_signal, input_signal_length): + x, x_len = self.vad_pp( + torch.Tensor(input_signal).to(self.device), + torch.Tensor(input_signal_length).to(self.device), + ) + logits = self.vad_model(x, x_len) + logits = torch.softmax(logits, dim=-1) + return logits[:, 1].detach().cpu().numpy() + + def __call__(self, audio_signal): + audio_duration = len(audio_signal) / self.sampling_rate + + input_signal, input_signal_length = self.prepare_input_batch(audio_signal) + + speech_probs = np.zeros(len(input_signal)) + for s_idx in range(0, len(input_signal), self.batch_size): + speech_probs[s_idx : s_idx + self.batch_size] = self.forward( + input_signal=input_signal[s_idx : s_idx + self.batch_size], + input_signal_length=input_signal_length[ + s_idx : s_idx + self.batch_size + ], + ) + + vad_times = [] + for idx, prob in enumerate(speech_probs): + s_time = max(0, (idx - 0.5) * self.win_step) + e_time = min(audio_duration, (idx + 0.5) * self.win_step) + vad_times.append([prob, s_time, e_time]) + + return np.array(vad_times) + + +class SpeechSegmenter: + def __init__( + self, + vad_model=None, + device=None, + frame_size=0.02, + min_seg_len=0.08, + max_seg_len=29.0, + max_silent_region=0.6, + padding=0.2, + eos_thresh=0.3, + bos_thresh=0.3, + cut_factor=2, + sampling_rate=16000, + ): + if vad_model is None: + vad_model = FrameVAD(device=device) + + self.vad_model = vad_model + + self.sampling_rate = sampling_rate + self.padding = padding + self.frame_size = frame_size + self.min_seg_len = min_seg_len + self.max_seg_len = max_seg_len + self.max_silent_region = max_silent_region + + self.eos_thresh = eos_thresh + self.bos_thresh = bos_thresh + + self.cut_factor = cut_factor + self.cut_idx = int(self.max_seg_len / (self.cut_factor * self.frame_size)) + self.max_idx_in_seg = self.cut_factor * self.cut_idx + + def update_params(self, params): + for key, value in params.items(): + setattr(self, key, value) + + self.cut_idx = int(self.max_seg_len / (self.cut_factor * self.frame_size)) + self.max_idx_in_seg = self.cut_factor * self.cut_idx + + def update_vad_model_params(self, params): + self.vad_model.update_params(params=params) + + def okay_to_merge(self, speech_probs, last_seg, curr_seg): + conditions = [ + (speech_probs[curr_seg["start"]][1] - speech_probs[last_seg["end"]][2]) + < self.max_silent_region, + (speech_probs[curr_seg["end"]][2] - speech_probs[last_seg["start"]][1]) + <= self.max_seg_len, + ] + + return all(conditions) + + def get_speech_segments(self, speech_probs): + speech_flag, start_idx = False, 0 + speech_segments = [] + for idx, (speech_prob, _st, _et) in enumerate(speech_probs): + if speech_flag: + if speech_prob < self.eos_thresh: + speech_flag = False + curr_seg = {"start": start_idx, "end": idx - 1} + + if len(speech_segments) and self.okay_to_merge( + speech_probs, speech_segments[-1], curr_seg + ): + speech_segments[-1]["end"] = curr_seg["end"] + else: + speech_segments.append(curr_seg) + + elif speech_prob >= self.bos_thresh: + speech_flag = True + start_idx = idx + + if speech_flag: + curr_seg = {"start": start_idx, "end": len(speech_probs) - 1} + + if len(speech_segments) and self.okay_to_merge( + speech_probs, speech_segments[-1], curr_seg + ): + speech_segments[-1]["end"] = curr_seg["end"] + else: + speech_segments.append(curr_seg) + + speech_segments = [ + _ + for _ in speech_segments + if (speech_probs[_["end"]][2] - speech_probs[_["start"]][1]) + > self.min_seg_len + ] + + start_ends = [] + for _ in speech_segments: + first_idx = len(start_ends) + start_idx, end_idx = _["start"], _["end"] + while (end_idx - start_idx) > self.max_idx_in_seg: + _start_idx = int(start_idx + self.cut_idx) + _end_idx = int(min(end_idx, start_idx + self.max_idx_in_seg)) + + new_end_idx = _start_idx + np.argmin( + speech_probs[_start_idx:_end_idx, 0] + ) + start_ends.append( + [speech_probs[start_idx][1], speech_probs[new_end_idx][2]] + ) + start_idx = new_end_idx + 1 + + start_ends.append( + [speech_probs[start_idx][1], speech_probs[end_idx][2] + self.padding] + ) + start_ends[first_idx][0] = start_ends[first_idx][0] - self.padding + + return start_ends + + def __call__(self, audio_data=None): + if audio_data is not None: + if isinstance(audio_data, np.ndarray): + audio_signal = audio_data + audio_duration = len(audio_signal) / self.sampling_rate + elif isinstance(audio_data, torch.Tensor): + audio_tensor = audio_data + audio_signal = audio_tensor.squeeze().cpu().numpy() + audio_duration = len(audio_signal) / self.sampling_rate + else: + raise ValueError("`audio_data` must be a numpy array or torch tensor.") + else: + raise ValueError("`audio_data` must be a numpy array or torch tensor.") + + speech_probs = self.vad_model(audio_signal) + start_ends = self.get_speech_segments(speech_probs) + + if len(start_ends) == 0: + start_ends = [[0.0, self.max_seg_len]] # Quick fix for silent audio. + + start_ends[0][0] = max(0.0, start_ends[0][0]) # fix edges + start_ends[-1][1] = min(audio_duration, start_ends[-1][1]) # fix edges + + return start_ends, audio_signal diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/tokenizer.py b/src/wordcab_transcribe/engines/tensorrt_llm/tokenizer.py index 1311d5c..fbb325e 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/tokenizer.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/tokenizer.py @@ -1,15 +1,13 @@ -import os from functools import cached_property -from __init__ import BASE_PATH - _TASKS = ( "transcribe", "translate", ) + # TODO: add language dict -with open(os.path.join(BASE_PATH, "assets/lang_codes.txt"), "r") as f: +with open("assets/lang_codes.txt", "r") as f: _LANGUAGE_CODES = [_ for _ in f.read().split("\n") if _] diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py b/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py index 1438b2e..e977176 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py @@ -10,6 +10,8 @@ class WhisperEncoding: + """Class for encoding audio to features using TensorRT.""" + def __init__(self, engine_dir): self.session = self.get_session(engine_dir) @@ -34,6 +36,7 @@ def get_session(self, engine_dir): return session def get_audio_features(self, mel): + """Get audio features from mel spectrogram.""" input_lengths = torch.tensor( [mel.shape[2] // 2 for _ in range(mel.shape[0])], dtype=torch.int32, @@ -66,6 +69,8 @@ def get_audio_features(self, mel): class WhisperDecoding: + """Class for decoding features to text using TensorRT.""" + def __init__(self, engine_dir, runtime_mapping, debug_mode=False): self.decoder_config = self.get_config(engine_dir) self.decoder_generation_session = self.get_session( @@ -87,6 +92,7 @@ def get_session(self, engine_dir, runtime_mapping, debug_mode=False): with open(serialize_path, "rb") as f: decoder_engine_buffer = f.read() + # TODO: Make dynamic max_batch_size and max_beam_width decoder_model_config = ModelConfig( max_batch_size=8, max_beam_width=1, diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/whisper_model.py b/src/wordcab_transcribe/engines/tensorrt_llm/whisper_model.py index 7ebd3b6..9dcd872 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/whisper_model.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/whisper_model.py @@ -4,7 +4,7 @@ from wordcab_transcribe.engines.tensorrt_llm.audio import LogMelSpectogram from wordcab_transcribe.engines.tensorrt_llm.data import WhisperTRTDataLoader -from wordcab_transcribe.engines.tensorrt_llm.speech_segmenter import SpeechSegmenter +from wordcab_transcribe.engines.tensorrt_llm.segmenter import SpeechSegmenter class NoneTokenizer: diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 530f35b..47f433e 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -26,6 +26,7 @@ from loguru import logger from tensorshare import Backend, TensorShare +from wordcab_transcribe.engines.tensorrt_llm.model import WhisperModelTRT from wordcab_transcribe.models import ( MultiChannelSegment, MultiChannelTranscriptionOutput, @@ -88,7 +89,12 @@ def __init__( compute_type=self.compute_type, ) elif model_engine == "tensorrt-llm": - pass + self.model = WhisperModelTRT( + self.model_path, + device=self.device, + device_index=device_index, + compute_type=self.compute_type, + ) else: self.model = WhisperModel( self.model_path,