diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 337cc428..73e37566 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -5,7 +5,17 @@ import zlib from inspect import signature -from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import ( + Any, + BinaryIO, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import ctranslate2 import numpy as np @@ -30,7 +40,7 @@ class Word(NamedTuple): probability: float -class Segment(NamedTuple): +class NamedTupleSegment(NamedTuple): id: int seek: int start: float @@ -44,6 +54,26 @@ class Segment(NamedTuple): words: Optional[List[Word]] +class Segment(NamedTupleSegment): + def _asdict(self) -> Dict[str, Any]: + words: Optional[List[Dict[str, Any]]] = None + if self.words: + words = [word._asdict() for word in self.words] + return { + "id": self.id, + "seek": self.seek, + "start": self.start, + "end": self.end, + "text": self.text, + "tokens": self.tokens, + "temperature": self.temperature, + "avg_logprob": self.avg_logprob, + "compression_ratio": self.compression_ratio, + "no_speech_prob": self.no_speech_prob, + "words": words, + } + + class TranscriptionOptions(NamedTuple): beam_size: int best_of: int diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index d30a0fb6..936e930e 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,6 +1,7 @@ import os from faster_whisper import WhisperModel, decode_audio +from faster_whisper.transcribe import Segment, Word def test_supported_languages(): @@ -97,3 +98,48 @@ def test_stereo_diarization(data_dir): segments, _ = model.transcribe(right) transcription = "".join(segment.text for segment in segments).strip() assert transcription == "The horizon seems extremely distant." + + +def test_segment_as_dict(): + segment = Segment( + id=1, + seek=398, + start=0.0, + end=3.04, + text=" Hello world.", + tokens=[50364, 2425, 1002, 13, 50516], + temperature=0.9, + avg_logprob=-1.1086603999137878, + compression_ratio=0.6, + no_speech_prob=0.10812531411647797, + words=[ + Word(start=0.0, end=1.82, word=" Hello", probability=0.8265082836151123), + Word(start=1.82, end=3.04, word=" world.", probability=0.31053248047828674), + ], + ) + assert segment._asdict() == { + "id": 1, + "seek": 398, + "start": 0.0, + "end": 3.04, + "text": " Hello world.", + "tokens": [50364, 2425, 1002, 13, 50516], + "temperature": 0.9, + "avg_logprob": -1.1086603999137878, + "compression_ratio": 0.6, + "no_speech_prob": 0.10812531411647797, + "words": [ + { + "start": 0.0, + "end": 1.82, + "word": " Hello", + "probability": 0.8265082836151123, + }, + { + "start": 1.82, + "end": 3.04, + "word": " world.", + "probability": 0.31053248047828674, + }, + ], + }