Skip to content

Latest commit

 

History

History
208 lines (162 loc) · 9.17 KB

File metadata and controls

208 lines (162 loc) · 9.17 KB

Chapter 2A - Groove Transformer


Table of Contents

  1. Introduction
  2. BasicGrooveTransformer.GrooveTransformer
    1. Instantiation
  3. BasicGrooveTransformer.GrooveTransformerEncoder
    1. Instantiation
    2. Storing
    3. Loading
    4. Pretrained Versions
    5. Generation

1. Introduction

2. BasicGrooveTransformer.GrooveTransformer

This model is a full encoder/decoder transformer model similar to the original transformer in the paper Attention is All You Need. It consists of a transformerEncoder and a transformerDecoder. The only thing is that this model is designed to work with piano-roll-like data.

We have not yet trained any versions of this model, but the implementation is complete and ready to be trained.

2.i Instantiation

A groove transformer similar to the original transformer consisting of transformerEncoder and transformerDecoder.

Source code available here

# Instantiating a model
from model.Base.BasicGrooveTransformer import GrooveTransformer

params = {
    "d_model": 128,
    "nhead": 4,
    "dim_forward": 256,
    "dropout": 0.1,
    "num_encoder_layers": 6,
    "num_decoder_layers": 6,
    "max_len": 32,
    "N": 64,  # batch size
    "embedding_size_src": 16,  # input dimensionality at each timestep
    "embedding_size_tgt": 27  # output dimensionality at each timestep
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'

TM = GrooveTransformer(params["d_model"], params["embedding_size_src"], params["embedding_size_tgt"],
                       params["nhead"], params["dim_feedforward"], params["dropout"],
                       params["num_encoder_layers"], params["num_decoder_layers"], params["max_len"], device)

3. BasicGrooveTransformer.GrooveTransformerEncoder

This is a model that can take in a piano-roll-like monotonic groove and output a piano-roll-like drum pattern. This model consists of only the encoder part of the original transformer model.

We have many pretrained versions of this model available (see Pretrained Versions). I suggest reading this document to better understand the training/evaluation process.

Moreover, a real-time plugin has already developed using this model. The plugin is available here

If you use this model, please cite the following paper:

@article{hakireal,
  title={Real-Time Drum Accompaniment Using Transformer Architecture},
  author={Haki, Behzad and Nieto, Marina and Pelinski, Teresa and Jord{\`a}, Sergi}
  booktitle={Proceedings of the 3rd Conference on AI Music Creativity, AIMC}
  year={2022}
}

3.i Instantiation

A groove transformer consisting of the transformerEncoder only section of the original transformer

Source code available here

from model.Base.BasicGrooveTransformer import GrooveTransformerEncoder

params = {
    'd_model': 512,
    'embedding_size_src': 27,
    'embedding_size_tgt': 27,
    'nhead': 1,
    'dim_feedforward': 64,
    'dropout': 0.25542373735391866,
    'num_encoder_layers': 10,
    'max_len': 32,
    'device': 'gpu'
}

from model import GrooveTransformerEncoder

device = 'cuda' if torch.cuda.is_available() else 'cpu'

TEM = GrooveTransformerEncoder(params["d_model"], params["embedding_size_src"], params["embedding_size_tgt"],
                               params["nhead"], params["dim_feedforward"], params["dropout"],
                               params["num_encoder_layers"], params["max_len"], device)

3.ii Storing

The models have a save method which can be used to store the model parameters. The save method takes in a **.pth file path where the model attributes are to be stored. The model parameters as well as the model state dictionary are stored in the stored file.

TEM.save("model/misc/rand_model.pth")

Using this method, a **.json file is also created which stores the model parameters. The data stored in this json file is already available in the dictionary stored in the .pth file. The json file is created for conveniently inspecting the model params.

3.iii Loading

`

4. Loading a Stored Model

Source code available here

from model.modelLoadesSamplers import load_mgt_model
from model.saved.monotonic_groove_transformer_v1.params import model_params
import torch
import numpy as np

# Model path and model_param dictionary
model_name = "colorful_sweep_41"
model_path = f"model/saved/monotonic_groove_transformer_v1/{model_name}.model"
model_param = model_params[model_name]

# 1. LOAD MODEL
GrooveTransformer = load_mgt_model(model_path, model_param)
checkpoint = torch.load(model_path, map_location=model_param['device'])

3.iv Pretrained Versions

Four pretrained versions of this model are available. The models are trained according to the documents discussed above in the introduction section. The models are available in the model/saved/monotonic_groove_transformer_v1 directory.

The models are:

To load the model, use the load_mgt_model method from the modelLoadesSamplers module as discussed above. For example, to load the misunderstood_bush_246 model, use the following code:

```python
from model.Base.modelLoadersSamplers import load_mgt_model

model_path = f"model/saved/monotonic_groove_transformer_v1/latest/misunderstood_bush_246.pth"
GrooveTransformer = load_mgt_model(model_path)

3.v Generation

Source code available here

Create am input groove (create a HVO_Sequence instance, load a midi file, or grab one from the HVO_Sequence datasets as below

from data.dataLoaders import load_gmd_hvo_sequences
test_set = load_gmd_hvo_sequences(
    "data/gmd/resources/storedDicts/groove_2bar-midionly.bz2pickle", "gmd", "data/dataset_json_settings/4_4_Beats_gmd.json", [4],
    "ROLAND_REDUCED_MAPPING", "train")
input_hvo_seq = test_set[np.random.randint(0, len(test_set))]
input_groove_hvo = torch.tensor(input_hvo_seq.flatten_voices(), dtype=torch.float32)

Pass groove to model and sample a drum pattern

from model.modelLoadesSamplers import predict_using_mgt
voice_thresholds = [0.5] * 9           # per voice sampling thresholds
voice_max_count_allowed = [32] * 9     # per voice max number of hits allowed
output_hvo = predict_using_mgt(GrooveTransformer, input_groove_hvo, voice_thresholds,
                            voice_max_count_allowed, return_concatenated=True)

Inspect generations by synthesizing to audio [link to documentation], store to midi [link to documentation] , or plot pianorolls [link to documentation]

from hvo_sequence.hvo_seq import zero_like
input = input_hvo_seq
groove = zero_like(input_hvo_seq)                        # create template for groove hvo_sequence object
groove.hvo = input_groove_hvo.cpu().detach().numpy()                     # add score
output = zero_like(input_hvo_seq)                        # create template for output hvo_sequence object
output.hvo = output_hvo[0, :, :].cpu().detach().numpy()                    # add score


input.to_html_plot("in.html", show_figure=True)
groove.to_html_plot("groove.html", show_figure=True)
output.to_html_plot("output.html", show_figure=True)