Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
beyondguo committed Nov 16, 2022
1 parent f50a627 commit 5bd666c
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions sega_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jieba, jieba.analyse
import time
import re
from nltk.tokenize import word_tokenize


from nltk.corpus import stopwords
Expand Down Expand Up @@ -33,16 +34,24 @@ def get_stopwords():

class SketchExtractor:
def __init__(self, model='yake'):
assert model in ['yake', 'bert','jieba'], '`model` only support `yake`, `bert` or `jieba`'
assert model in ['random','yake', 'bert','jieba'], '`model` only support `yake`, `bert` or `jieba`'
self.model = model
if model == 'random':
self.extractor = None
if model == 'yake': # for English
self.extractor = None
if model == 'bert': # for English
self.extractor = AspectKeyBERT(model='all-MiniLM-L6-v2') # paraphrase-MiniLM-L3-v2 (the fastest LM) all-MiniLM-L6-v2
if model == 'jieba': # for Chinese
self.extractor = jieba.analyse


def get_kws(self, s, max_ngram=3, top=10, aspect_keywords=None, use_aspect_as_doc_embedding=False,lang='en'):
if self.model == 'random':
words = list(set(word_tokenize(s)))
random.shuffle(words)
random_words = words[:top]
return [],random_words
if self.model == 'yake':
self.extractor = yake.KeywordExtractor(n=max_ngram,top=top, windowsSize=1,lan=lang)
kws_pairs = self.extractor.extract_keywords(s)
Expand Down Expand Up @@ -91,7 +100,7 @@ def get_sketch_from_kws(self, s, kws, template=4, mask='<mask>', sep=' '):
all_ids = []
for w in kws: # 找出每个词的位置
try:
for m in list(re.finditer(w.translate(table),s)):
for m in list(re.finditer(re.escape(w.translate(table)),s)):
all_ids += list(range(m.start(),m.end()))
except Exception as e:
print(e)
Expand All @@ -111,7 +120,7 @@ def get_sketch_from_kws(self, s, kws, template=4, mask='<mask>', sep=' '):
all_ids = []
for w in kws: # 找出每个词的位置
try:
for m in list(re.finditer(w.translate(table),s)):
for m in list(re.finditer(re.escape(w.translate(table)),s)):
all_ids += list(range(m.start(),m.end()))
except Exception as e:
print(e)
Expand All @@ -125,7 +134,9 @@ def get_sketch_from_kws(self, s, kws, template=4, mask='<mask>', sep=' '):
for i,id in enumerate(all_ids):
if i == 0 and id != 0: # 开头补mask
masked_text.append(f'{mask}{sep}')
if id - all_ids[i-1] > 1: # 说明中间有东西
if sep == ' ' and id - all_ids[i-1] == 2 and s[id-1] == ' ': # 中间是空格
masked_text.append(' ')
if id - all_ids[i-1] > 2: # 说明中间有东西
masked_text.append(f'{sep}{mask}{sep}')
masked_text.append(s[id])
if i == len(all_ids)-1 and id != len(s)-1: # 最后补mask
Expand Down

0 comments on commit 5bd666c

Please sign in to comment.