-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathcreate_data.py
86 lines (74 loc) · 3.12 KB
/
create_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import os
import sys
from tqdm import tqdm
from utils.reader import load_audio
# 生成数据列表
def get_data_list(infodata_path, list_path, zhvoice_path):
with open(infodata_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
f_train = open(os.path.join(list_path, 'train_list.txt'), 'w')
f_enroll = open(os.path.join(list_path, 'enroll_list.txt'), 'w')
f_trials = open(os.path.join(list_path, 'trials_list.txt'), 'w')
speakers_name = []
for line in lines:
line = json.loads(line.replace('\n', ''))
duration_ms = line['duration_ms']
if duration_ms < 1300:
continue
sound_path = os.path.join(zhvoice_path, line['index'])
if not os.path.exists(sound_path):
continue
speaker = line['speaker']
if speaker not in speakers_name:
speakers_name.append(speaker)
test_speaker_name = [name for i, name in enumerate(speakers_name) if i % 32 == 0]
train_speaker_name = [name for name in speakers_name if name not in test_speaker_name]
train_speaker_dict = {name: i for i, name in enumerate(train_speaker_name)}
test_speaker_dict = {name: i for i, name in enumerate(test_speaker_name)}
print(f'训练集有{len(train_speaker_name)}个说话人,测试集有{len(test_speaker_name)}个说话人')
test_data = {i: [] for i in range(len(test_speaker_name))}
for line in tqdm(lines):
line = json.loads(line.replace('\n', ''))
duration_ms = line['duration_ms']
if duration_ms < 1300:
continue
speaker = line['speaker']
sound_path = os.path.join(zhvoice_path, line['index'])
if not os.path.exists(sound_path):
continue
if speaker in test_speaker_name:
speaker_id = test_speaker_dict[speaker]
test_data[speaker_id].append(sound_path.replace('\\', '/'))
if speaker in train_speaker_name:
speaker_id = train_speaker_dict[speaker]
f_train.write('%s\t%d\n' % (sound_path.replace('\\', '/'), speaker_id))
f_train.close()
for data in test_data.items():
speaker_id, data = data
for i, d in enumerate(data):
if i == 0:
f_enroll.write('%s\t%d\n' % (d, speaker_id))
else:
f_trials.write('%s\t%d\n' % (d, speaker_id))
# 删除错误音频
def remove_error_audio(data_list_path):
with open(data_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
lines1 = []
for line in tqdm(lines):
audio_path, _ = line.split('\t')
try:
_ = load_audio(audio_path)
lines1.append(line)
except Exception as e:
print(audio_path, file=sys.stderr)
print(e, file=sys.stderr)
with open(data_list_path, 'w', encoding='utf-8') as f:
for line in lines1:
f.write(line)
if __name__ == '__main__':
get_data_list('dataset/zhvoice/text/infodata.json', 'dataset', 'dataset/zhvoice')
remove_error_audio('dataset/enroll_list.txt')
remove_error_audio('dataset/trials_list.txt')
remove_error_audio('dataset/train_list.txt')