From f3aecda75c235357440abf730a8828fc01ffb38a Mon Sep 17 00:00:00 2001 From: Takuya Makino Date: Sun, 29 Nov 2020 15:21:53 +0900 Subject: [PATCH] Add transformer-baesd classifier --- bunruija/classifiers/__init__.py | 2 ++ bunruija/classifiers/classifier.py | 29 +++++++++++++++++-- bunruija/classifiers/lstm.py | 26 ----------------- bunruija/classifiers/transformer.py | 38 +++++++++++++++++++++++++ bunruija/feature_extraction/sequence.py | 24 ++++++++++++---- bunruija/tokenizers/__init__.py | 9 +++++- setup.py | 1 + 7 files changed, 93 insertions(+), 36 deletions(-) create mode 100644 bunruija/classifiers/transformer.py diff --git a/bunruija/classifiers/__init__.py b/bunruija/classifiers/__init__.py index 2213046..6d22b9a 100644 --- a/bunruija/classifiers/__init__.py +++ b/bunruija/classifiers/__init__.py @@ -15,6 +15,7 @@ from ..registry import BUNRUIJA_REGISTRY from ..tokenizers import build_tokenizer from .lstm import LSTMClassifier +from .transformer import TransformerClassifier BUNRUIJA_REGISTRY['svm'] = SVC @@ -24,6 +25,7 @@ BUNRUIJA_REGISTRY['lstm'] = LSTMClassifier BUNRUIJA_REGISTRY['pipeline'] = Pipeline BUNRUIJA_REGISTRY['stacking'] = StackingClassifier +BUNRUIJA_REGISTRY['transformer'] = TransformerClassifier BUNRUIJA_REGISTRY['voting'] = VotingClassifier diff --git a/bunruija/classifiers/classifier.py b/bunruija/classifiers/classifier.py index 64e1152..09310a5 100644 --- a/bunruija/classifiers/classifier.py +++ b/bunruija/classifiers/classifier.py @@ -67,6 +67,32 @@ def __init__(self, **kwargs): def init_layer(self, data): pass + def convert_data(self, X, y=None): + if len(X) == 2 and isinstance(X[1], list): + indices = X[0] + raw_words = X[1] + has_raw_words = True + else: + has_raw_words = False + indices = X + raw_words = None + + data = [] + for i in range(len(indices.indptr) - 1): + start = indices.indptr[i] + end = indices.indptr[i + 1] + data_i = { + 'inputs': indices.data[start: end], + } + + if y is not None: + data_i['label'] = y[i] + + if has_raw_words: + data_i['raw_words'] = raw_words[start: end] + data.append(data_i) + return data + def fit(self, X, y): data = self.convert_data(X, y) self.init_layer(data) @@ -112,9 +138,6 @@ def reset_module(self, **kwargs): def classifier_args(self): raise NotImplementedError - def convert_data(self, X, y=None): - raise NotImplementedError - def build_optimizer(self): lr = float(self.kwargs.get('lr', 0.001)) weight_decay = self.kwargs.get('weight_decay', 0.) diff --git a/bunruija/classifiers/lstm.py b/bunruija/classifiers/lstm.py index d623742..91c23d6 100644 --- a/bunruija/classifiers/lstm.py +++ b/bunruija/classifiers/lstm.py @@ -56,32 +56,6 @@ def init_layer(self, data): len(num_classes), bias=True) - def convert_data(self, X, y=None): - if len(X) == 2 and isinstance(X[1], list): - indices = X[0] - raw_words = X[1] - has_raw_words = True - else: - has_raw_words = False - indices = X - raw_words = None - - data = [] - for i in range(len(indices.indptr) - 1): - start = indices.indptr[i] - end = indices.indptr[i + 1] - data_i = { - 'inputs': indices.data[start: end], - } - - if y is not None: - data_i['label'] = y[i] - - if has_raw_words: - data_i['raw_words'] = raw_words[start: end] - data.append(data_i) - return data - def __call__(self, batch): x = batch['inputs'] lengths = (x != self.pad).sum(dim=1) diff --git a/bunruija/classifiers/transformer.py b/bunruija/classifiers/transformer.py new file mode 100644 index 0000000..f222a60 --- /dev/null +++ b/bunruija/classifiers/transformer.py @@ -0,0 +1,38 @@ +import numpy as np +import torch +from transformers import AutoModel +from transformers import AutoTokenizer + +from bunruija.classifiers.classifier import NeuralBaseClassifier + + +class TransformerClassifier(NeuralBaseClassifier): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + from_pretrained = kwargs.pop('from_pretrained', None) + self.model = AutoModel.from_pretrained(from_pretrained) + self.dropout = torch.nn.Dropout(kwargs.get('dropout', 0.1)) + + tokenizer = AutoTokenizer.from_pretrained(from_pretrained) + self.pad = tokenizer.pad_token_id + + def init_layer(self, data): + y = [] + max_input_idx = 0 + for data_i in data: + y.append(data_i['label']) + + num_classes = np.unique(y) + self.out = torch.nn.Linear( + self.model.config.hidden_size, + len(num_classes), + bias=True) + + def __call__(self, batch): + input_ids = batch['inputs'] + x = self.model(input_ids) + pooled_output = x[1] + pooled_output = self.dropout(pooled_output) + logits = self.out(pooled_output) + return logits diff --git a/bunruija/feature_extraction/sequence.py b/bunruija/feature_extraction/sequence.py index 7fbd46c..52e1d06 100644 --- a/bunruija/feature_extraction/sequence.py +++ b/bunruija/feature_extraction/sequence.py @@ -1,6 +1,7 @@ import numpy as np from scipy.sparse import csr_matrix from sklearn.base import TransformerMixin +import transformers from bunruija import tokenizers from bunruija.data import Dictionary @@ -81,17 +82,28 @@ def transform(self, raw_documents): tokenizer = self.build_tokenizer() for row_id, document in enumerate(raw_documents): elements = tokenizer(document) - max_col = max(max_col, len(elements)) - for i, element in enumerate(elements): - if element in self.dictionary: - if self.keep_raw_word: - raw_words.append(element) - index = self.dictionary.get_index(element) + if isinstance(elements, transformers.tokenization_utils_base.BatchEncoding): + input_ids = elements['input_ids'] + max_col = max(max_col, len(input_ids)) + + for i, index in enumerate(input_ids): data.append(index) row.append(row_id) col.append(i) + else: + max_col = max(max_col, len(elements)) + + for i, element in enumerate(elements): + if element in self.dictionary: + if self.keep_raw_word: + raw_words.append(element) + index = self.dictionary.get_index(element) + data.append(index) + row.append(row_id) + col.append(i) + data = np.array(data) row = np.array(row) col = np.array(col) diff --git a/bunruija/tokenizers/__init__.py b/bunruija/tokenizers/__init__.py index 6ad4308..9d22451 100644 --- a/bunruija/tokenizers/__init__.py +++ b/bunruija/tokenizers/__init__.py @@ -2,15 +2,22 @@ from .tokenizer import BaseTokenizer from .mecab_tokenizer import MeCabTokenizer +from transformers import AutoTokenizer + BUNRUIJA_REGISTRY['mecab'] = MeCabTokenizer +BUNRUIJA_REGISTRY['auto'] = AutoTokenizer def build_tokenizer(config): tokenizer_type = config.get('tokenizer', {}).get('type', 'mecab') tokenizer_args = config.get('tokenizer', {}).get('args', {}) - tokenizer = BUNRUIJA_REGISTRY[tokenizer_type](**tokenizer_args) + if 'from_pretrained' in tokenizer_args: + tokenizer = BUNRUIJA_REGISTRY[tokenizer_type].from_pretrained( + tokenizer_args['from_pretrained']) + else: + tokenizer = BUNRUIJA_REGISTRY[tokenizer_type](**tokenizer_args) return tokenizer diff --git a/setup.py b/setup.py index 1372814..9a68f97 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ 'mecab-python3==0.996.5', 'pyyaml', 'torch>=1.6.0', + 'transformers>=3.5.1', 'scikit-learn>=0.23.2', 'unidic-lite', ],