From 847fec449286997eee21e66d2cd45a3f589c29ac Mon Sep 17 00:00:00 2001 From: jax <119055910+jax-explorer@users.noreply.github.com> Date: Sat, 4 May 2024 16:11:52 +0800 Subject: [PATCH] Feature/add hotwords (#731) * add hotword params --------- Co-authored-by: jax --- faster_whisper/transcribe.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 337cc428..cf5ece06 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple): max_new_tokens: Optional[int] clip_timestamps: Union[str, List[float]] hallucination_silence_threshold: Optional[float] + hotwords: Optional[str] class TranscriptionInfo(NamedTuple): @@ -220,6 +221,7 @@ def transcribe( chunk_length: Optional[int] = None, clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + hotwords: Optional[str] = None, language_detection_threshold: Optional[float] = None, language_detection_segments: int = 1, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: @@ -284,10 +286,11 @@ def transcribe( hallucination_silence_threshold: Optional[float] When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + hotwords:Optional text + add hotwords if set prefix it invalid language_detection_threshold: If the maximum probability of the language tokens is higher than this value, the language is detected. language_detection_segments: Number of segments to consider for the language detection. - Returns: A tuple with: @@ -441,6 +444,7 @@ def transcribe( max_new_tokens=max_new_tokens, clip_timestamps=clip_timestamps, hallucination_silence_threshold=hallucination_silence_threshold, + hotwords=hotwords, ) segments = self.generate_segments(features, tokenizer, options, encoder_output) @@ -547,6 +551,7 @@ def generate_segments( previous_tokens, without_timestamps=options.without_timestamps, prefix=options.prefix if seek == 0 else None, + hotwords=options.hotwords, ) if seek > 0 or encoder_output is None: @@ -939,12 +944,19 @@ def get_prompt( previous_tokens: List[int], without_timestamps: bool = False, prefix: Optional[str] = None, + hotwords: Optional[str] = None, ) -> List[int]: prompt = [] - if previous_tokens: + if previous_tokens or (hotwords and not prefix): prompt.append(tokenizer.sot_prev) - prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) + if hotwords and not prefix: + hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.max_length // 2: + hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] + prompt.extend(hotwords_tokens) + if previous_tokens: + prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(tokenizer.sot_sequence)