From b547794648079b31149e624c1be371cb6281fffa Mon Sep 17 00:00:00 2001 From: jax Date: Sun, 3 Mar 2024 23:17:14 +0800 Subject: [PATCH 1/6] add hotwords init --- faster_whisper/transcribe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index d3d5debc..c0970cb8 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -499,7 +499,7 @@ def generate_segments( "Processing segment at %s", format_timestamp(time_offset) ) - previous_tokens = all_tokens[prompt_reset_since:] + previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( tokenizer, previous_tokens, @@ -899,9 +899,11 @@ def get_prompt( prefix: Optional[str] = None, ) -> List[int]: prompt = [] - if previous_tokens: prompt.append(tokenizer.sot_prev) + hotwords = "this video is about ComfyUI" + hotwords_token = tokenizer.encode(hotwords) + prompt.extend(hotwords_token) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(tokenizer.sot_sequence) From ec78dd85ae211a319daca7c85cbe7c3b7606163d Mon Sep 17 00:00:00 2001 From: jax Date: Sun, 3 Mar 2024 23:57:28 +0800 Subject: [PATCH 2/6] add hotword params --- faster_whisper/transcribe.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index c0970cb8..25b51d4c 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple): max_new_tokens: Optional[int] clip_timestamps: Union[str, List[float]] hallucination_silence_threshold: Optional[float] + hotwords: Optional[str] class TranscriptionInfo(NamedTuple): @@ -220,6 +221,7 @@ def transcribe( chunk_length: Optional[int] = None, clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + hotwords : Optional[str] = None ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -281,7 +283,8 @@ def transcribe( hallucination_silence_threshold: Optional[float] When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected - + hotwords:Optional text + add hotwords if set prefix it invalid Returns: A tuple with: @@ -399,6 +402,7 @@ def transcribe( max_new_tokens=max_new_tokens, clip_timestamps=clip_timestamps, hallucination_silence_threshold=hallucination_silence_threshold, + hotwords=hotwords ) segments = self.generate_segments(features, tokenizer, options, encoder_output) @@ -505,6 +509,7 @@ def generate_segments( previous_tokens, without_timestamps=options.without_timestamps, prefix=options.prefix if seek == 0 else None, + hotwords = options.hotwords ) if seek > 0 or encoder_output is None: @@ -897,13 +902,11 @@ def get_prompt( previous_tokens: List[int], without_timestamps: bool = False, prefix: Optional[str] = None, + hotwords:Optional[str] = None ) -> List[int]: prompt = [] if previous_tokens: prompt.append(tokenizer.sot_prev) - hotwords = "this video is about ComfyUI" - hotwords_token = tokenizer.encode(hotwords) - prompt.extend(hotwords_token) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(tokenizer.sot_sequence) @@ -919,6 +922,12 @@ def get_prompt( prompt.append(tokenizer.timestamp_begin) prompt.extend(prefix_tokens) + if hotwords and not prefix: + hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.max_length // 2: + hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] + prompt.extend(hotwords_tokens) + return prompt def add_word_timestamps( From aa345f5667fb785d074ded4bdd5ecc5e25c63f96 Mon Sep 17 00:00:00 2001 From: jax Date: Mon, 4 Mar 2024 00:09:00 +0800 Subject: [PATCH 3/6] change hotwords pos --- faster_whisper/transcribe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 25b51d4c..7216373b 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -905,6 +905,12 @@ def get_prompt( hotwords:Optional[str] = None ) -> List[int]: prompt = [] + + if hotwords and not prefix: + hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.max_length // 2: + hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] + prompt.extend(hotwords_tokens) if previous_tokens: prompt.append(tokenizer.sot_prev) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) @@ -922,12 +928,6 @@ def get_prompt( prompt.append(tokenizer.timestamp_begin) prompt.extend(prefix_tokens) - if hotwords and not prefix: - hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) - if len(hotwords_tokens) >= self.max_length // 2: - hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] - prompt.extend(hotwords_tokens) - return prompt def add_word_timestamps( From b0a29ce36e1cf66925f0a5ecb4649e26a9ca5904 Mon Sep 17 00:00:00 2001 From: jax Date: Mon, 4 Mar 2024 00:16:28 +0800 Subject: [PATCH 4/6] change hotwords pos --- faster_whisper/transcribe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 7216373b..db9b3bbe 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -906,13 +906,13 @@ def get_prompt( ) -> List[int]: prompt = [] - if hotwords and not prefix: - hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) - if len(hotwords_tokens) >= self.max_length // 2: - hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] - prompt.extend(hotwords_tokens) if previous_tokens: prompt.append(tokenizer.sot_prev) + if hotwords and not prefix: + hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.max_length // 2: + hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] + prompt.extend(hotwords_tokens) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(tokenizer.sot_sequence) From 2858892baeb3c322a256fbaa9acf357ea2010a86 Mon Sep 17 00:00:00 2001 From: jax Date: Mon, 4 Mar 2024 00:38:57 +0800 Subject: [PATCH 5/6] format --- faster_whisper/transcribe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index db9b3bbe..c681c2a9 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -221,7 +221,7 @@ def transcribe( chunk_length: Optional[int] = None, clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, - hotwords : Optional[str] = None + hotwords: Optional[str] = None, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -402,7 +402,7 @@ def transcribe( max_new_tokens=max_new_tokens, clip_timestamps=clip_timestamps, hallucination_silence_threshold=hallucination_silence_threshold, - hotwords=hotwords + hotwords=hotwords, ) segments = self.generate_segments(features, tokenizer, options, encoder_output) @@ -503,13 +503,13 @@ def generate_segments( "Processing segment at %s", format_timestamp(time_offset) ) - previous_tokens = all_tokens[prompt_reset_since:] + previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( tokenizer, previous_tokens, without_timestamps=options.without_timestamps, prefix=options.prefix if seek == 0 else None, - hotwords = options.hotwords + hotwords=options.hotwords, ) if seek > 0 or encoder_output is None: @@ -902,7 +902,7 @@ def get_prompt( previous_tokens: List[int], without_timestamps: bool = False, prefix: Optional[str] = None, - hotwords:Optional[str] = None + hotwords: Optional[str] = None, ) -> List[int]: prompt = [] From 78869dff53b7c51a18dd0641cce4789f0713a790 Mon Sep 17 00:00:00 2001 From: jax Date: Fri, 8 Mar 2024 11:03:20 +0800 Subject: [PATCH 6/6] fix hotwords first window invaild --- faster_whisper/transcribe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index c681c2a9..92d7a4c3 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -906,14 +906,15 @@ def get_prompt( ) -> List[int]: prompt = [] - if previous_tokens: + if previous_tokens or (hotwords and not prefix): prompt.append(tokenizer.sot_prev) if hotwords and not prefix: hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) if len(hotwords_tokens) >= self.max_length // 2: hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] prompt.extend(hotwords_tokens) - prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) + if previous_tokens: + prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(tokenizer.sot_sequence)