forked from castorini/hedwig
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from j-cahill/naotominakawa-patch-1
Add lyrics.py and lyrics_processor.py
- Loading branch information
Showing
2 changed files
with
139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
|
||
from datasets.bert_processors.abstract_processor import BertProcessor, InputExample | ||
|
||
|
||
class LyricsProcessor(BertProcessor): | ||
NAME = 'Lyrics' | ||
NUM_CLASSES = 12 # Number of genre; len(df['genre'].unique()) = 12 | ||
IS_MULTILABEL = False | ||
|
||
def get_train_examples(self, data_dir): | ||
return self._create_examples( | ||
self._read_tsv(os.path.join(data_dir, 'Lyrics', 'train.tsv')), 'train') | ||
|
||
def get_dev_examples(self, data_dir): | ||
return self._create_examples( | ||
self._read_tsv(os.path.join(data_dir, 'Lyrics', 'dev.tsv')), 'dev') | ||
|
||
def get_test_examples(self, data_dir): | ||
return self._create_examples( | ||
self._read_tsv(os.path.join(data_dir, 'Lyrics', 'test.tsv')), 'test') | ||
|
||
def _create_examples(self, lines, set_type): | ||
"""Creates examples for the training and dev sets.""" | ||
examples = [] | ||
for (i, line) in enumerate(lines): | ||
if i == 0: | ||
continue | ||
guid = "%s-%s" % (set_type, i) | ||
text_a = line[1] | ||
label = line[0] | ||
examples.append( | ||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) | ||
return examples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
import re | ||
|
||
import numpy as np | ||
import torch | ||
from torchtext.data import NestedField, Field, TabularDataset | ||
from torchtext.data.iterator import BucketIterator | ||
from torchtext.vocab import Vectors | ||
|
||
|
||
def clean_string(string): | ||
""" | ||
Performs tokenization and string cleaning for the Lyrics dataset | ||
""" | ||
string = re.sub(r"[^A-Za-z0-9(),!?\'`]", " ", string) | ||
string = re.sub(r"\s{2,}", " ", string) | ||
return string.lower().strip().split() | ||
|
||
|
||
def split_sents(string): | ||
string = re.sub(r"[!?]"," ", string) | ||
return string.strip().split('.') | ||
|
||
|
||
def char_quantize(string, max_length=1000): | ||
identity = np.identity(len(LyricsCharQuantized.ALPHABET)) | ||
quantized_string = np.array([identity[LyricsCharQuantized.ALPHABET[char]] for char in list(string.lower()) if char in LyricsCharQuantized.ALPHABET], dtype=np.float32) | ||
if len(quantized_string) > max_length: | ||
return quantized_string[:max_length] | ||
else: | ||
return np.concatenate((quantized_string, np.zeros((max_length - len(quantized_string), len(LyricsCharQuantized.ALPHABET)), dtype=np.float32))) | ||
|
||
|
||
def process_labels(string): | ||
""" | ||
Returns the label string as a list of integers | ||
:param string: | ||
:return: | ||
""" | ||
return [float(x) for x in string] | ||
|
||
|
||
class Lyrics(TabularDataset): | ||
NAME = 'Lyrics' | ||
NUM_CLASSES = 12 | ||
IS_MULTILABEL = True | ||
|
||
TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True) | ||
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels) | ||
|
||
@staticmethod | ||
def sort_key(ex): | ||
return len(ex.text) | ||
|
||
@classmethod | ||
def splits(cls, path, train=os.path.join('Lyrics', 'train.tsv'), | ||
validation=os.path.join('Lyrics', 'dev.tsv'), | ||
test=os.path.join('Lyrics', 'test.tsv'), **kwargs): | ||
return super(Lyrics, cls).splits( | ||
path, train=train, validation=validation, test=test, | ||
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)] | ||
) | ||
|
||
@classmethod | ||
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None, | ||
unk_init=torch.Tensor.zero_): | ||
""" | ||
:param path: directory containing train, test, dev files | ||
:param vectors_name: name of word vectors file | ||
:param vectors_cache: path to directory containing word vectors file | ||
:param batch_size: batch size | ||
:param device: GPU device | ||
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes | ||
:param unk_init: function used to generate vector for OOV words | ||
:return: | ||
""" | ||
if vectors is None: | ||
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init) | ||
|
||
train, val, test = cls.splits(path) | ||
cls.TEXT_FIELD.build_vocab(train, val, test, vectors=vectors) | ||
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, | ||
sort_within_batch=True, device=device) | ||
|
||
|
||
class LyricsCharQuantized(Lyrics): | ||
ALPHABET = dict(map(lambda t: (t[1], t[0]), enumerate(list("""abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}""")))) | ||
TEXT_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=char_quantize) | ||
|
||
@classmethod | ||
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None, | ||
unk_init=torch.Tensor.zero_): | ||
""" | ||
:param path: directory containing train, test, dev files | ||
:param batch_size: batch size | ||
:param device: GPU device | ||
:return: | ||
""" | ||
train, val, test = cls.splits(path) | ||
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device) | ||
|
||
|
||
class LyricsHierarchical(Lyrics): | ||
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string) | ||
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents) |