Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Aligned Words List Output #93

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 64 additions & 41 deletions jiwer/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from jiwer.process import CharacterOutput, WordOutput, AlignmentChunk

__all__ = ["visualize_alignment"]
__all__ = ["visualize_alignment", "get_alignment_words"]


def visualize_alignment(
Expand Down Expand Up @@ -108,9 +108,7 @@ def visualize_alignment(
continue

final_str += f"sentence {idx+1}\n"
final_str += _construct_comparison_string(
gt, hp, chunks, include_space_seperator=not is_cer
)
final_str += _construct_comparison_string(gt, hp, chunks, include_space_separator=not is_cer)
final_str += "\n"

if show_measures:
Expand Down Expand Up @@ -138,48 +136,73 @@ def _construct_comparison_string(
reference: List[str],
hypothesis: List[str],
ops: List[AlignmentChunk],
include_space_seperator: bool = False,
include_space_separator: bool = False,
) -> str:
ref_str = "REF: "
hyp_str = "HYP: "
op_str = " "

for op in ops:
if op.type == "equal" or op.type == "substitute":
ref = reference[op.ref_start_idx : op.ref_end_idx]
hyp = hypothesis[op.hyp_start_idx : op.hyp_end_idx]
op_char = " " if op.type == "equal" else "s"
elif op.type == "delete":
ref = reference[op.ref_start_idx : op.ref_end_idx]
hyp = ["*" for _ in range(len(ref))]
op_char = "d"
elif op.type == "insert":
hyp = hypothesis[op.hyp_start_idx : op.hyp_end_idx]
ref = ["*" for _ in range(len(hyp))]
op_char = "i"
else:
raise ValueError(f"unparseable op name={op.type}")
reference_str = "REF: "
hypothesis_str = "HYP: "
operation_str = " "

op_chars = [op_char for _ in range(len(ref))]
for rf, hp, c in zip(ref, hyp, op_chars):
str_len = max(len(rf), len(hp), len(c))
reference_words, hypothesis_words, operation_chars = get_alignment_words(reference, hypothesis, ops)

if rf == "*":
rf = "".join(["*"] * str_len)
elif hp == "*":
hp = "".join(["*"] * str_len)
for reference_word, hypothesis_word, operation_char in zip(reference_words, hypothesis_words, operation_chars):
word_len = max(len(reference_word), len(hypothesis_word), len(operation_char))

ref_str += f"{rf:>{str_len}}"
hyp_str += f"{hp:>{str_len}}"
op_str += f"{c.upper():>{str_len}}"
reference_str += f"{reference_word:>{word_len}}"
hypothesis_str += f"{hypothesis_word:>{word_len}}"
operation_str += f"{operation_char:>{word_len}}"

if include_space_seperator:
ref_str += " "
hyp_str += " "
op_str += " "
if include_space_separator:
reference_str += " "
hypothesis_str += " "
operation_str += " "

if include_space_seperator:
if include_space_separator:
# remove last space
return f"{ref_str[:-1]}\n{hyp_str[:-1]}\n{op_str[:-1]}\n"
return f"{reference_str[:-1]}\n{hypothesis_str[:-1]}\n{operation_str[:-1]}\n"
else:
return f"{ref_str}\n{hyp_str}\n{op_str}\n"
return f"{reference_str}\n{hypothesis_str}\n{operation_str}\n"


def get_alignment_words(
reference: List[str], hypothesis: List[str], operations: List[AlignmentChunk]
) -> Tuple[List[str], List[str], List[str]]:
"""Generate aligned words and operation characters based on reference, hypothesis, and alignment operations.

Args:
reference (List[str]): The list of reference words.
hypothesis (List[str]): The list of hypothesis words.
operations (List[AlignmentChunk]): The list of alignment operations.

Returns:
Tuple[List[str], List[str], List[str]]: A tuple containing three lists:
- reference_words: The aligned reference words.
- hypothesis_words: The aligned hypothesis words.
- operation_chars: The operation characters for each alignment (' ' for equal, 'S' for substitute,
'D' for delete, 'I' for insert).

Raises:
ValueError: If an unparsable operation type is encountered.
"""
reference_words, hypothesis_words, operation_chars = [], [], []

for operation in operations:
if operation.type == "equal" or operation.type == "substitute":
ref_chunk_words = reference[operation.ref_start_idx : operation.ref_end_idx]
hyp_chunk_words = hypothesis[operation.hyp_start_idx : operation.hyp_end_idx]
operation_char = " " if operation.type == "equal" else "S"
elif operation.type == "delete":
ref_chunk_words = reference[operation.ref_start_idx : operation.ref_end_idx]
hyp_chunk_words = ["*" * len(word) for word in ref_chunk_words]
operation_char = "D"
elif operation.type == "insert":
hyp_chunk_words = hypothesis[operation.hyp_start_idx : operation.hyp_end_idx]
ref_chunk_words = ["*" * len(word) for word in hyp_chunk_words]
operation_char = "I"
else:
raise ValueError(f"Unparsable operation: {operation.type}")

operation_chars.extend([operation_char] * len(ref_chunk_words))
reference_words.extend(ref_chunk_words)
hypothesis_words.extend(hyp_chunk_words)

return reference_words, hypothesis_words, operation_chars
21 changes: 20 additions & 1 deletion tests/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ def test_skip_correct(self):
)
self.assertEqual(alignment, correct_alignment)

def test_get_alignment_words(self):
reference = ["this is a test of alignment words", "this is a test of the alignment words"]
hypothesis = ["this is a test of alignment words", "this is also a test of alignment wordz!"]
output = jiwer.process_words(reference, hypothesis)

reference_words, hypothesis_words, operation_chars = jiwer.get_alignment_words(
output.references[0], output.hypotheses[0], output.alignments[0]
)
self.assertEqual(reference_words, ["this", "is", "a", "test", "of", "alignment", "words"])
self.assertEqual(hypothesis_words, ["this", "is", "a", "test", "of", "alignment", "words"])
self.assertEqual(operation_chars, [" "] * len(reference_words))

reference_words, hypothesis_words, operation_chars = jiwer.get_alignment_words(
output.references[1], output.hypotheses[1], output.alignments[1]
)
self.assertEqual(reference_words, ["this", "is", "****", "a", "test", "of", "the", "alignment", "words"])
self.assertEqual(hypothesis_words, ["this", "is", "also", "a", "test", "of", "***", "alignment", "wordz!"])
self.assertEqual(operation_chars, [" ", " ", "I", " ", " ", " ", "D", " ", "S"])


class TestAlignmentVisualizationCharacters(unittest.TestCase):
def test_insertion(self):
Expand Down Expand Up @@ -222,4 +241,4 @@ def test_multiple_sentences(self):
jiwer.process_characters(["one", "two"], ["1", "2"]),
show_measures=False,
)
self.assertEqual(alignment, correct_alignment)
self.assertEqual(alignment, correct_alignment)