Skip to content

Commit

Permalink
Merge pull request #93 from fgnt/print_wer_summary
Browse files Browse the repository at this point in the history
Print wer summary
  • Loading branch information
thequilo authored Sep 19, 2024
2 parents 9d5125a + a3821d1 commit 8377855
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion meeteval/der/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def md_eval_22(
regions=regions,
uem=uem,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='DER')


def cli():
Expand Down
39 changes: 28 additions & 11 deletions meeteval/wer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _save_results(
hypothesis_paths: 'list[Path]',
per_reco_out: str,
average_out: str,
wer_name: str = 'WER',
):
"""Saves the results.
"""
Expand All @@ -130,6 +131,23 @@ def _save_results(
dataclasses.asdict(average),
average_out.format(parent=parent, stem=stem),
)
if hasattr(average, 'scored_speaker_time'):
error_time = average.missed_speaker_time + average.falarm_speaker_time + average.speaker_error_time
logging.info(
f'%{wer_name}: {average.error_rate:.2%} '
f'[ {error_time:.2f}s / {average.scored_speaker_time:.2f}s, '
f'{average.missed_speaker_time:.2f}s missed, '
f'{average.falarm_speaker_time:.2f}s falarm, '
f'{average.speaker_error_time:.2f}s spk error ]'
)
else:
logging.info(
f'%{wer_name}: {average.error_rate:.2%} '
f'[ {average.errors} / {average.length}, '
f'{average.insertions} ins, '
f'{average.deletions} del, '
f'{average.substitutions} sub ]'
)
return average


Expand All @@ -156,7 +174,7 @@ def wer(
hypothesis = KeyedText.load(hypothesis)
from meeteval.wer.wer.siso import siso_word_error_rate_multifile
results = siso_word_error_rate_multifile(reference, hypothesis)
_save_results(results, hypothesis_paths, per_reco_out, average_out)
_save_results(results, hypothesis_paths, per_reco_out, average_out, wer_name='WER')


def orcwer(
Expand All @@ -178,7 +196,7 @@ def orcwer(
partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='ORC-WER')


def greedy_orcwer(
Expand All @@ -202,7 +220,7 @@ def greedy_orcwer(
partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy ORC-WER')


def cpwer(
Expand All @@ -222,7 +240,7 @@ def cpwer(
reference_sort=reference_sort, hypothesis_sort=hypothesis_sort,
uem=uem, partial=partial, normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='cpWER')


def mimower(
Expand All @@ -242,7 +260,7 @@ def mimower(
reference_sort=reference_sort, hypothesis_sort=hypothesis_sort,
uem=uem, partial=partial, normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='MIMO-WER')


def tcpwer(
Expand All @@ -269,7 +287,7 @@ def tcpwer(
hypothesis_sort=hypothesis_sort,
uem=uem, normalizer=normalizer, partial=partial,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='tcpWER')


def tcorcwer(
Expand Down Expand Up @@ -297,7 +315,7 @@ def tcorcwer(
uem=uem, partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='tcORC-WER')


def greedy_tcorcwer(
Expand Down Expand Up @@ -325,8 +343,7 @@ def greedy_tcorcwer(
uem=uem, partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)

_save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy-tcORC-WER')


def greedy_dicpwer(
Expand All @@ -348,7 +365,7 @@ def greedy_dicpwer(
partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy-DI-cpWER')


def _merge(
Expand Down Expand Up @@ -649,7 +666,7 @@ def run(self):
args = self.parser.parse_args()

# Logging
logging.basicConfig(level=args.log_level.upper(), format='%(levelname)s - %(message)s')
logging.basicConfig(level=args.log_level.upper(), format='%(levelname)s %(message)s', force=True)

if hasattr(args, 'func'):
kwargs = vars(args)
Expand Down

0 comments on commit 8377855

Please sign in to comment.