diff --git a/jiwer/alignment.py b/jiwer/alignment.py index 3944b7c..98c8811 100644 --- a/jiwer/alignment.py +++ b/jiwer/alignment.py @@ -26,7 +26,7 @@ from jiwer.process import CharacterOutput, WordOutput, AlignmentChunk -__all__ = ["visualize_alignment"] +__all__ = ["visualize_alignment", "get_alignment_words"] def visualize_alignment( @@ -108,9 +108,7 @@ def visualize_alignment( continue final_str += f"sentence {idx+1}\n" - final_str += _construct_comparison_string( - gt, hp, chunks, include_space_seperator=not is_cer - ) + final_str += _construct_comparison_string(gt, hp, chunks, include_space_separator=not is_cer) final_str += "\n" if show_measures: @@ -138,48 +136,73 @@ def _construct_comparison_string( reference: List[str], hypothesis: List[str], ops: List[AlignmentChunk], - include_space_seperator: bool = False, + include_space_separator: bool = False, ) -> str: - ref_str = "REF: " - hyp_str = "HYP: " - op_str = " " - - for op in ops: - if op.type == "equal" or op.type == "substitute": - ref = reference[op.ref_start_idx : op.ref_end_idx] - hyp = hypothesis[op.hyp_start_idx : op.hyp_end_idx] - op_char = " " if op.type == "equal" else "s" - elif op.type == "delete": - ref = reference[op.ref_start_idx : op.ref_end_idx] - hyp = ["*" for _ in range(len(ref))] - op_char = "d" - elif op.type == "insert": - hyp = hypothesis[op.hyp_start_idx : op.hyp_end_idx] - ref = ["*" for _ in range(len(hyp))] - op_char = "i" - else: - raise ValueError(f"unparseable op name={op.type}") + reference_str = "REF: " + hypothesis_str = "HYP: " + operation_str = " " - op_chars = [op_char for _ in range(len(ref))] - for rf, hp, c in zip(ref, hyp, op_chars): - str_len = max(len(rf), len(hp), len(c)) + reference_words, hypothesis_words, operation_chars = get_alignment_words(reference, hypothesis, ops) - if rf == "*": - rf = "".join(["*"] * str_len) - elif hp == "*": - hp = "".join(["*"] * str_len) + for reference_word, hypothesis_word, operation_char in zip(reference_words, hypothesis_words, operation_chars): + word_len = max(len(reference_word), len(hypothesis_word), len(operation_char)) - ref_str += f"{rf:>{str_len}}" - hyp_str += f"{hp:>{str_len}}" - op_str += f"{c.upper():>{str_len}}" + reference_str += f"{reference_word:>{word_len}}" + hypothesis_str += f"{hypothesis_word:>{word_len}}" + operation_str += f"{operation_char:>{word_len}}" - if include_space_seperator: - ref_str += " " - hyp_str += " " - op_str += " " + if include_space_separator: + reference_str += " " + hypothesis_str += " " + operation_str += " " - if include_space_seperator: + if include_space_separator: # remove last space - return f"{ref_str[:-1]}\n{hyp_str[:-1]}\n{op_str[:-1]}\n" + return f"{reference_str[:-1]}\n{hypothesis_str[:-1]}\n{operation_str[:-1]}\n" else: - return f"{ref_str}\n{hyp_str}\n{op_str}\n" + return f"{reference_str}\n{hypothesis_str}\n{operation_str}\n" + + +def get_alignment_words( + reference: List[str], hypothesis: List[str], operations: List[AlignmentChunk] +) -> Tuple[List[str], List[str], List[str]]: + """Generate aligned words and operation characters based on reference, hypothesis, and alignment operations. + + Args: + reference (List[str]): The list of reference words. + hypothesis (List[str]): The list of hypothesis words. + operations (List[AlignmentChunk]): The list of alignment operations. + + Returns: + Tuple[List[str], List[str], List[str]]: A tuple containing three lists: + - reference_words: The aligned reference words. + - hypothesis_words: The aligned hypothesis words. + - operation_chars: The operation characters for each alignment (' ' for equal, 'S' for substitute, + 'D' for delete, 'I' for insert). + + Raises: + ValueError: If an unparsable operation type is encountered. + """ + reference_words, hypothesis_words, operation_chars = [], [], [] + + for operation in operations: + if operation.type == "equal" or operation.type == "substitute": + ref_chunk_words = reference[operation.ref_start_idx : operation.ref_end_idx] + hyp_chunk_words = hypothesis[operation.hyp_start_idx : operation.hyp_end_idx] + operation_char = " " if operation.type == "equal" else "S" + elif operation.type == "delete": + ref_chunk_words = reference[operation.ref_start_idx : operation.ref_end_idx] + hyp_chunk_words = ["*" * len(word) for word in ref_chunk_words] + operation_char = "D" + elif operation.type == "insert": + hyp_chunk_words = hypothesis[operation.hyp_start_idx : operation.hyp_end_idx] + ref_chunk_words = ["*" * len(word) for word in hyp_chunk_words] + operation_char = "I" + else: + raise ValueError(f"Unparsable operation: {operation.type}") + + operation_chars.extend([operation_char] * len(ref_chunk_words)) + reference_words.extend(ref_chunk_words) + hypothesis_words.extend(hyp_chunk_words) + + return reference_words, hypothesis_words, operation_chars diff --git a/tests/test_alignment.py b/tests/test_alignment.py index e1f52cd..356eee6 100644 --- a/tests/test_alignment.py +++ b/tests/test_alignment.py @@ -123,6 +123,25 @@ def test_skip_correct(self): ) self.assertEqual(alignment, correct_alignment) + def test_get_alignment_words(self): + reference = ["this is a test of alignment words", "this is a test of the alignment words"] + hypothesis = ["this is a test of alignment words", "this is also a test of alignment wordz!"] + output = jiwer.process_words(reference, hypothesis) + + reference_words, hypothesis_words, operation_chars = jiwer.get_alignment_words( + output.references[0], output.hypotheses[0], output.alignments[0] + ) + self.assertEqual(reference_words, ["this", "is", "a", "test", "of", "alignment", "words"]) + self.assertEqual(hypothesis_words, ["this", "is", "a", "test", "of", "alignment", "words"]) + self.assertEqual(operation_chars, [" "] * len(reference_words)) + + reference_words, hypothesis_words, operation_chars = jiwer.get_alignment_words( + output.references[1], output.hypotheses[1], output.alignments[1] + ) + self.assertEqual(reference_words, ["this", "is", "****", "a", "test", "of", "the", "alignment", "words"]) + self.assertEqual(hypothesis_words, ["this", "is", "also", "a", "test", "of", "***", "alignment", "wordz!"]) + self.assertEqual(operation_chars, [" ", " ", "I", " ", " ", " ", "D", " ", "S"]) + class TestAlignmentVisualizationCharacters(unittest.TestCase): def test_insertion(self): @@ -222,4 +241,4 @@ def test_multiple_sentences(self): jiwer.process_characters(["one", "two"], ["1", "2"]), show_measures=False, ) - self.assertEqual(alignment, correct_alignment) + self.assertEqual(alignment, correct_alignment) \ No newline at end of file