Skip to content

Commit

Permalink
fix(moco): small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
theolepage committed Oct 20, 2021
1 parent db61d04 commit f056d87
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 26 deletions.
6 changes: 2 additions & 4 deletions configs/moco-base-kaldi.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
"weight_regularizer": 1e-4
},
"training": {
"run_eagerly": true,
"epochs": 150,
"optimizer": {
"type": "SGD",
"momentum": 0.9
},
"batch_size": 1,
"batch_size": 1024,
"learning_rate": {
"scheduler": "cosine",
"start": 0.1,
Expand All @@ -36,8 +35,7 @@
"scp": "./data/train/feats.scp",
"utt2spk": "./data/train/utt2spk",
"frames": {
"length": 64000,
"pairs": true,
"length": 300,
"extract_mfcc": true
}
}
Expand Down
20 changes: 8 additions & 12 deletions sslforslr/dataset/KaldiDatasetLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class KaldiDatasetGenerator(Sequence):
def __init__(self, batch_size, frames_config, rxfiles, labels, indices):
self.batch_size = batch_size
self.frame_length = frames_config['length']
self.pairs = frames_config.get('pairs', False)
self.extract_mfcc = frames_config.get('extract_mfcc', False)
self.rxfiles = rxfiles
self.labels = labels
Expand All @@ -29,29 +28,26 @@ def __len__(self):
return len(self.indices) // self.batch_size

def __getitem__(self, i):
X, y = [], [], []
X, y = [], []

for j in range(self.batch_size):
index = self.indices[i * self.batch_size + j]

sample, sr = sf.read(self.rxfiles[index])
data = sample.reshape((len(sample), 1))
label = self.labels[index]

assert len(sample) >= self.frame_length
sample = sample.reshape((self.frame_length, 1))

if self.extract_mfcc:
data = extract_mfcc(sample)
else:
offset = np.random.randint(0, len(sample) - self.frame_length + 1)
data = sample[offset:offset+self.frame_length]

assert len(data) >= self.frame_length
offset = np.random.randint(0, len(data) - self.frame_length + 1)
data = data[offset:offset+self.frame_length]

X.append(data)
y.append(label)

if self.pairs:
return (np.array(X), np.array(X)) np.array(Y)
return np.array(X), np.array(Y)

return np.array(X), np.array(y)

class KaldiDatasetLoader:
def __init__(self, seed, config):
Expand Down
20 changes: 10 additions & 10 deletions sslforslr/models/moco/MoCo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ def compile(self, optimizer, **kwargs):
self.optimizer = optimizer

def call(self, X):
return self.encoder_q(self.mlp(X))
return self.mlp(self.encoder_q(X))

def train_step(self, data):
X, _ = data # Discard labels provided by the dataset generator
# X shape: (batch_size, frame_length, 40, 1)
# X shape: (batch_size, 300, 30)

X_1_aug, X_2_aug = X
X_1_aug = X
X_2_aug = tf.identity(X)

with tf.GradientTape() as tape:
Z_q = self.encoder_q(X_1_aug, training=True)
Expand Down Expand Up @@ -93,7 +94,8 @@ def train_step(self, data):
def test_step(self, data):
X, _ = data # Discard labels provided by the dataset generator

X_1_aug, X_2_aug = X
X_1_aug = X
X_2_aug = tf.identity(X)

Z_q = self.encoder_q(X_1_aug, training=False)
Z_k = self.encoder_k(X_2_aug, training=False)
Expand Down Expand Up @@ -129,12 +131,10 @@ def info_nce_loss(anchor, pos, neg, temp):
loss = tf.reduce_mean(loss)

# Determine accuracy
logits_size = tf.shape(logits)[1]
logits_softmax = tf.nn.softmax(logits, axis=0)
pred_indices = tf.math.argmax(logits_softmax, axis=0, output_type=tf.int32)
preds_acc = tf.math.equal(pred_indices, tf.zeros(logits_size, dtype=tf.int32))
accuracy = tf.math.count_nonzero(preds_acc, dtype=tf.int32)
accuracy /= logits_size
logits_softmax = tf.nn.softmax(logits, axis=1)
pred_indices = tf.math.argmax(logits_softmax, axis=1, output_type=tf.int32)
preds_acc = tf.math.equal(pred_indices, labels)
accuracy = tf.math.count_nonzero(preds_acc, dtype=tf.int32) / batch_size

return loss, accuracy

Expand Down

0 comments on commit f056d87

Please sign in to comment.