Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tts]For mixed Chinese and English speech synthesis, add SSML support for Chinese #2830

Merged
merged 2 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*.egg-info
build
*output/
.history

audio/dist/
audio/fc_patch/
Expand Down
57 changes: 51 additions & 6 deletions paddlespeech/t2s/frontend/mix_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict
from typing import List

import paddle

from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor


class MixFrontend():
Expand Down Expand Up @@ -107,7 +109,40 @@ def get_input_ids(self,
add_sp: bool=True,
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:

segments = self.get_segment(sentence)
''' 1. 添加SSML支持,先列出 文字 和 <say-as>标签内容,
然后添加到tmpSegments数组里
'''
d_inputs = MixTextProcessor.get_dom_split(sentence)
tmpSegments = []
for instr in d_inputs:
''' 暂时只支持 say-as '''
if instr.lower().startswith("<say-as"):
tmpSegments.append((instr, "zh"))
else:
tmpSegments.extend(self.get_segment(instr))

''' 2. 把zh的merge到一起,避免合成结果中间停顿
'''
segments = []
currentSeg = ["", ""]
for seg in tmpSegments:
if seg[1] == "en" or seg[1] == "other":
if currentSeg[0] == '':
segments.append(seg)
else:
currentSeg[0] = "<speak>" + currentSeg[0] + "</speak>"
segments.append(tuple(currentSeg))
segments.append(seg)
currentSeg = ["", ""]
else:
if currentSeg[0] == '':
currentSeg[0] = seg[0]
currentSeg[1] = seg[1]
else:
currentSeg[0] = currentSeg[0] + seg[0]
if currentSeg[0] != '':
currentSeg[0] = "<speak>" + currentSeg[0] + "</speak>"
segments.append(tuple(currentSeg))

phones_list = []
result = {}
Expand All @@ -120,11 +155,21 @@ def get_input_ids(self,
input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=False, to_tensor=to_tensor)
else:
input_ids = self.zh_frontend.get_input_ids(
content,
merge_sentences=False,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
''' 3. 把带speak tag的中文和普通文字分开处理
'''
if content.strip() != "" and \
re.match(r".*?<speak>.*?</speak>.*", content, re.DOTALL):
input_ids = self.zh_frontend.get_input_ids_ssml(
content,
merge_sentences=False,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
else:
input_ids = self.zh_frontend.get_input_ids(
content,
merge_sentences=False,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
if add_sp:
input_ids["phone_ids"][-1] = paddle.concat(
[input_ids["phone_ids"][-1], self.sp_id_tensor])
Expand Down
34 changes: 34 additions & 0 deletions paddlespeech/t2s/ssml/xml_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ def get_pinyin_split(self, mixstr):
ctlist.append([mixstr, []])
return ctlist

@classmethod
def get_dom_split(self, mixstr):
''' 文本分解,顺序加了列表中,返回文本和say-as标签
'''
ctlist = []
patn = re.compile(r'(.*\s*?)(<speak>.*?</speak>)(.*\s*)$', re.M | re.S)
mat = re.match(patn, mixstr)
if mat:
pre_xml = mat.group(1)
in_xml = mat.group(2)
after_xml = mat.group(3)

ctlist.append(pre_xml)
dom = DomXml(in_xml)
tags = dom.get_text_and_sayas_tags()
ctlist.extend(tags)

ctlist.append(after_xml)
return ctlist
else:
ctlist.append(mixstr)
return ctlist

class DomXml():
def __init__(self, xmlstr):
Expand Down Expand Up @@ -156,3 +178,15 @@ def get_all_tags(self, tag_name):
if x.hasAttribute('pinyin'): # pinyin
print(x.tagName, 'pinyin',
x.getAttribute('pinyin'), x.firstChild.data)

def get_text_and_sayas_tags(self):
'''返回 xml 内容的列表,包括所有文本内容和<say-as> tag'''
res = []

for x1 in self.rnode:
if x1.nodeType == Node.TEXT_NODE:
res.append(x1.value)
else:
for x2 in x1.childNodes:
res.append(x2.toxml())
return res