Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the ORC assignment again #92

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion meeteval/wer/wer/orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _orc_error_rate(
total_num_segments = len(set(reference.T['segment_index']))
reference = reference.filter(lambda x: x['words'] != '')
hypothesis = hypothesis.filter(lambda x: x['words'] != '')
num_filtered_segments = len(reference.unique('segment_index'))

# Group by stream. For ORC-WER, only hypothesis must be grouped
hypothesis = hypothesis.groupby('speaker')
Expand All @@ -106,6 +107,7 @@ def _orc_error_rate(

# Compute the ORC distance
distance, assignment = matching_fn(reference, hypothesis)
assert len(assignment) == num_filtered_segments, (len(assignment), num_filtered_segments)

# Translate the assignment from hypothesis index to stream id
# Fill with a dummy stream if hypothesis is empty
Expand All @@ -132,7 +134,7 @@ def _orc_error_rate(
assert er.errors == distance, (distance, er, assignment)

# Get the assignment in the order of segment_index
assignment = [r['speaker'] for r in reference_new.sorted('segment_index')]
assignment = [a for _, a in sorted(zip(reference.groupby('segment_index').keys(), assignment))]

# Insert labels for empty segments that got removed
if len(assignment) != total_num_segments:
Expand Down
1 change: 1 addition & 0 deletions meeteval/wer/wer/time_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def time_constrained_siso_word_error_rate(
hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing,
collar=collar,
segment_representation='word',
remove_empty_segments=True,
)

er = _time_constrained_siso_error_rate(reference, hypothesis)
Expand Down
24 changes: 23 additions & 1 deletion tests/test_time_constrained_orc_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import pytest
from hypothesis import assume, settings, given, strategies as st, reproduce_failure

import meeteval.io
from meeteval.io import SegLST
from meeteval.wer import combine_error_rates


# Limit alphabet to ensure a few correct matches
Expand Down Expand Up @@ -49,13 +51,33 @@ def seglst(draw, min_segments=0, max_segments=10, max_speakers=2):
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
def test_tcorc_burn(reference, hypothesis):
from meeteval.wer.wer.time_constrained_orc import time_constrained_orc_wer
from meeteval.wer.wer.time_constrained import time_constrained_siso_word_error_rate

tcorc = time_constrained_orc_wer(reference, hypothesis, collar=1000, reference_sort=False, hypothesis_sort=False)
tcorc = time_constrained_orc_wer(reference, hypothesis, collar=5, reference_sort=False, hypothesis_sort=False)

assert len(tcorc.assignment) == len(reference)
assert isinstance(tcorc.errors, int)
assert tcorc.errors >= 0
assigned_reference, assigned_hypothesis = tcorc.apply_assignment(reference, hypothesis)
assigned_reference = assigned_reference.groupby('speaker')
assigned_hypothesis = assigned_hypothesis.groupby('speaker')
er = combine_error_rates(
*[
time_constrained_siso_word_error_rate(
assigned_reference.get(k, meeteval.io.SegLST([])),
assigned_hypothesis.get(k, meeteval.io.SegLST([])),
reference_sort=False,
hypothesis_sort=False,
collar=5,
)
for k in set(assigned_reference.keys()) | set(assigned_hypothesis.keys())
]
)
assert er.errors == tcorc.errors
assert er.length == tcorc.length
assert er.insertions == tcorc.insertions
assert er.deletions == tcorc.deletions
assert er.substitutions == tcorc.substitutions


@given(
Expand Down
Loading