diff --git a/meeteval/der/md_eval.py b/meeteval/der/md_eval.py index d08a2f9c..4e2901e2 100644 --- a/meeteval/der/md_eval.py +++ b/meeteval/der/md_eval.py @@ -77,24 +77,16 @@ def __add__(self, other: 'DiaErrorRate'): ) -def _md_eval_22( - reference, - hypothesis, - average_out='{parent}/{stem}_md_eval_22.json', - per_reco_out='{parent}/{stem}_md_eval_22_per_reco.json', - collar=0, - regex=None, -): - from meeteval.wer.__main__ import _load_texts - - r, _, h, hypothesis_paths = _load_texts( - reference, hypothesis, regex) +def md_eval_22_multifile(reference, hypothesis, collar=0): + from meeteval.io.rttm import RTTM + reference = RTTM.new(reference) + hypothesis = RTTM.new(hypothesis) - r = _fix_channel(r.to_rttm()) - h = _fix_channel(h.to_rttm()) + reference = _fix_channel(reference) + hypothesis = _fix_channel(hypothesis) - r = r.grouped_by_filename() - h = h.grouped_by_filename() + r = reference.grouped_by_filename() + h = hypothesis.grouped_by_filename() keys = set(r.keys()) & set(h.keys()) missing = set(r.keys()) ^ set(h.keys()) @@ -135,7 +127,8 @@ def get_details(r, h, key, tmpdir): # SPEAKER ERROR TIME =0.000000 secs # OVERALL SPEAKER DIARIZATION ERROR = 100.00 percent of scored speaker time `(ALL) - error_rate, = re.findall(r'OVERALL SPEAKER DIARIZATION ERROR = ([\d.]+) percent of scored speaker time', cp.stdout) + error_rate, = re.findall(r'OVERALL SPEAKER DIARIZATION ERROR = ([\d.]+) percent of scored speaker time', + cp.stdout) length, = re.findall(r'SCORED SPEAKER TIME =([\d.]+) secs', cp.stdout) deletions, = re.findall(r'MISSED SPEAKER TIME =([\d.]+) secs', cp.stdout) insertions, = re.findall(r'FALARM SPEAKER TIME =([\d.]+) secs', cp.stdout) @@ -173,6 +166,35 @@ def convert(string): f'does not match the average error rate of md-eval-22.pl ' f'applied to each recording ({md_eval.error_rate}).' ) + return per_reco + + +def md_eval_22(reference, hypothesis, collar=0): + from meeteval.io.rttm import RTTM + reference = RTTM.new(reference, filename='dummy') + hypothesis = RTTM.new(hypothesis, filename='dummy') + + assert len(reference.filenames()) == 1, reference.filenames() + assert len(hypothesis.filenames()) == 1, hypothesis.filenames() + assert reference.filenames() == hypothesis.filenames(), (reference.filenames(), hypothesis.filenames()) + + return md_eval_22_multifile(reference, hypothesis, collar)[reference.filenames()[0]] + + +def _md_eval_22( + reference, + hypothesis, + average_out='{parent}/{stem}_md_eval_22.json', + per_reco_out='{parent}/{stem}_md_eval_22_per_reco.json', + collar=0, + regex=None, +): + from meeteval.wer.__main__ import _load_texts + + r, _, h, hypothesis_paths = _load_texts( + reference, hypothesis, regex) + + per_reco = md_eval_22_multifile(r, h, collar) from meeteval.wer.__main__ import _save_results _save_results(per_reco, hypothesis_paths, per_reco_out, average_out)