Skip to content

Commit

Permalink
Merge pull request #95 from fgnt/greedy_ditcp
Browse files Browse the repository at this point in the history
Add greedy DI-tcpWER
  • Loading branch information
thequilo authored Sep 20, 2024
2 parents af222cd + e8beb1e commit f2d133a
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 10 deletions.
29 changes: 29 additions & 0 deletions meeteval/wer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,34 @@ def greedy_dicpwer(
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy-DI-cpWER')


def greedy_ditcpwer(
reference, hypothesis,
average_out='{parent}/{stem}_greedy_ditcpwer.json',
per_reco_out='{parent}/{stem}_greedy_ditcpwer_per_reco.json',
regex=None,
collar=0,
hyp_pseudo_word_timing='character_based_points',
ref_pseudo_word_timing='character_based',
hypothesis_sort='segment',
reference_sort='segment',
uem=None,
partial=False,
normalizer=None,
):
"""Computes the time-constrained diarization-invariant cpWER (greedy DI-tcpWER)"""
results = meeteval.wer.api.greedy_ditcpwer(
reference, hypothesis, regex=regex,
collar=collar,
hyp_pseudo_word_timing=hyp_pseudo_word_timing,
ref_pseudo_word_timing=ref_pseudo_word_timing,
hypothesis_sort=hypothesis_sort,
reference_sort=reference_sort,
uem=uem, partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy-DI-tcpWER')


def _merge(
files: 'list[str]',
out: str = '-',
Expand Down Expand Up @@ -695,6 +723,7 @@ def cli():
cli.add_command(tcpwer)
cli.add_command(tcorcwer)
cli.add_command(greedy_dicpwer)
cli.add_command(greedy_ditcpwer)
cli.add_command(greedy_tcorcwer)
cli.add_command(merge)
cli.add_command(average)
Expand Down
38 changes: 38 additions & 0 deletions meeteval/wer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'tcorcwer',
'greedy_tcorcwer',
'greedy_dicpwer',
'greedy_ditcpwer',
]


Expand Down Expand Up @@ -361,6 +362,43 @@ def greedy_dicpwer(
return results


def greedy_ditcpwer(
reference, hypothesis,
regex=None,
collar=0,
hyp_pseudo_word_timing='character_based_points',
ref_pseudo_word_timing='character_based',
hypothesis_sort='segment',
reference_sort='segment',
uem=None,
partial=False,
normalizer=None,
):
"""Computes the Diarization Invariant cpWER (DI-cpWER) with a greedy
algorithm."""
from meeteval.wer.wer.di_cp import greedy_di_tcp_word_error_rate_multifile
reference, hypothesis = _load_texts(
reference, hypothesis, regex=regex,
uem=uem, normalizer=normalizer,
)
results = greedy_di_tcp_word_error_rate_multifile(
reference, hypothesis,
reference_pseudo_word_level_timing=ref_pseudo_word_timing,
hypothesis_pseudo_word_level_timing=hyp_pseudo_word_timing,
collar=collar,
reference_sort=reference_sort,
hypothesis_sort=hypothesis_sort,
partial=partial,
)
from meeteval.wer import combine_error_rates
average: ErrorRate = combine_error_rates(results)
if average.hypothesis_self_overlap is not None:
average.hypothesis_self_overlap.warn('hypothesis')
if average.reference_self_overlap is not None:
average.reference_self_overlap.warn('reference')
return results


def greedy_tcorcwer(
reference, hypothesis,
regex=None,
Expand Down
5 changes: 3 additions & 2 deletions meeteval/wer/matching/greedy_combination_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def initialize_assignment(
>>> initialize_assignment(segments, streams, initialization='cp')
[0, 0]
"""
if initialization == 'cp':
if initialization in ('cp', 'tcp'):
# Special case when no streams are present
if len(streams) == 0:
return initialize_assignment(segments, streams, 'constant')
Expand All @@ -192,13 +192,14 @@ def initialize_assignment(
# Compute cpWER to get a good starting point
from meeteval.wer.wer.cp import _minimum_permutation_assignment
from meeteval.wer.wer.siso import siso_levenshtein_distance
from meeteval.wer.wer.time_constrained import time_constrained_siso_levenshtein_distance
if isinstance(streams, SegLST):
streams = streams.groupby('speaker')
speaker_grouped_segments = segments.groupby('speaker')
assignment, _, cost_matrix = _minimum_permutation_assignment(
speaker_grouped_segments,
streams,
distance_fn=siso_levenshtein_distance,
distance_fn=siso_levenshtein_distance if initialization == 'cp' else time_constrained_siso_levenshtein_distance,
)
# Use integers for the assignment labels.
counter = iter(itertools.count())
Expand Down
3 changes: 2 additions & 1 deletion meeteval/wer/wer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .siso import siso_word_error_rate, siso_character_error_rate, siso_word_error_rate_multifile
from .error_rate import ErrorRate, combine_error_rates
from .time_constrained import time_constrained_minimum_permutation_word_error_rate, time_constrained_siso_word_error_rate, tcp_word_error_rate_multifile
from .di_cp import greedy_di_cp_word_error_rate, DICPErrorRate, greedy_di_cp_word_error_rate_multifile
from .di_cp import greedy_di_cp_word_error_rate, DICPErrorRate, greedy_di_cp_word_error_rate_multifile, greedy_di_tcp_word_error_rate, greedy_di_tcp_word_error_rate_multifile
from .time_constrained_orc import time_constrained_orc_wer, time_constrained_orc_wer_multifile, greedy_time_constrained_orc_wer, greedy_time_constrained_orc_wer_multifile
82 changes: 82 additions & 0 deletions meeteval/wer/wer/di_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
'DICPErrorRate',
'greedy_di_cp_word_error_rate',
'greedy_di_cp_word_error_rate_multifile',
'greedy_di_tcp_word_error_rate',
'greedy_di_tcp_word_error_rate_multifile',
'apply_dicp_assignment',
]

Expand Down Expand Up @@ -111,6 +113,86 @@ def greedy_di_cp_word_error_rate_multifile(
)


def greedy_di_tcp_word_error_rate(
reference,
hypothesis,
reference_pseudo_word_level_timing='character_based',
hypothesis_pseudo_word_level_timing='character_based_points',
collar: int = 0,
reference_sort='segment',
hypothesis_sort='segment',
):
"""
Computes the DI-tcpWER (time-constrained DI-cpWER) with a greedy algorithm
>>> reference = SegLST([
... {'segment_index': 0, 'speaker': 'A', 'words': 'a', 'start_time': 0.0, 'end_time': 1.0},
... {'segment_index': 1, 'speaker': 'A', 'words': 'b', 'start_time': 1.0, 'end_time': 2.0},
... {'segment_index': 2, 'speaker': 'B', 'words': 'c', 'start_time': 2.0, 'end_time': 3.0},
... {'segment_index': 3, 'speaker': 'B', 'words': 'd', 'start_time': 3.0, 'end_time': 4.0},
... ])
>>> greedy_di_tcp_word_error_rate(reference, reference)
DICPErrorRate(error_rate=0.0, errors=0, length=4, insertions=0, deletions=0, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=6.0), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=6.0), assignment=('A', 'A', 'B', 'B'))
>>> hypothesis = SegLST([
... {'segment_index': 0, 'speaker': 'A', 'words': 'a b', 'start_time': 0.0, 'end_time': 2.0},
... {'segment_index': 2, 'speaker': 'A', 'words': 'b c d', 'start_time': 1.0, 'end_time': 4.0},
... ])
>>> greedy_di_tcp_word_error_rate(reference, hypothesis)
DICPErrorRate(error_rate=0.25, errors=1, length=4, insertions=1, deletions=0, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=6.0), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.25, overlap_time=1.0, total_time=4.0), assignment=('A', 'B'))
"""

# The assignment of the DI-tcpWER is equal to the assignment of the tcORC-WER
# with swapped arguments (reference <-> hypothesis)
er = meeteval.wer.wer.time_constrained_orc.greedy_time_constrained_orc_wer(
hypothesis, reference,
hypothesis_pseudo_word_level_timing, reference_pseudo_word_level_timing,
collar,
hypothesis_sort, reference_sort
)

# The error rate object can be constructed just from the ORC-WER error rate
# object. Insertions and deletions are swapped, the length is different.
return DICPErrorRate(
er.errors, sum([len(s['words'].split()) for s in reference]),
insertions=er.deletions,
deletions=er.insertions,
substitutions=er.substitutions,
assignment=er.assignment,
reference_self_overlap=er.hypothesis_self_overlap,
hypothesis_self_overlap=er.reference_self_overlap,
)


def greedy_di_tcp_word_error_rate_multifile(
reference,
hypothesis,
partial=False,
reference_pseudo_word_level_timing='character_based',
hypothesis_pseudo_word_level_timing='character_based_points',
collar: int = 0,
reference_sort='segment',
hypothesis_sort='segment',
) -> 'dict[str, DICPErrorRate]':
"""
Computes the (Greedy) DI-tcpWER for each example in the reference and hypothesis files.
To compute the overall WER, use
`sum(greedy_di_tcp_word_error_rate_multifile(r, h).values())`.
"""
from meeteval.io.seglst import apply_multi_file
return apply_multi_file(
functools.partial(
greedy_di_tcp_word_error_rate,
reference_pseudo_word_level_timing=reference_pseudo_word_level_timing,
hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing,
collar=collar,
reference_sort=reference_sort,
hypothesis_sort=hypothesis_sort,
), reference, hypothesis,
partial=partial
)


def apply_dicp_assignment(
assignment: 'list[int | str] | tuple[int | str]',
reference: 'list[list[str]] | dict[str, list[str]] | SegLST',
Expand Down
7 changes: 7 additions & 0 deletions meeteval/wer/wer/mimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ class MimoErrorRate(ErrorRate):
"""
assignment: 'tuple[int, ...]'

def apply_assignment(self, reference, hypothesis):
return apply_mimo_assignment(
self.assignment,
reference=reference,
hypothesis=hypothesis,
)


def mimo_error_rate(
reference: 'list[list[Iterable]] | dict[Any, list[Iterable]]',
Expand Down
18 changes: 16 additions & 2 deletions meeteval/wer/wer/time_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,10 @@ def apply_collar(s: SegLST, collar: float):
return s.map(
lambda s: {
**s,
'start_time': [t - collar for t in s['start_time']] if isinstance(s['start_time'], list) else s['start_time'] - collar,
'end_time': [t + collar for t in s['end_time']] if isinstance(s['end_time'], list) else s['end_time'] + collar,
'start_time': [t - collar for t in s['start_time']] if isinstance(s['start_time'], list) else s[
'start_time'] - collar,
'end_time': [t + collar for t in s['end_time']] if isinstance(s['end_time'], list) else s[
'end_time'] + collar,
}
)

Expand Down Expand Up @@ -549,6 +551,18 @@ def time_constrained_siso_levenshtein_distance(
) -> int:
from meeteval.wer.matching.cy_levenshtein import time_constrained_levenshtein_distance

# Flatten words when segment_representation is 'segment'
def flatten_words(s):
if not isinstance(s['words'], list):
return [s]
return [
{'words': w, 'start_time': st, 'end_time': et}
for w, st, et in zip(s['words'], s['start_time'], s['end_time'])
]

reference = reference.flatmap(flatten_words)
hypothesis = hypothesis.flatmap(flatten_words)

# Ignore empty segments
reference = reference.filter(lambda s: s['words'])
hypothesis = hypothesis.filter(lambda s: s['words'])
Expand Down
2 changes: 1 addition & 1 deletion meeteval/wer/wer/time_constrained_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def matching(reference, hypothesis):
[list(zip(*r)) for r in reference.T['words', 'start_time', 'end_time']],
[[w for words in stream.T['words', 'start_time', 'end_time'] for w in zip(*words)] for stream in
hypothesis.values()],
initial_assignment=initialize_assignment(reference, hypothesis, initialization='cp'),
initial_assignment=initialize_assignment(reference, hypothesis, initialization='tcp'),
)
return distance, assignment

Expand Down
6 changes: 6 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def test_burn_greedy_dicp():
run(f'meeteval-wer greedy_dicpwer -h hyp.stm -r ref.stm')


def test_burn_greedy_ditcp():
# Normal test with stm files
run(f'python -m meeteval.wer greedy_ditcpwer -h hyp.stm -r ref.stm')
run(f'meeteval-wer greedy_ditcpwer -h hyp.stm -r ref.stm')


def test_burn_mimo():
run(f'python -m meeteval.wer mimower -h hyp.stm -r ref.stm')
run(f"python -m meeteval.wer mimower -h 'hyp?.stm' -r 'ref?.stm'")
Expand Down
42 changes: 38 additions & 4 deletions tests/test_di_cp.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from hypothesis import given, strategies as st, assume, settings
import meeteval


seglst = st.builds(
meeteval.io.SegLST,
st.lists(
st.builds(
dict,
lambda **x: {
'end_time': x['start_time'] + x.pop('duration'),
**x,
},
speaker=st.sampled_from(['spkA', 'spkB', 'spkC', 'spkD', 'spkE', 'spkF', 'spkG', 'spkH']),
words=st.text(),
session_id=st.just('session1'),
start_time=st.floats(min_value=0, max_value=100),
duration=st.floats(min_value=0, max_value=10),
),
)
)


@given(seglst, seglst)
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
def test_greedy_di_cp_bound_by_cp(ref, hyp):
cp = meeteval.wer.wer.cp.cp_word_error_rate(ref, hyp)
dicp = meeteval.wer.wer.di_cp.greedy_di_cp_word_error_rate(ref, hyp)
Expand All @@ -24,7 +29,7 @@ def test_greedy_di_cp_bound_by_cp(ref, hyp):


@given(seglst, seglst)
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
def test_greedy_di_cp_vs_greedy_orc(ref, hyp):
"""
Test that the total distance of the greedy di-cp algorithm is equal to the
Expand All @@ -37,3 +42,32 @@ def test_greedy_di_cp_vs_greedy_orc(ref, hyp):
assert dicp.substitutions == orc.substitutions
assert dicp.insertions == orc.deletions
assert dicp.deletions == orc.insertions


@given(seglst, seglst)
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
def test_greedy_di_tcp_bound_by_tcp(ref, hyp):
cp = meeteval.wer.wer.time_constrained.tcp_word_error_rate(ref, hyp)
dicp = meeteval.wer.wer.di_cp.greedy_di_tcp_word_error_rate(ref, hyp)

assert cp.error_rate is None and dicp.error_rate is None or cp.error_rate >= dicp.error_rate


@given(seglst, seglst)
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
def test_greedy_di_tcp_vs_greedy_torc(ref, hyp):
"""
Test that the total distance of the greedy di-tcp algorithm is equal to the
distance computed by the greedy tcorc algorithm with swapped arguments
"""
dicp = meeteval.wer.wer.di_cp.greedy_di_tcp_word_error_rate(ref, hyp)
orc = meeteval.wer.wer.time_constrained_orc.greedy_time_constrained_orc_wer(
hyp, ref,
reference_pseudo_word_level_timing='character_based_points',
hypothesis_pseudo_word_level_timing='character_based',
)

assert dicp.errors == orc.errors
assert dicp.substitutions == orc.substitutions
assert dicp.insertions == orc.deletions
assert dicp.deletions == orc.insertions

0 comments on commit f2d133a

Please sign in to comment.