Skip to content

Commit

Permalink
viz: cli: sync with other clis
Browse files Browse the repository at this point in the history
  • Loading branch information
boeddeker committed Feb 5, 2024
1 parent 1e7b728 commit b0b507e
Showing 1 changed file with 73 additions and 27 deletions.
100 changes: 73 additions & 27 deletions meeteval/viz/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,40 @@

import meeteval
from meeteval.viz.visualize import AlignmentVisualization
from meeteval.wer.api import _load_texts
import tqdm


def create_viz_folder(
reference,
hypothesiss,
out,
alignment='tcp',
regex=None,
):
out = Path(out)
if isinstance(reference, (str, Path)):
reference = meeteval.io.load(reference).to_seglst()

reference = reference.groupby('session_id')
out.mkdir(parents=True, exist_ok=True)

avs = {}
for i, hypothesis in tqdm.tqdm(hypothesiss.items()):
if isinstance(hypothesis, (str, Path)):
hypothesis = meeteval.io.load(hypothesis).to_seglst()

hypothesis = hypothesis.groupby('session_id')
r, h = _load_texts(
reference, hypothesis, regex=regex,
reference_sort='segment',
hypothesis_sort='segment',
)

r = r.groupby('session_id')
h = h.groupby('session_id')

session_ids = set(reference.keys()) & set(hypothesis.keys())
session_ids = set(r.keys()) & set(h.keys())
xor = set(r.keys()) ^ set(h.keys())
if xor:
print(f'Ignore {xor}, because they are not available in reference and hypothesis.')

for session_id in session_ids:
av = AlignmentVisualization(reference[session_id],
hypothesis[session_id],
av = AlignmentVisualization(r[session_id],
h[session_id],
alignment=alignment)
av.dump(out / f'{session_id}_{i}.html')
avs.setdefault(i, {})[session_id] = av
Expand Down Expand Up @@ -89,13 +96,7 @@ def create_viz_folder(
with tag('body'):
with tag('table', klass='tablesorter', id='myTable'):
with tag('thead'), tag('tr'):

# meeteval.wer.combine_error_rates([
# meeteval.wer.ErrorRate.from_dict(av.data['info']['wer'] for av in avs.values())
#
# ])
def get_wer(v):
print(list(v.values())[0].data['info']['wer']['hypothesis'])
error_rate = meeteval.wer.combine_error_rates(*[
meeteval.wer.ErrorRate.from_dict(av.data['info']['wer']['hypothesis'])
for av in v.values()
Expand Down Expand Up @@ -145,21 +146,66 @@ def get_wer(v):
print(f'Open {(out / "index.html").absolute()}')


def main(ref, *, out, alignment='tcp', **kwargs):
"""
"""
print('kwargs', kwargs)
print('alignment', alignment)
print('out', out)
print('ref', ref)
def html(
reference,
hypothesis,
alignment='tcp',
regex=None,
out='viz',
):
def prepare(i: int, h: str):
if ':' in h and not Path(h).exists():
# inspired by tensorboard from the --logdir_spec argument.
name, path = h.split(':', maxsplit=1)
return name, path
else:
return f'sys{i}', h

assert len(reference) == 1, (len(reference), 'At the moment only shared reference is supported.')

hypothesis = dict([
prepare(i, h)
for i, h in enumerate(hypothesis)
])

create_viz_folder(
reference=ref,
hypothesiss=kwargs,
reference=reference,
hypothesiss=hypothesis,
out=out,
alignment=alignment,
regex=regex,
)


def cli():
from meeteval.wer.__main__ import CLI

class VizCLI(CLI):

def add_argument(self, command_parser, name, p):
if name == 'alignment':
command_parser.add_argument(
'--alignment',
choices=['tcp', 'cp'],
help='Specifies which alignment is used.\n'
'- cp: Find the permutation that minimizes the cpWER and use the "classical" alignment.\n'
'- tcp: Find the permutation that minimizes the tcpWER and use a time constraint alignment.'
)
command_parser.add_argument(
'-h', '--hypothesis',
help='Hypothesis file(s) in SegLST, STM or CTM format. '
'Multiple files can be provided for a side by side view. '
'Optionally prefixed with system name, e.g. mysystem:/path/to/hyp.stm',
nargs='+', action=self.extend_action,
required=True,
)
else:
return super().add_argument(command_parser, name, p)

cli = VizCLI()
cli.add_command(html)
cli.run()


if __name__ == '__main__':
import fire
fire.Fire(main)
cli()

0 comments on commit b0b507e

Please sign in to comment.