diff --git a/meeteval/der/__main__.py b/meeteval/der/__main__.py index f7861ca0..1da32d18 100644 --- a/meeteval/der/__main__.py +++ b/meeteval/der/__main__.py @@ -20,7 +20,7 @@ def md_eval_22( regions=regions, uem=uem, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='DER') def cli(): diff --git a/meeteval/wer/__main__.py b/meeteval/wer/__main__.py index 714fefce..303beb9b 100644 --- a/meeteval/wer/__main__.py +++ b/meeteval/wer/__main__.py @@ -113,6 +113,7 @@ def _save_results( hypothesis_paths: 'list[Path]', per_reco_out: str, average_out: str, + wer_name: str = 'WER', ): """Saves the results. """ @@ -130,6 +131,23 @@ def _save_results( dataclasses.asdict(average), average_out.format(parent=parent, stem=stem), ) + if hasattr(average, 'scored_speaker_time'): + error_time = average.missed_speaker_time + average.falarm_speaker_time + average.speaker_error_time + logging.info( + f'%{wer_name}: {average.error_rate:.2%} ' + f'[ {error_time:.2f}s / {average.scored_speaker_time:.2f}s, ' + f'{average.missed_speaker_time:.2f}s missed, ' + f'{average.falarm_speaker_time:.2f}s falarm, ' + f'{average.speaker_error_time:.2f}s spk error ]' + ) + else: + logging.info( + f'%{wer_name}: {average.error_rate:.2%} ' + f'[ {average.errors} / {average.length}, ' + f'{average.insertions} ins, ' + f'{average.deletions} del, ' + f'{average.substitutions} sub ]' + ) return average @@ -156,7 +174,7 @@ def wer( hypothesis = KeyedText.load(hypothesis) from meeteval.wer.wer.siso import siso_word_error_rate_multifile results = siso_word_error_rate_multifile(reference, hypothesis) - _save_results(results, hypothesis_paths, per_reco_out, average_out) + _save_results(results, hypothesis_paths, per_reco_out, average_out, wer_name='WER') def orcwer( @@ -178,7 +196,7 @@ def orcwer( partial=partial, normalizer=normalizer, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='ORC-WER') def greedy_orcwer( @@ -202,7 +220,7 @@ def greedy_orcwer( partial=partial, normalizer=normalizer, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy ORC-WER') def cpwer( @@ -222,7 +240,7 @@ def cpwer( reference_sort=reference_sort, hypothesis_sort=hypothesis_sort, uem=uem, partial=partial, normalizer=normalizer, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='cpWER') def mimower( @@ -242,7 +260,7 @@ def mimower( reference_sort=reference_sort, hypothesis_sort=hypothesis_sort, uem=uem, partial=partial, normalizer=normalizer, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='MIMO-WER') def tcpwer( @@ -269,7 +287,7 @@ def tcpwer( hypothesis_sort=hypothesis_sort, uem=uem, normalizer=normalizer, partial=partial, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='tcpWER') def tcorcwer( @@ -297,7 +315,7 @@ def tcorcwer( uem=uem, partial=partial, normalizer=normalizer, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='tcORC-WER') def greedy_tcorcwer( @@ -325,8 +343,7 @@ def greedy_tcorcwer( uem=uem, partial=partial, normalizer=normalizer, ) - _save_results(results, hypothesis, per_reco_out, average_out) - + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy-tcORC-WER') def greedy_dicpwer( @@ -348,7 +365,7 @@ def greedy_dicpwer( partial=partial, normalizer=normalizer, ) - _save_results(results, hypothesis, per_reco_out, average_out) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='greedy-DI-cpWER') def _merge( @@ -649,7 +666,7 @@ def run(self): args = self.parser.parse_args() # Logging - logging.basicConfig(level=args.log_level.upper(), format='%(levelname)s - %(message)s') + logging.basicConfig(level=args.log_level.upper(), format='%(levelname)s %(message)s', force=True) if hasattr(args, 'func'): kwargs = vars(args)