-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunModel.py
executable file
·92 lines (67 loc) · 2.68 KB
/
runModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#import need models
import os
import tensorflow as tf
from tensorflow.keras import losses
import model
import ds
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import spiegelib
#load data
print("Loading Data...")
train_data = ds.melParamData("train","data")
test_data = ds.melParamData("test","data")
validation_data = ds.melParamData("validation","data")
print("Done!")
# print(serum_param_dic)
# print(test_synth)
#directory for finding checkpoints
checkpoint_path = "new_models4/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
#get latest model
latest = tf.train.latest_checkpoint(checkpoint_dir)
#create autoencoder model
autoencoder = model.autoencoder3(64,train_data.get_mels()[:10,...,np.newaxis].shape,train_data.get_params().shape)
#load stored weights
autoencoder.load_weights(latest)
#compile model
autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
#print evaluation on test set
loss, loss1,loss2 = autoencoder.evaluate(train_data.get_mels(),[train_data.get_mels(),train_data.get_params()])
print("training model loss = " + str(loss) + "\n training model spectrogram loss = "+ str(loss1) + "\n training model synth_param loss = "+ str(loss2))
#print evaluation on test set
loss, loss1,loss2 = autoencoder.evaluate(test_data.get_mels(),[test_data.get_mels(),test_data.get_params()])
print("test model loss = " + str(loss) + "\n test model spectrogram loss = "+ str(loss1) + "\n test model synth_param loss = "+ str(loss2))
#get prediction
spectogram,params = autoencoder.predict(train_data.get_mels()[[2000]])
#evaluate reconstruction of 30th test file
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True)
img = librosa.display.specshow(train_data.get_mels()[2000], y_axis='mel', x_axis='time', ax=ax[0])
ax[0].set(title='Mel-Frequency Spectrogram Reconstruction')
ax[0].label_outer()
librosa.display.specshow(np.squeeze(spectogram), y_axis='mel', x_axis='time', ax=ax[1])
fig.colorbar(img, ax=ax, format="%+2.f dB")
# print("Ground Truth Parameters:" + str(test_data.get_params()[500]-params))
# print("Predicted Parameters" + str(params))
plt.show()
test_synth = train_data.get_params()[2000]
#create serum synthesizer object
synth = spiegelib.synth.SynthVST("/Library/Audio/Plug-Ins/Components/Serum.component")
#generate ground truth audio from synth parameters
synth.set_patch(test_synth)
synth.render_patch()
audio = synth.get_audio()
audio.plot_spectrogram()
audio.save("test_audio.wav")
#show plots
plt.show()
#generate predict audio from synth parameters
synth.set_patch(np.squeeze(params))
synth.render_patch()
audio = synth.get_audio()
audio.plot_spectrogram()
audio.save("predict_audio.wav")
#show plots
plt.show()