Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize alignment processing and remove vocabulary size limitations #100

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 57 additions & 42 deletions jiwer/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from rapidfuzz.distance import Opcodes

from collections import defaultdict

from jiwer import transforms as tr
from jiwer.transformations import wer_default, cer_default

Expand Down Expand Up @@ -149,6 +151,10 @@ def process_words(

Returns:
(WordOutput): The processed reference and hypothesis sentences

Raises:
ValueError: If one or more references are empty strings
ValueError: If after applying transforms, reference and hypothesis lengths don't match
"""
# validate input type
if isinstance(reference, str):
Expand All @@ -174,9 +180,9 @@ def process_words(
f"{len(hyp_transformed)} hypothesis sentences."
)

# Change each word into a unique character in order to compute
# Map each word into a unique integer in order to compute
# word-level levenshtein distance
ref_as_chars, hyp_as_chars = _word2char(ref_transformed, hyp_transformed)
ref_as_ints, hyp_as_ints = _word2int(ref_transformed, hyp_transformed)

# keep track of total hits, substitutions, deletions and insertions
# across all input sentences
Expand All @@ -188,37 +194,45 @@ def process_words(
# anf finally, keep track of the alignment between each reference and hypothesis
alignments = []

for reference_sentence, hypothesis_sentence in zip(ref_as_chars, hyp_as_chars):
# Get the required edit operations to transform reference into hypothesis
edit_ops = rapidfuzz.distance.Levenshtein.editops(
for reference_sentence, hypothesis_sentence in zip(ref_as_ints, hyp_as_ints):
# Get the opcodes directly
opcodes = rapidfuzz.distance.Levenshtein.opcodes(
reference_sentence, hypothesis_sentence
)

# count the number of edits of each type
substitutions = sum(1 if op.tag == "replace" else 0 for op in edit_ops)
deletions = sum(1 if op.tag == "delete" else 0 for op in edit_ops)
insertions = sum(1 if op.tag == "insert" else 0 for op in edit_ops)
hits = len(reference_sentence) - (substitutions + deletions)
subs = dels = ins = hits = 0
sentence_op_chunks = []

for tag, i1, i2, j1, j2 in opcodes:
# Create alignment chunk
sentence_op_chunks.append(
AlignmentChunk(
type=tag,
ref_start_idx=i1,
ref_end_idx=i2,
hyp_start_idx=j1,
hyp_end_idx=j2,
)
)

# update state
# Update counts
if tag == "equal":
hits += i2 - i1
elif tag == "replace":
subs += i2 - i1
elif tag == "delete":
dels += i2 - i1
elif tag == "insert":
ins += j2 - j1

# Update global counts
num_hits += hits
num_substitutions += substitutions
num_deletions += deletions
num_insertions += insertions
num_substitutions += subs
num_deletions += dels
num_insertions += ins
num_rf_words += len(reference_sentence)
num_hp_words += len(hypothesis_sentence)
alignments.append(
[
AlignmentChunk(
type=op.tag,
ref_start_idx=op.src_start,
ref_end_idx=op.src_end,
hyp_start_idx=op.dest_start,
hyp_end_idx=op.dest_end,
)
for op in Opcodes.from_editops(edit_ops)
]
)
alignments.append(sentence_op_chunks)

# Compute all measures
S, D, I, H = num_substitutions, num_deletions, num_insertions, num_hits
Expand Down Expand Up @@ -385,23 +399,24 @@ def _is_list_of_list_of_strings(x: Any, require_non_empty_lists: bool):
return True


def _word2char(reference: List[List[str]], hypothesis: List[List[str]]):
# tokenize each word into an integer
vocabulary = set(chain(*reference, *hypothesis))
def _word2int(reference: List[List[str]], hypothesis: List[List[str]]):
"""
Maps each unique word in the reference and hypothesis sentences to a unique integer
for Levenshtein distance calculation.

if "" in vocabulary:
raise ValueError(
"Empty strings cannot be a word. "
"Please ensure that the given transform removes empty strings."
)
Args:
reference: List of reference sentences, where each sentence is a list of words
hypothesis: List of hypothesis sentences, where each sentence is a list of words

word2char = dict(zip(vocabulary, range(len(vocabulary))))
Returns:
Tuple[List[List[int]], List[List[int]]]: The reference and hypothesis sentences
with words mapped to unique integers
"""
word2int = defaultdict()
word2int.default_factory = word2int.__len__ # Auto-incrementing IDs

reference_chars = [
"".join([chr(word2char[w]) for w in sentence]) for sentence in reference
]
hypothesis_chars = [
"".join([chr(word2char[w]) for w in sentence]) for sentence in hypothesis
]
# Single pass through all words using generator expressions
ref_ints = [[word2int[word] for word in sentence] for sentence in reference]
hyp_ints = [[word2int[word] for word in sentence] for sentence in hypothesis]

return reference_chars, hypothesis_chars
return ref_ints, hyp_ints
62 changes: 62 additions & 0 deletions tests/test_large_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from jiwer import process_words, wer


def test_basic_word_mapping():
"""Test that basic word mapping works correctly."""
# Create a reference with 100 words where 50 are the same in hypothesis
reference = ["same"] * 50 + ["ref_only"] * 50
hypothesis = ["same"] * 50 + ["hyp_only"] * 50

result = process_words(reference=reference, hypothesis=hypothesis)
assert isinstance(result.wer, float)
assert result.wer == 0.5 # 50% of words are different
assert result.hits == 50 # 50 "same" matches


def test_vocabulary_size_limit():
"""Test processing with very large vocabulary (no size limit now)."""
# Create a large vocabulary that would have exceeded the old chr() limit
vocab_size = 0x110000 # 1,114,112 unique words

# Split into reference and hypothesis
reference = [f"word{i}" for i in range(vocab_size // 2)]
hypothesis = [f"word{i}" for i in range(vocab_size // 2, vocab_size)]

try:
result = process_words(reference=reference, hypothesis=hypothesis)
assert isinstance(result.wer, float)
assert result.wer == 1.0 # All words are different
except Exception as e:
pytest.fail(f"Large vocabulary processing failed: {e}")


def test_wer_large_vocabulary():
"""Test WER calculation with very large vocabulary."""
vocab_size = 0x110000 # 1,114,112 unique words, above the chr() limit

reference = " ".join(f"word{i}" for i in range(vocab_size // 2))
hypothesis = " ".join(f"word{i}" for i in range(vocab_size // 2, vocab_size))

try:
error_rate = wer(reference=reference, hypothesis=hypothesis)
assert isinstance(error_rate, float)
assert error_rate == 1.0 # All words are different
except Exception as e:
pytest.fail(f"WER calculation failed with large vocabulary: {e}")


def test_hash_collision_handling():
"""Test that hash collisions don't affect results."""
# Create words that might have hash collisions
reference = ["a" * i for i in range(1, 1001)] # Start from 1 to avoid empty strings
hypothesis = [
"b" * i for i in range(1, 1001)
] # Start from 1 to avoid empty strings

try:
result = process_words(reference=reference, hypothesis=hypothesis)
assert isinstance(result.wer, float)
assert result.wer > 0 # Should detect differences
except Exception as e:
pytest.fail(f"Hash collision test failed: {e}")