Skip to content

Commit

Permalink
refactor call args
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 1, 2024
1 parent ab2f6f8 commit b7c5044
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
10 changes: 4 additions & 6 deletions wenet/cli/punc_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import List

import jieba
import torch
from wenet.cli.hub import Hub
from wenet.paraformer.search import _isAllAlpha
Expand All @@ -26,8 +27,6 @@ def split_words(self, text: str):
self.use_jieba = True
import logging

import jieba

# Disable jieba's logger
logging.getLogger('jieba').disabled = True
jieba.load_userdict(os.path.join(self.model_dir, 'jieba_usr_dict'))
Expand Down Expand Up @@ -95,12 +94,11 @@ def add_punc_batch(self, texts: List[str]):
result.append(sentence.replace('▁', ' '))
return result

def __call__(self, result):
text = result['text']
def __call__(self, text: str):
if text != '':
r = self.add_punc_batch([text])[0]
result['text_with_punc'] = r
return result
return r
return ''


def load_model(model_dir: str = None,
Expand Down
2 changes: 1 addition & 1 deletion wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def main():
result = model.transcribe(args.audio_file, args.show_tokens_info)
if args.punc:
assert punc_model is not None
result = punc_model(result)
result['text_with_punc'] = punc_model(result['text'])
print(result)


Expand Down

0 comments on commit b7c5044

Please sign in to comment.