Skip to content

kasper0406/audio-to-midi

Repository files navigation

Audio-to-Midi

Audio-to-Midi is a model for converting pure piano music into midi events. It is a Transformer based model working on raw audio samples, which is processed by a CNN before reaching the Transformer stack. The model is written in Jax, and uses a bit of rust code to achieve efficient dataset loading.

This model powers the Piano Transcriber iOS app, and is trained on data generated by SwiftMidiBouncer.

Notice that the model is trained only on Piano samples, and therefore I do not expect it to work very well for other instruments.

Setup

  1. Make a new Python environment, and install necessary packages. I do not have a full list currently. The main ones are jax and equinox.

  2. The data-loading requires the rust plugin from the rust-plugins directory to be installed. To do this:

  1. The data loading uses FFmpeg to read audio files and convert them to downsampled raw audio. You need ffmpeg to be available. Note that I've seen quite big performance gains, by manually compiling my own version of FFmpeg.

Inferring

This repository provides the audio_to_midi.py CLI to interact with the model. It supports the following operations:

Inferring Midi events from an audio file

To generate midi events from an audio file, run the following command:

python audio_to_midi.py <input_audio_file> [<output_midi_file.mid>]

In addition to output the midi file, the raw probability distributions outputted by the model is plotted as well using matplotlib: Audio-to-Midi output distribution example

Computing the validation loss

For a dataset directory containing audio samples along with csv files annotated with Midi events, different loss metrics can be calculated. Fx:

⇒  python audio_to_midi.py --validation /Volumes/git/ml/datasets/midi-to-sound/validation_set_only_yamaha
Restoring saved model at step 359000
Loaded test set
Finished evaluating test loss
Validation loss: 74.94306182861328
Hit rate: 0.7999402284622192
Eventized diff: 66.33872985839844

Pre-trained model

A pre-trained model has been released in three formats:

  • Jax-checkpoints: Unpack these in the directory of this repository, and run the inferrence CLI as per above.
  • CoreML: This is a model used for iOS. The ModelManager class in the Piano Transcriber app has an example of how to use it.
  • Tensorflow: This model can be run by using the infer_tf.py cli. Example:
    python infer_tf.py <audio file>

Training

  1. Setup the dataset, which can be done by generating a new one using SwiftMidiBouncer
  2. Adjust the parameters in the model.py file, and in train.py to point it to the desired dataset, as well as adjusting the batch size and learning rate.
  3. Run python train.py to start the training

The training has been tested on a TPU, where data sharding has been used, to parallelize training across all available devices.

Exporting

Audio-to-Midi is written in Jax, which by itself, is not very useful, if you want to run inferrence in non-Python environments, such as Phones or web-apps.

The export.py file supports the following conversions:

  1. The Audio-to-Midi Jax model to Tensorflow
  2. Tensorflow to TFLite
  3. Tensorflow to CoreML using coremltools.

XLA is note really fully supported by TFLite and coremltools, so unfortunately, we use the older way to export form Jax to Tensorflow. The exporting code is quite fragile, and requires a bunch of tuning to actually work. See the export.py file for all details.

About

Try to infer midi from Audio files

Resources

Stars

Watchers

Forks

Packages

No packages published