Skip to content

Commit

Permalink
added predict fn
Browse files Browse the repository at this point in the history
  • Loading branch information
teddykoker committed May 19, 2020
1 parent fe3e4bd commit 6f7ff0d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@ Ibanez TS9 Tube Screamer (all knobs at 12 o'clock).<br>

## Usage

**Prepare data**:

**Run effect on .wav file**:
Must be single channel, 44.1 kHz
```bash
python prepare_data.py data/in.wav data/out_ts9.wav
# must be same data used to train
python prepare_data.py data/in.wav data/out_ts9.wav

# specify input file and desired output file
python predict.py my_input_guitar.wav my_output.wav

# if you trained you own model you can pass --model flag
# with path to .ckpt
```

**Train**:
```bash
python prepare_data.py data/in.wav data/out_ts9.wav # or use your own!
python train.py
python train.py --gpus "0,1" # for multiple gpus
python train.py -h # help (see for other hyperparameters)
Expand Down
58 changes: 58 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pickle
import torch
from tqdm import tqdm
from scipy.io import wavfile
import argparse
import numpy as np

from model import PedalNet


def save(name, data):
wavfile.write(name, 44100, data.flatten().astype(np.int16))


@torch.no_grad()
def predict(args):
model = PedalNet.load_from_checkpoint(args.model)
model.eval()
train_data = pickle.load(open(args.train_data, "rb"))

mean, std = train_data["mean"], train_data["std"]

in_rate, in_data = wavfile.read(args.input)
assert in_rate == 44100, "input data needs to be 44.1 kHz"
sample_size = int(in_rate * args.sample_time)
length = len(in_data) - len(in_data) % sample_size

# split into samples
in_data = in_data[:length].reshape((-1, 1, sample_size)).astype(np.float32)

# standardize
in_data = (in_data - mean) / std

# pad each sample with previous sample
prev_sample = np.concatenate((np.zeros_like(in_data[0:1]), in_data[:-1]), axis=0)
pad_in_data = np.concatenate((prev_sample, in_data), axis=2)

pred = []
batches = pad_in_data.shape[0] // args.batch_size
for x in tqdm(np.array_split(pad_in_data, batches)):
pred.append(model(torch.from_numpy(x)).numpy())

pred = np.concatenate(pred)
pred = pred[:, :, -in_data.shape[2] :]

save(args.output, pred)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="models/pedalnet.ckpt")
parser.add_argument("--train_data", default="data.pickle")
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--sample_time", type=float, default=100e-3)
parser.add_argument("input")
parser.add_argument("output")
args = parser.parse_args()
predict(args)

0 comments on commit 6f7ff0d

Please sign in to comment.