Skip to content

Commit

Permalink
Merge pull request #3 from j-cahill/naotominakawa-patch-1
Browse files Browse the repository at this point in the history
Add lyrics.py and lyrics_processor.py
  • Loading branch information
j-cahill authored Jul 24, 2019
2 parents 43d0b98 + 4ee01c7 commit 15c6f65
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
34 changes: 34 additions & 0 deletions datasets/bert_processors/lyrics_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

from datasets.bert_processors.abstract_processor import BertProcessor, InputExample


class LyricsProcessor(BertProcessor):
NAME = 'Lyrics'
NUM_CLASSES = 12 # Number of genre; len(df['genre'].unique()) = 12
IS_MULTILABEL = False

def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, 'Lyrics', 'train.tsv')), 'train')

def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, 'Lyrics', 'dev.tsv')), 'dev')

def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, 'Lyrics', 'test.tsv')), 'test')

def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[1]
label = line[0]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
105 changes: 105 additions & 0 deletions datasets/lyrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import re

import numpy as np
import torch
from torchtext.data import NestedField, Field, TabularDataset
from torchtext.data.iterator import BucketIterator
from torchtext.vocab import Vectors


def clean_string(string):
"""
Performs tokenization and string cleaning for the Lyrics dataset
"""
string = re.sub(r"[^A-Za-z0-9(),!?\'`]", " ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.lower().strip().split()


def split_sents(string):
string = re.sub(r"[!?]"," ", string)
return string.strip().split('.')


def char_quantize(string, max_length=1000):
identity = np.identity(len(LyricsCharQuantized.ALPHABET))
quantized_string = np.array([identity[LyricsCharQuantized.ALPHABET[char]] for char in list(string.lower()) if char in LyricsCharQuantized.ALPHABET], dtype=np.float32)
if len(quantized_string) > max_length:
return quantized_string[:max_length]
else:
return np.concatenate((quantized_string, np.zeros((max_length - len(quantized_string), len(LyricsCharQuantized.ALPHABET)), dtype=np.float32)))


def process_labels(string):
"""
Returns the label string as a list of integers
:param string:
:return:
"""
return [float(x) for x in string]


class Lyrics(TabularDataset):
NAME = 'Lyrics'
NUM_CLASSES = 12
IS_MULTILABEL = True

TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)

@staticmethod
def sort_key(ex):
return len(ex.text)

@classmethod
def splits(cls, path, train=os.path.join('Lyrics', 'train.tsv'),
validation=os.path.join('Lyrics', 'dev.tsv'),
test=os.path.join('Lyrics', 'test.tsv'), **kwargs):
return super(Lyrics, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param vectors_name: name of word vectors file
:param vectors_cache: path to directory containing word vectors file
:param batch_size: batch size
:param device: GPU device
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes
:param unk_init: function used to generate vector for OOV words
:return:
"""
if vectors is None:
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init)

train, val, test = cls.splits(path)
cls.TEXT_FIELD.build_vocab(train, val, test, vectors=vectors)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)


class LyricsCharQuantized(Lyrics):
ALPHABET = dict(map(lambda t: (t[1], t[0]), enumerate(list("""abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"""))))
TEXT_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=char_quantize)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param batch_size: batch size
:param device: GPU device
:return:
"""
train, val, test = cls.splits(path)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device)


class LyricsHierarchical(Lyrics):
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)

0 comments on commit 15c6f65

Please sign in to comment.