Skip to content

Commit

Permalink
feat(evaluate): add evaluate notebook (det curves, t-sne)
Browse files Browse the repository at this point in the history
  • Loading branch information
theolepage committed Nov 24, 2021
1 parent 59cc30e commit 0731e81
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 50 deletions.
14 changes: 2 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,7 @@ Then, you can evaluate model on speaker verification (EER, minDCF) with `python

## To-Do

- [ ] Experiment with different architectures (VICReg)
- [ ] loss = nce + 0.1 * vic
- [ ] loss = vic
- [ ] Ablation study on VICReg hyper-params
- [ ] Train other models (MoCo/XVectorEncoder, CPC/CPCEncoder, LIM/SincEncoder, Wav2Spk)
- [ ] Pytorch implementation of best model
- [ ] Provide script to download and prepare data

---

- [ ] Make sure other models work (MoCo/XVectorEncoder, CPC/CPCEncoder, LIM/SincEncoder, Wav2Spk)
- [ ] CPC/LIM: @tf.function warning when doing tensor[1, :]
- [ ] Fix warning when loading: some weights are not used
- [ ] Allow restore optimizer
- [ ] Pytorch implementation
- [ ] Allow restore optimizer
24 changes: 11 additions & 13 deletions evaluate.py → extract_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import argparse
import pickle
import tensorflow as tf

from sslforslr.utils.helpers import load_config, load_dataset, load_model
from sslforslr.utils.evaluate import speaker_verification_evaluate
from sslforslr.utils.evaluate import extract_embeddings

def evaluate(config_path):
config, checkpoint_dir = load_config(config_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Path to model config file.')
args = parser.parse_args()

config, checkpoint_dir = load_config(args.config)
(train_gen, val_gen), input_shape = load_dataset(config)
model = load_model(config, input_shape)

Expand All @@ -21,14 +26,7 @@ def evaluate(config_path):
else:
raise Exception('%s has not been trained.' % config['name'])

eer, min_dcf_001, min_dcf_005 = speaker_verification_evaluate(model, config)
print('EER (%):', eer)
print('minDCF (p=0.01):', min_dcf_001)
print('minDCF (p=0.05):', min_dcf_005)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Path to model config file.')
args = parser.parse_args()
embeddings = extract_embeddings(model, config.dataset.test, config.dataset)

evaluate(args.config)
with open(checkpoint_dir + '/embeddings.pkl', 'wb') as f:
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
335 changes: 335 additions & 0 deletions notebooks/evaluate.ipynb

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
numpy
sklearn
pandas
seaborn
matplotlib

tensorflow
tensorflow-addons
sklearn
soundfile
tqdm
torch
torchaudio
kaldiio

soundfile
ruamel.yaml
dacite
prettyprinter
git+git://github.com/DemisEom/SpecAugment.git
tqdm
6 changes: 1 addition & 5 deletions sslforslr/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import torchaudio
import soundfile as sf

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from SpecAugment.spec_augment_pytorch import spec_augment

def load_wav(path, frame_length, num_frames=1, min_length=None):
audio, sr = sf.read(path)

Expand Down Expand Up @@ -51,7 +47,7 @@ def extract_mfcc(audio, enable_spec_augment=False):
n_mels=40)(audio) # mfcc: (N, C, T)

if enable_spec_augment:
mfcc = spec_augment(mfcc)
raise Exception('SpecAugment not supported')

mfcc = mfcc.numpy().transpose(0, 2, 1) # (N, T, C)

Expand Down
20 changes: 12 additions & 8 deletions sslforslr/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tensorflow.keras.callbacks import Callback

from sslforslr.utils.evaluate import speaker_verification_evaluate
from sslforslr.utils.evaluate import extract_embeddings, evaluate

class SVMetricsCallback(Callback):
def __init__(self, config):
Expand All @@ -9,17 +9,21 @@ def __init__(self, config):
self.config = config

def on_epoch_end(self, epoch, logs):
eer, min_dcf_001, min_dcf_005 = speaker_verification_evaluate(
self.model,
self.config
embeddings = extract_embeddings(
model,
self.config.dataset.test,
self.config.dataset
)


eer, min_dcf_001, _, _ = evaluate(
embeddings,
self.config.dataset.trials
)

print('EER (%):', eer)
print('minDCF (p=0.01):', min_dcf_001)
print('minDCF (p=0.05):', min_dcf_005)

logs.update({
'test_eer': eer,
'test_min_dcf_001': min_dcf_001,
'test_min_dcf_005': min_dcf_005
'test_min_dcf_001': min_dcf_001
})
13 changes: 6 additions & 7 deletions sslforslr/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ def compute_min_dcf(fnrs, fprs, p_target=0.01, c_miss=1, c_fa=1):

return min_dcf

def speaker_verification_evaluate(model, config, round_val=5):
embeddings = extract_embeddings(model, config.dataset.test, config.dataset)
scores, labels = score_trials(config.dataset.trials, embeddings)
def evaluate(embeddings, trials):
scores, labels = score_trials(trials, embeddings)

eer = compute_eer(scores, labels)

eer = round(compute_eer(scores, labels), round_val)
fnrs, fprs = compute_error_rates(scores, labels)
min_dcf_001 = round(compute_min_dcf(fnrs, fprs, p_target=0.01), round_val)
min_dcf_005 = round(compute_min_dcf(fnrs, fprs, p_target=0.05), round_val)
min_dcf_001 = compute_min_dcf(fnrs, fprs, p_target=0.01)

return eer, min_dcf_001, min_dcf_005
return eer, min_dcf_001, fnrs, fprs

0 comments on commit 0731e81

Please sign in to comment.