diff --git a/meeteval/wer/wer/orc.py b/meeteval/wer/wer/orc.py index 9b50662..779073b 100644 --- a/meeteval/wer/wer/orc.py +++ b/meeteval/wer/wer/orc.py @@ -90,6 +90,7 @@ def _orc_error_rate( total_num_segments = len(set(reference.T['segment_index'])) reference = reference.filter(lambda x: x['words'] != '') hypothesis = hypothesis.filter(lambda x: x['words'] != '') + num_filtered_segments = len(reference.unique('segment_index')) # Group by stream. For ORC-WER, only hypothesis must be grouped hypothesis = hypothesis.groupby('speaker') @@ -106,6 +107,7 @@ def _orc_error_rate( # Compute the ORC distance distance, assignment = matching_fn(reference, hypothesis) + assert len(assignment) == num_filtered_segments, (len(assignment), num_filtered_segments) # Translate the assignment from hypothesis index to stream id # Fill with a dummy stream if hypothesis is empty @@ -132,7 +134,7 @@ def _orc_error_rate( assert er.errors == distance, (distance, er, assignment) # Get the assignment in the order of segment_index - assignment = [r['speaker'] for r in reference_new.sorted('segment_index')] + assignment = [a for _, a in sorted(zip(reference.groupby('segment_index').keys(), assignment))] # Insert labels for empty segments that got removed if len(assignment) != total_num_segments: diff --git a/meeteval/wer/wer/time_constrained.py b/meeteval/wer/wer/time_constrained.py index 84953f4..b3ca128 100644 --- a/meeteval/wer/wer/time_constrained.py +++ b/meeteval/wer/wer/time_constrained.py @@ -632,6 +632,7 @@ def time_constrained_siso_word_error_rate( hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing, collar=collar, segment_representation='word', + remove_empty_segments=True, ) er = _time_constrained_siso_error_rate(reference, hypothesis) diff --git a/tests/test_time_constrained_orc_matching.py b/tests/test_time_constrained_orc_matching.py index 5c49a47..15a8d73 100644 --- a/tests/test_time_constrained_orc_matching.py +++ b/tests/test_time_constrained_orc_matching.py @@ -3,7 +3,9 @@ import pytest from hypothesis import assume, settings, given, strategies as st, reproduce_failure +import meeteval.io from meeteval.io import SegLST +from meeteval.wer import combine_error_rates # Limit alphabet to ensure a few correct matches @@ -49,13 +51,33 @@ def seglst(draw, min_segments=0, max_segments=10, max_speakers=2): @settings(deadline=None) # The tests take longer on the GitHub actions test servers def test_tcorc_burn(reference, hypothesis): from meeteval.wer.wer.time_constrained_orc import time_constrained_orc_wer + from meeteval.wer.wer.time_constrained import time_constrained_siso_word_error_rate - tcorc = time_constrained_orc_wer(reference, hypothesis, collar=1000, reference_sort=False, hypothesis_sort=False) + tcorc = time_constrained_orc_wer(reference, hypothesis, collar=5, reference_sort=False, hypothesis_sort=False) assert len(tcorc.assignment) == len(reference) assert isinstance(tcorc.errors, int) assert tcorc.errors >= 0 assigned_reference, assigned_hypothesis = tcorc.apply_assignment(reference, hypothesis) + assigned_reference = assigned_reference.groupby('speaker') + assigned_hypothesis = assigned_hypothesis.groupby('speaker') + er = combine_error_rates( + *[ + time_constrained_siso_word_error_rate( + assigned_reference.get(k, meeteval.io.SegLST([])), + assigned_hypothesis.get(k, meeteval.io.SegLST([])), + reference_sort=False, + hypothesis_sort=False, + collar=5, + ) + for k in set(assigned_reference.keys()) | set(assigned_hypothesis.keys()) + ] + ) + assert er.errors == tcorc.errors + assert er.length == tcorc.length + assert er.insertions == tcorc.insertions + assert er.deletions == tcorc.deletions + assert er.substitutions == tcorc.substitutions @given(