-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathinfer_folder.py
61 lines (47 loc) · 2.37 KB
/
infer_folder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
conduct evaluation on a folder of WAV files, with computing SISNR, PESQ, ESTOI, and DNSMOS.
"""
import os
import toml
import torch
from tqdm import tqdm
import soundfile as sf
from omegaconf import OmegaConf
from model import DPCRN
def infer_folder(cfg_yaml, test_folder):
test_wavnames = list(filter(lambda x: x.endswith("wav"), os.listdir(test_folder)))
cfg_toml = toml.load(cfg_yaml.network.cfg_toml)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
netout_folder = f'{cfg_yaml.path.exp_folder}'
os.makedirs(netout_folder, exist_ok=True)
### load model
model = DPCRN(**cfg_toml['network_config']).to(device)
checkpoint = torch.load(cfg_yaml.network.checkpoint, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
for param in model.parameters():
param.requires_grad = False
### compute SISNR, PESQ and ESTOI
with torch.no_grad():
for name in tqdm(test_wavnames):
noisy, fs = sf.read(os.path.join(test_folder, name), dtype="float32")
noisy = torch.stft(torch.from_numpy(noisy), **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5))
noisy = noisy.to(device)
estimate= model(noisy[None, ...]) # (B,F,T,2)
enhanced = torch.istft(estimate[..., 0] + 1j*estimate[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device))
out = enhanced.cpu().detach().numpy().squeeze()
sf.write(os.path.join(netout_folder, name[:-4]+'_enh.wav'), out, fs)
### compute DNSMOS
os.chdir('DNSMOS')
out_dir = os.path.join(netout_folder, 'dnsmos_enhanced_p808.csv')
os.system(f'python dnsmos_local_p808.py -t {netout_folder} -o {out_dir}')
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-C', '--config', default='cfg_yaml.yaml')
parser.add_argument('-D', '--device', default='0', help='The index of the available device, only single GPU supported')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
cfg_yaml = OmegaConf.load(args.config)
test_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/blind_test_set/dns-challenge-3-final-evaluation/wideband_16kHz/noisy_clips_wb_16kHz/'
infer_folder(cfg_yaml, test_folder)