-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpcc.py
69 lines (59 loc) · 2.68 KB
/
pcc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from PAM import PAM
from dataset import ExampleDatasetFiles
import os
import glob
import pandas as pd
import numpy as np
def evaluate_pam(dataloader, pam):
"""Evaluate PAM score using the provided dataloader"""
collect_pam, collect_pam_segment = [], []
for _, audios, sample_index in tqdm(dataloader):
pam_score, pam_segment_score = pam.evaluate(audios, sample_index)
collect_pam += pam_score
collect_pam_segment += pam_segment_score
return collect_pam, collect_pam_segment
def load_task_dataframe(task, model):
"""Load and return human listening scores"""
df = pd.read_csv(os.path.join(task, "scores.csv"))
model_df = df[df["Model"] == model.split(os.path.sep)[-1]]
files = [os.path.join(model, x) + ".wav" for x in list(model_df["File Name"])]
OVLs, RELs = model_df["OVL"], model_df["REL"]
return files, OVLs, RELs
def evaluate_task(task, model, pam):
"""Evaluate files generated by a model for particular task"""
print(f"\nTask: {task}, Model: {model}")
# Load human listening scores
files, OVLs, RELs = load_task_dataframe(task, model)
# Create Dataset and Dataloader
dataset = ExampleDatasetFiles(
src=files,
repro=args.repro,
)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle = False,
num_workers = args.num_workers,
pin_memory = False, drop_last=False, collate_fn=dataset.collate)
# Evaluate and print PAM score
collect_pam, _ = evaluate_pam(dataloader, pam)
print(f"Average PAM Score: {sum(collect_pam)/len(collect_pam)}")
print(f"PCC PAM / OVL: {np.corrcoef(collect_pam, OVLs)[0,1]}")
print(f"PCC PAM / REL: {np.corrcoef(collect_pam, RELs)[0,1]}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "PAM")
parser.add_argument('--folder', type=str, default="human_eval", help='Folder path to evaluate')
parser.add_argument('--batch_size', type=int, default=10, help='Number of examples per batch')
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for dataloader')
parser.add_argument('--repro', type=bool, default=True, help='Reproduce paper setup and evaluation')
args = parser.parse_args()
# initialize PAM
pam = PAM(use_cuda=torch.cuda.is_available())
# Run evaluation on tasks
tasks = glob.glob(os.path.join(args.folder,"**"))
for task in tasks:
models = glob.glob(os.path.join(task,"**"))
models = [m for m in models if ".csv" not in m]
for model in models:
evaluate_task(task, model, pam)