Skip to content

Commit

Permalink
Feature/add hotwords (#731)
Browse files Browse the repository at this point in the history
* add hotword params

---------

Co-authored-by: jax <[email protected]>
  • Loading branch information
jax-explorer and jax authored May 4, 2024
1 parent 46080e5 commit 847fec4
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple):
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]
hotwords: Optional[str]


class TranscriptionInfo(NamedTuple):
Expand Down Expand Up @@ -220,6 +221,7 @@ def transcribe(
chunk_length: Optional[int] = None,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = None,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
Expand Down Expand Up @@ -284,10 +286,11 @@ def transcribe(
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected
hotwords:Optional text
add hotwords if set prefix it invalid
language_detection_threshold: If the maximum probability of the language tokens is higher
than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
A tuple with:
Expand Down Expand Up @@ -441,6 +444,7 @@ def transcribe(
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
hotwords=hotwords,
)

segments = self.generate_segments(features, tokenizer, options, encoder_output)
Expand Down Expand Up @@ -547,6 +551,7 @@ def generate_segments(
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix if seek == 0 else None,
hotwords=options.hotwords,
)

if seek > 0 or encoder_output is None:
Expand Down Expand Up @@ -939,12 +944,19 @@ def get_prompt(
previous_tokens: List[int],
without_timestamps: bool = False,
prefix: Optional[str] = None,
hotwords: Optional[str] = None,
) -> List[int]:
prompt = []

if previous_tokens:
if previous_tokens or (hotwords and not prefix):
prompt.append(tokenizer.sot_prev)
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
if hotwords and not prefix:
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
if len(hotwords_tokens) >= self.max_length // 2:
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
prompt.extend(hotwords_tokens)
if previous_tokens:
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])

prompt.extend(tokenizer.sot_sequence)

Expand Down

0 comments on commit 847fec4

Please sign in to comment.