import os import sox import torch import yaml import shutil import operator import numpy as np from util.preprocess_functions import preprocess_dataset,normalize,set_type from util.timit_dataset import create_dataloader from util.functions import test_file from six.moves import cPickle # if this is removed, pytorch will output harmless warning messages that may be irritating import warnings warnings.filterwarnings("ignore") phn_occurrence = {} # load phoneme information with open('config/phn_occurrence.txt') as f: for line in f.readlines(): phn_occurrence[line.split()[0]] = int(line.split()[1]) phonemes = ['ih', 'n', 'iy', 'l', 's', 'r', 'ah', 'aa', 'er', 'k', 'm', 't', 'eh', 'ae', 'z', 'd', 'q', 'w', 'dh', 'p', 'dx', 'f', 'b', 'sh', 'ay', 'ey', 'ow', 'g', 'uw', 'hh', 'v', 'y', 'ng', 'jh', 'th', 'oy', 'ch', 'uh', 'aw'] strat_phn_count = {} def get_pred(path): data_type = 'float32' mean_val = np.loadtxt('config/mean_val.txt') std_val = np.loadtxt('config/std_val.txt') x, y = preprocess_dataset(path) x = normalize(x, mean_val, std_val) x = set_type(x, data_type) test_set = create_dataloader(x, y, **conf['model_parameter'], **conf['training_parameter'], shuffle=False) for batch_index,(batch_data,batch_label) in enumerate(test_set): pred,true = test_file(batch_data, batch_label, listener, speller, optimizer, **conf['model_parameter']) return pred # load LAS model config_path = 'config/las_example_config.yaml' conf = yaml.load(open(config_path,'r')) listener = torch.load(conf['training_parameter']['pretrained_listener_path'], map_location=lambda storage, loc: storage) speller = torch.load(conf['training_parameter']['pretrained_speller_path'], map_location=lambda storage, loc: storage) optimizer = torch.optim.Adam([{'params':listener.parameters()}, {'params':speller.parameters()}], lr=conf['training_parameter']['learning_rate']) for phn in phonemes: i = 0 for file in os.listdir(os.path.join('phoneme_set', phn)): print('Testing {} {} out of {}'.format(phn, str(i), str(phn_occurrence[phn])), end='\r') test = os.path.join('phoneme_set', phn, file) pred = get_pred(test) if phn in pred: if phn not in strat_phn_count: strat_phn_count[phn] = 1 else: strat_phn_count[phn] += 1 os.makedirs(os.path.join('strat_phoneme_set', phn), exist_ok=True) shutil.copy(test, os.path.join('strat_phoneme_set', phn, phn + str(strat_phn_count[phn]) + '.wav')) i += 1 print() sorted_phn = sorted(strat_phn_count.items(), key=operator.itemgetter(1), reverse=True) with open('config/strat_phn_occurrence.txt', 'w+') as f: [f.write(phn[0] + ' ' + str(phn[1]) + '\n') for phn in sorted_phn]