Skip to content

Commit

Permalink
Add transformer-baesd classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
tma15 committed Nov 29, 2020
1 parent 9adaab3 commit f3aecda
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 36 deletions.
2 changes: 2 additions & 0 deletions bunruija/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..registry import BUNRUIJA_REGISTRY
from ..tokenizers import build_tokenizer
from .lstm import LSTMClassifier
from .transformer import TransformerClassifier


BUNRUIJA_REGISTRY['svm'] = SVC
Expand All @@ -24,6 +25,7 @@
BUNRUIJA_REGISTRY['lstm'] = LSTMClassifier
BUNRUIJA_REGISTRY['pipeline'] = Pipeline
BUNRUIJA_REGISTRY['stacking'] = StackingClassifier
BUNRUIJA_REGISTRY['transformer'] = TransformerClassifier
BUNRUIJA_REGISTRY['voting'] = VotingClassifier


Expand Down
29 changes: 26 additions & 3 deletions bunruija/classifiers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ def __init__(self, **kwargs):
def init_layer(self, data):
pass

def convert_data(self, X, y=None):
if len(X) == 2 and isinstance(X[1], list):
indices = X[0]
raw_words = X[1]
has_raw_words = True
else:
has_raw_words = False
indices = X
raw_words = None

data = []
for i in range(len(indices.indptr) - 1):
start = indices.indptr[i]
end = indices.indptr[i + 1]
data_i = {
'inputs': indices.data[start: end],
}

if y is not None:
data_i['label'] = y[i]

if has_raw_words:
data_i['raw_words'] = raw_words[start: end]
data.append(data_i)
return data

def fit(self, X, y):
data = self.convert_data(X, y)
self.init_layer(data)
Expand Down Expand Up @@ -112,9 +138,6 @@ def reset_module(self, **kwargs):
def classifier_args(self):
raise NotImplementedError

def convert_data(self, X, y=None):
raise NotImplementedError

def build_optimizer(self):
lr = float(self.kwargs.get('lr', 0.001))
weight_decay = self.kwargs.get('weight_decay', 0.)
Expand Down
26 changes: 0 additions & 26 deletions bunruija/classifiers/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,6 @@ def init_layer(self, data):
len(num_classes),
bias=True)

def convert_data(self, X, y=None):
if len(X) == 2 and isinstance(X[1], list):
indices = X[0]
raw_words = X[1]
has_raw_words = True
else:
has_raw_words = False
indices = X
raw_words = None

data = []
for i in range(len(indices.indptr) - 1):
start = indices.indptr[i]
end = indices.indptr[i + 1]
data_i = {
'inputs': indices.data[start: end],
}

if y is not None:
data_i['label'] = y[i]

if has_raw_words:
data_i['raw_words'] = raw_words[start: end]
data.append(data_i)
return data

def __call__(self, batch):
x = batch['inputs']
lengths = (x != self.pad).sum(dim=1)
Expand Down
38 changes: 38 additions & 0 deletions bunruija/classifiers/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import torch
from transformers import AutoModel
from transformers import AutoTokenizer

from bunruija.classifiers.classifier import NeuralBaseClassifier


class TransformerClassifier(NeuralBaseClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)

from_pretrained = kwargs.pop('from_pretrained', None)
self.model = AutoModel.from_pretrained(from_pretrained)
self.dropout = torch.nn.Dropout(kwargs.get('dropout', 0.1))

tokenizer = AutoTokenizer.from_pretrained(from_pretrained)
self.pad = tokenizer.pad_token_id

def init_layer(self, data):
y = []
max_input_idx = 0
for data_i in data:
y.append(data_i['label'])

num_classes = np.unique(y)
self.out = torch.nn.Linear(
self.model.config.hidden_size,
len(num_classes),
bias=True)

def __call__(self, batch):
input_ids = batch['inputs']
x = self.model(input_ids)
pooled_output = x[1]
pooled_output = self.dropout(pooled_output)
logits = self.out(pooled_output)
return logits
24 changes: 18 additions & 6 deletions bunruija/feature_extraction/sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import TransformerMixin
import transformers

from bunruija import tokenizers
from bunruija.data import Dictionary
Expand Down Expand Up @@ -81,17 +82,28 @@ def transform(self, raw_documents):
tokenizer = self.build_tokenizer()
for row_id, document in enumerate(raw_documents):
elements = tokenizer(document)
max_col = max(max_col, len(elements))

for i, element in enumerate(elements):
if element in self.dictionary:
if self.keep_raw_word:
raw_words.append(element)
index = self.dictionary.get_index(element)
if isinstance(elements, transformers.tokenization_utils_base.BatchEncoding):
input_ids = elements['input_ids']
max_col = max(max_col, len(input_ids))

for i, index in enumerate(input_ids):
data.append(index)
row.append(row_id)
col.append(i)

else:
max_col = max(max_col, len(elements))

for i, element in enumerate(elements):
if element in self.dictionary:
if self.keep_raw_word:
raw_words.append(element)
index = self.dictionary.get_index(element)
data.append(index)
row.append(row_id)
col.append(i)

data = np.array(data)
row = np.array(row)
col = np.array(col)
Expand Down
9 changes: 8 additions & 1 deletion bunruija/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@
from .tokenizer import BaseTokenizer
from .mecab_tokenizer import MeCabTokenizer

from transformers import AutoTokenizer


BUNRUIJA_REGISTRY['mecab'] = MeCabTokenizer
BUNRUIJA_REGISTRY['auto'] = AutoTokenizer


def build_tokenizer(config):
tokenizer_type = config.get('tokenizer', {}).get('type', 'mecab')
tokenizer_args = config.get('tokenizer', {}).get('args', {})

tokenizer = BUNRUIJA_REGISTRY[tokenizer_type](**tokenizer_args)
if 'from_pretrained' in tokenizer_args:
tokenizer = BUNRUIJA_REGISTRY[tokenizer_type].from_pretrained(
tokenizer_args['from_pretrained'])
else:
tokenizer = BUNRUIJA_REGISTRY[tokenizer_type](**tokenizer_args)
return tokenizer


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'mecab-python3==0.996.5',
'pyyaml',
'torch>=1.6.0',
'transformers>=3.5.1',
'scikit-learn>=0.23.2',
'unidic-lite',
],
Expand Down

0 comments on commit f3aecda

Please sign in to comment.