Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prado #10

Merged
merged 7 commits into from
Mar 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bunruija/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

from ..registry import BUNRUIJA_REGISTRY
from ..tokenizers import build_tokenizer
from .classifier import NeuralBaseClassifier
from .lstm import LSTMClassifier
from .prado import PRADO
from .transformer import TransformerClassifier
from . import util


BUNRUIJA_REGISTRY['prado'] = PRADO
Expand Down Expand Up @@ -55,17 +57,22 @@ def build_estimator(self, estimator_data):
estimator_args = estimator_data.get('args', {})
additional_args = self.maybe_need_more_arg(estimator_type)
estimator_args.update(additional_args)

if isinstance(BUNRUIJA_REGISTRY[estimator_type], NeuralBaseClassifier):
estimator_args['saver'] = self.saver
estimator = BUNRUIJA_REGISTRY[estimator_type](**estimator_args)
return estimator_type, estimator

def build(self):
setting = self.config['classifier']
self.saver = util.Saver(self.config)

if isinstance(setting, list):
model = self.build_estimator(setting)[1]
elif isinstance(setting, dict):
model_type = setting['type']
model_args = setting.get('args', {})
model_args['saver'] = self.saver

if model_type in ['stacking', 'voting']:
estimators = model_args.pop('estimators')
Expand All @@ -91,6 +98,7 @@ def build(self):
logger.info(f'model args: {model_args}')
model = BUNRUIJA_REGISTRY[model_type](**model_args)
logger.info(model)

return model


Expand Down
55 changes: 48 additions & 7 deletions bunruija/classifiers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
logger = getLogger(__name__)


def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s


class BaseClassifier(BaseEstimator, ClassifierMixin):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -62,11 +74,38 @@ def __init__(self, **kwargs):
self.max_epochs = kwargs.get('max_epochs', 3)
self.batch_size = kwargs.get('batch_size', 20)

self.log_interval = kwargs.get('log_intarval', 100)
self.optimizer_type = kwargs.get('optimizer', 'adam')
self.save_every_step = kwargs.get('save_every_step', -1)
self.saver = kwargs.get('saver', None)
self.labels = set()

def __repr__(self):
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines

main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'

main_str += ')'
return main_str

def init_layer(self, data):
pass
raise NotImplementedError

def convert_data(self, X, y=None):
logger.info('Loading data')
Expand Down Expand Up @@ -113,8 +152,6 @@ def fit(self, X, y):
loss_accum = 0
n_samples_accum = 0
for epoch in range(self.max_epochs):
# loss_epoch = 0.

for batch in torch.utils.data.DataLoader(
data,
batch_size=self.batch_size,
Expand All @@ -132,24 +169,27 @@ def fit(self, X, y):
batch['labels'],
reduction='sum',
)
# loss_epoch += loss.item()
loss_accum += loss.item()
n_samples_accum += len(batch['labels'])
(loss / len(batch['labels'])).backward()
optimizer.step()
step += 1
del loss

if step % log_interval == 0:
if step % self.log_interval == 0:
loss_accum /= n_samples_accum
elapsed = time.perf_counter() - start_at
logger.info(f'epoch:{epoch+1} step:{step} '
f'loss:{loss_accum:.2f} elapsed:{elapsed:.2f}')
loss_accum = 0
n_samples_accum = 0

# elapsed = time.perf_counter() - start_at
# logger.info(f'epoch:{epoch+1} loss:{loss_epoch:.2f} elapsed:{elapsed:.2f}')
if (
self.save_every_step > -1
and self.saver
and step % self.save_every_step == 0
):
self.saver(self)

def reset_module(self, **kwargs):
pass
Expand All @@ -174,6 +214,7 @@ def zero_grad(self):
p.grad = None

def predict(self, X):
self.to(self.device)
self.eval()

data = self.convert_data(X)
Expand Down
133 changes: 68 additions & 65 deletions bunruija/classifiers/prado.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,84 +22,87 @@ def __init__(self, index):

def __call__(self, module, _):
mask = module.raw_weight.new_ones(module.raw_weight.size())
mask.index_fill_(2, self.index.to(mask.device), 0.)
if self.index.device != mask.device:
mask.index_fill_(2, self.index.to(mask.device), 0.)
else:
mask.index_fill_(2, self.index, 0.)
module.weight = module.raw_weight * mask


class ConvolutionLayer(torch.nn.Module):
def __init__(self, kernel_size, dim_hid):
super().__init__()
self.dim_hid = dim_hid
self.conv = torch.nn.Conv2d(
1,
self.dim_hid,
kernel_size=(kernel_size, self.dim_hid),
stride=1
)

self.batch_norm = torch.nn.BatchNorm1d(self.dim_hid)

def __call__(self, x):
x = self.conv(x).squeeze(3)
x = self.batch_norm(x)
return x


class ProjectAttentionLayer(torch.nn.Module):
def __init__(self, kernel_size, dim_hid, skip_bigram=None):
super().__init__()
self.conv_value = ConvolutionLayer(kernel_size, dim_hid)

if isinstance(skip_bigram, list):
self.conv_value.conv.raw_weight = self.conv_value.conv.weight
del self.conv_value.conv.weight
weight_mask = WeightMask(torch.tensor(skip_bigram))
self.conv_value.conv.register_forward_pre_hook(weight_mask)
elif skip_bigram is None:
pass
else:
raise ValueError(skip_gram)

self.conv_attn = ConvolutionLayer(kernel_size, dim_hid)
self.zero_pad = torch.nn.ZeroPad2d((0, 0, kernel_size - 1, 0))

def __call__(self, x_value_in, x_attn_in):
# (bsz, output_channel, seq_len)
x_value_in = self.zero_pad(x_value_in)
x_value = self.conv_value(x_value_in)

# (bsz, output_channel, seq_len)
x_attn_in = self.zero_pad(x_attn_in)
x_attn = self.conv_attn(x_attn_in)
x_attn = torch.softmax(x_attn, dim=2)

x = x_attn * x_value
# (bsz, output_channel)
x = torch.sum(x, dim=2)
return x


class ProjectedAttention(torch.nn.Module):
def __init__(self, kernel_sizes, dim_hid, skip_bigrams=None):
super().__init__()
self.dim_hid = dim_hid

self.convs_value = torch.nn.ModuleList([
torch.nn.Conv2d(
1,
self.dim_hid,
kernel_size=(kernel_size, self.dim_hid),
stride=1,
) for kernel_size in kernel_sizes
])

if isinstance(skip_bigrams, list):
for c, skip_gram in zip(self.convs_value, skip_bigrams):
if isinstance(skip_gram, list):
c.raw_weight = c.weight
del c.weight
weight_mask = WeightMask(torch.tensor(skip_gram))
c.register_forward_pre_hook(weight_mask)
elif skip_gram is None:
continue
else:
raise ValueError(skip_gram)

self.convs_attn = torch.nn.ModuleList([
torch.nn.Conv2d(
1,
self.dim_hid,
kernel_size=(kernel_size, self.dim_hid),
stride=1,
) for kernel_size in kernel_sizes
])

self.zero_pads = torch.nn.ModuleList([
torch.nn.ZeroPad2d((0, 0, kernel_size - 1, 0))
for kernel_size in kernel_sizes
])

self.batch_norms_value = torch.nn.ModuleList([
torch.nn.BatchNorm1d(self.dim_hid)
for _ in kernel_sizes
])

self.batch_norms_attn = torch.nn.ModuleList([
torch.nn.BatchNorm1d(self.dim_hid)
for _ in kernel_sizes
])
self.layers = torch.nn.ModuleList([
ProjectAttentionLayer(kernel_size, dim_hid, skip_bigram=skip_bigram)
for kernel_size, skip_bigram in zip(kernel_sizes, skip_bigrams)
])
else:
self.layers = torch.nn.ModuleList([
ProjectAttentionLayer(kernel_size, dim_hid, skip_bigram=None)
for kernel_size in kernel_sizes
])

def __call__(self, x_value_in, x_attn_in):
x_list = []

for conv_value, conv_attn, zero_pad, batch_norm_value, batch_norm_attn in zip(
self.convs_value,
self.convs_attn,
self.zero_pads,
self.batch_norms_value,
self.batch_norms_attn
):
# (bsz, output_channel, seq_len)
x_value_in = zero_pad(x_value_in)
x_value = conv_value(x_value_in).squeeze(3)
x_value = batch_norm_value(x_value)

# (bsz, output_channel, seq_len)
x_attn_in = zero_pad(x_attn_in)
x_attn = conv_attn(x_attn_in).squeeze(3)
x_attn = batch_norm_attn(x_attn)
x_attn = torch.softmax(x_attn, dim=2)

x = x_attn * x_value
# (bsz, output_channel)
x = torch.sum(x, dim=2)
for layer in self.layers:
x = layer(x_value_in, x_attn_in)
x_list.append(x)
x = torch.cat(x_list, dim=1)
return x
Expand Down
15 changes: 15 additions & 0 deletions bunruija/classifiers/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pathlib import Path
import pickle


class Saver:
def __init__(self, config):
self.config = config

def __call__(self, model):
with open(Path(self.config.get('bin_dir', '.')) / 'model.bunruija', 'rb') as f:
model_data = pickle.load(f)

with open(Path(self.config.get('bin_dir', '.')) / 'model.bunruija', 'wb') as f:
model_data['classifier'] = model
pickle.dump(model_data, f)
10 changes: 3 additions & 7 deletions bunruija/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pickle
from pathlib import Path
import pickle

import sklearn
import torch
Expand All @@ -18,18 +18,14 @@ def __init__(self, config_file):
self.data = pickle.load(f)

self.model = bunruija.classifiers.build_model(self.config)
self.saver = bunruija.classifiers.util.Saver(self.config)

def train(self):
y_train = self.data['label_train']
X_train = self.data['data_train']
self.model.fit(X_train, y_train)

with open(Path(self.config.get('bin_dir', '.')) / 'model.bunruija', 'rb') as f:
model_data = pickle.load(f)

with open(Path(self.config.get('bin_dir', '.')) / 'model.bunruija', 'wb') as f:
model_data['classifier'] = self.model
pickle.dump(model_data, f)
self.saver(self.model)

if 'label_dev' in self.data:
y_dev = self.data['label_dev']
Expand Down
4 changes: 4 additions & 0 deletions bunruija_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ def main(args):

def cli_main():
main(sys.argv[1:])


if __name__ == '__main__':
cli_main()
1 change: 1 addition & 0 deletions example/yelp_polarity/settings/prado.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ classifier:
lr: 0.001
max_epochs: 3
weight_decay: 0.01
save_every_step: 1000