Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fastText model #38

Merged
merged 30 commits into from
Sep 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2b7923e
Integrate BERT into Hedwig (#29)
achyudh Apr 14, 2019
cb14201
Resolve conflicts in the dev fork
achyudh Apr 14, 2019
8346514
Merge branch 'karkaroff-master'
achyudh Apr 14, 2019
fff8e0a
Resolve merge conflicts in README.md
achyudh Apr 14, 2019
0979f77
Add TREC relevance datasets
achyudh Apr 19, 2019
e5f2ee0
Add relevance transfer trainer and evaluator
achyudh Apr 19, 2019
57f0680
Add re-ranking module
achyudh Apr 19, 2019
7d26d71
Add ImbalancedDatasetSampler
achyudh Apr 19, 2019
eab4fc2
Add relevance transfer package
achyudh Apr 19, 2019
a08b2d1
Fix import in classification trainer
achyudh Apr 19, 2019
cb3ca31
Merge remote-tracking branch 'castorini/master'
achyudh Apr 19, 2019
0890eae
Remove unwanted args from models/bert
achyudh Apr 29, 2019
a8de77c
Merge remote-tracking branch 'castorini/master'
achyudh Apr 29, 2019
1116c64
Fix bug where model wasn't in training mode every epoch
achyudh May 2, 2019
8c36691
Merge remote-tracking branch 'castorini/master'
achyudh May 2, 2019
0f34aa0
Add Robust45 preprocessor for BERT
achyudh May 5, 2019
7bed0f1
Add support for BERT for relevance transfer
achyudh May 5, 2019
6c8c728
Add hierarchical BERT model
achyudh Jul 3, 2019
615fa27
Remove tensorboardX logging
achyudh Jul 7, 2019
b40cccb
Add hierarchical BERT for relevance transfer
achyudh Jul 7, 2019
70ec667
Merge remote-tracking branch 'castorini/master'
achyudh Jul 7, 2019
1b031a8
Add learning rate multiplier
achyudh Sep 1, 2019
a987e2c
Merge branch 'master' of github.com:castorini/hedwig
achyudh Sep 1, 2019
e81cfff
Add lr multiplier for relevance transfer
achyudh Sep 2, 2019
4758607
Add MLP model
achyudh Sep 7, 2019
289cde0
Add fastText model
achyudh Sep 8, 2019
12a09da
Add Reuters bag-of-words dataset class
achyudh Sep 8, 2019
bcf1dca
Add input dropout for MLP
achyudh Sep 8, 2019
7aeded5
Merge branch 'master' of github.com:castorini/hedwig
achyudh Sep 8, 2019
448b087
Remove duplicate README files
achyudh Sep 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchtext.data.iterator import BucketIterator
from torchtext.vocab import Vectors

from datasets.reuters import clean_string, split_sents
from datasets.reuters import clean_string, split_sents, process_labels, generate_ngrams


def char_quantize(string, max_length=500):
Expand All @@ -18,15 +18,6 @@ def char_quantize(string, max_length=500):
return np.concatenate((quantized_string, np.zeros((max_length - len(quantized_string), len(IMDBCharQuantized.ALPHABET)), dtype=np.float32)))


def process_labels(string):
"""
Returns the label string as a list of integers
:param string:
:return:
"""
return [float(x) for x in string]


class IMDB(TabularDataset):
NAME = 'IMDB'
NUM_CLASSES = 10
Expand Down Expand Up @@ -70,6 +61,11 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
sort_within_batch=True, device=device)


class IMDBHierarchical(IMDB):
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)


class IMDBCharQuantized(IMDB):
ALPHABET = dict(map(lambda t: (t[1], t[0]), enumerate(list("""abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"""))))
TEXT_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=char_quantize)
Expand All @@ -85,8 +81,3 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
"""
train, val, test = cls.splits(path)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device)


class IMDBHierarchical(IMDB):
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)
22 changes: 15 additions & 7 deletions datasets/reuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def split_sents(string):
return string.strip().split('.')


def generate_ngrams(tokens, n=2):
n_grams = zip(*[tokens[i:] for i in range(n)])
tokens.extend(['-'.join(x) for x in n_grams])
return tokens


def load_json(string):
split_val = json.loads(string)
return np.asarray(split_val, dtype=np.float32)
Expand All @@ -44,8 +50,6 @@ def char_quantize(string, max_length=1000):
def process_labels(string):
"""
Returns the label string as a list of integers
:param string:
:return:
"""
return [float(x) for x in string]

Expand Down Expand Up @@ -93,6 +97,15 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
sort_within_batch=True, device=device)


class ReutersBOW(Reuters):
TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, preprocessing=generate_ngrams, include_lengths=True)


class ReutersHierarchical(Reuters):
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)


class ReutersCharQuantized(Reuters):
ALPHABET = dict(map(lambda t: (t[1], t[0]), enumerate(list("""abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"""))))
TEXT_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=char_quantize)
Expand Down Expand Up @@ -138,8 +151,3 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
"""
train, val, test = cls.splits(path)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device)


class ReutersHierarchical(Reuters):
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)
Empty file added models/fasttext/__init__.py
Empty file.
150 changes: 150 additions & 0 deletions models/fasttext/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import random
from copy import deepcopy

import numpy as np
import torch
import torch.onnx

from common.evaluate import EvaluatorFactory
from common.train import TrainerFactory
from datasets.aapd import AAPD
from datasets.imdb import IMDB
from datasets.reuters import ReutersBOW
from datasets.yelp2014 import Yelp2014
from models.fasttext.args import get_args
from models.fasttext.model import FastText


class UnknownWordVecCache(object):
"""
Caches the first randomly generated word vector for a certain size to make it is reused.
"""
cache = {}

@classmethod
def unk(cls, tensor):
size_tup = tuple(tensor.size())
if size_tup not in cls.cache:
cls.cache[size_tup] = torch.Tensor(tensor.size())
cls.cache[size_tup].uniform_(-0.25, 0.25)
return cls.cache[size_tup]


def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device, is_multilabel):
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device)
if hasattr(saved_model_evaluator, 'is_multilabel'):
saved_model_evaluator.is_multilabel = is_multilabel

scores, metric_names = saved_model_evaluator.get_scores()
print('Evaluation metrics for', split_name)
print(metric_names)
print(scores)


if __name__ == '__main__':
# Set default configuration in args.py
args = get_args()

# Set random seed for reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True

if not args.cuda:
args.gpu = -1

if torch.cuda.is_available() and args.cuda:
print('Note: You are using GPU for training')
torch.cuda.set_device(args.gpu)
torch.cuda.manual_seed(args.seed)
args.gpu = torch.device('cuda:%d' % args.gpu)

if torch.cuda.is_available() and not args.cuda:
print('Warning: Using CPU for training')

dataset_map = {
'Reuters': ReutersBOW,
'AAPD': AAPD,
'IMDB': IMDB,
'Yelp2014': Yelp2014
}

if args.dataset not in dataset_map:
raise ValueError('Unrecognized dataset')
else:
dataset_class = dataset_map[args.dataset]
train_iter, dev_iter, test_iter = dataset_map[args.dataset].iters(args.data_dir, args.word_vectors_file,
args.word_vectors_dir,
batch_size=args.batch_size, device=args.gpu,
unk_init=UnknownWordVecCache.unk)

config = deepcopy(args)
config.dataset = train_iter.dataset
config.target_class = train_iter.dataset.NUM_CLASSES
config.words_num = len(train_iter.dataset.TEXT_FIELD.vocab)

print('Dataset:', args.dataset)
print('No. of target classes:', train_iter.dataset.NUM_CLASSES)
print('No. of train instances', len(train_iter.dataset))
print('No. of dev instances', len(dev_iter.dataset))
print('No. of test instances', len(test_iter.dataset))

if args.resume_snapshot:
if args.cuda:
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage.cuda(args.gpu))
else:
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage)
else:
model = FastText(config)
if args.cuda:
model.cuda()

if not args.trained_model:
save_path = os.path.join(args.save_path, dataset_map[args.dataset].NAME)
os.makedirs(save_path, exist_ok=True)

parameter = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay)

train_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, train_iter, args.batch_size, args.gpu)
test_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, test_iter, args.batch_size, args.gpu)
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, dev_iter, args.batch_size, args.gpu)

if hasattr(train_evaluator, 'is_multilabel'):
train_evaluator.is_multilabel = dataset_class.IS_MULTILABEL
if hasattr(test_evaluator, 'is_multilabel'):
test_evaluator.is_multilabel = dataset_class.IS_MULTILABEL
if hasattr(dev_evaluator, 'is_multilabel'):
dev_evaluator.is_multilabel = dataset_class.IS_MULTILABEL

trainer_config = {
'optimizer': optimizer,
'batch_size': args.batch_size,
'log_interval': args.log_every,
'patience': args.patience,
'model_outfile': args.save_path,
'is_multilabel': dataset_class.IS_MULTILABEL
}

trainer = TrainerFactory.get_trainer(args.dataset, model, None, train_iter, trainer_config, train_evaluator, test_evaluator, dev_evaluator)

if not args.trained_model:
trainer.train(args.epochs)
else:
if args.cuda:
model = torch.load(args.trained_model, map_location=lambda storage, location: storage.cuda(args.gpu))
else:
model = torch.load(args.trained_model, map_location=lambda storage, location: storage)

# Calculate dev and test metrics
if hasattr(trainer, 'snapshot_path'):
model = torch.load(trainer.snapshot_path)

evaluate_dataset('dev', dataset_map[args.dataset], model, None, dev_iter, args.batch_size,
is_multilabel=dataset_class.IS_MULTILABEL,
device=args.gpu)
evaluate_dataset('test', dataset_map[args.dataset], model, None, test_iter, args.batch_size,
is_multilabel=dataset_class.IS_MULTILABEL,
device=args.gpu)
23 changes: 23 additions & 0 deletions models/fasttext/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

import models.args


def get_args():
parser = models.args.get_args()

parser.add_argument('--dataset', type=str, default='Reuters', choices=['Reuters', 'AAPD', 'IMDB', 'Yelp2014'])
parser.add_argument('--mode', type=str, default='rand', choices=['rand', 'static', 'non-static'])
parser.add_argument('--words-dim', type=int, default=300)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--epoch-decay', type=int, default=15)
parser.add_argument('--weight-decay', type=float, default=0)

parser.add_argument('--word-vectors-dir', default=os.path.join(os.pardir, 'hedwig-data', 'embeddings', 'word2vec'))
parser.add_argument('--word-vectors-file', default='GoogleNews-vectors-negative300.txt')
parser.add_argument('--save-path', type=str, default=os.path.join('model_checkpoints', 'kim_cnn'))
parser.add_argument('--resume-snapshot', type=str)
parser.add_argument('--trained-model', type=str)

args = parser.parse_args()
return args
44 changes: 44 additions & 0 deletions models/fasttext/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn as nn

import torch.nn.functional as F


class FastText(nn.Module):

def __init__(self, config):
super().__init__()
dataset = config.dataset
target_class = config.target_class
words_num = config.words_num
words_dim = config.words_dim
self.mode = config.mode

if config.mode == 'rand':
rand_embed_init = torch.Tensor(words_num, words_dim).uniform_(-0.25, 0.25)
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False)
elif config.mode == 'static':
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True)
elif config.mode == 'non-static':
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False)
else:
print("Unsupported Mode")
exit()

self.dropout = nn.Dropout(config.dropout)
self.fc1 = nn.Linear(words_dim, target_class)

def forward(self, x, **kwargs):
if self.mode == 'rand':
x = self.embed(x) # (batch, sent_len, embed_dim)
elif self.mode == 'static':
x = self.static_embed(x) # (batch, sent_len, embed_dim)
elif self.mode == 'non-static':
x = self.non_static_embed(x) # (batch, sent_len, embed_dim)

x = F.avg_pool2d(x, (x.shape[1], 1)).squeeze(1) # (batch, embed_dim)

logit = self.fc1(x) # (batch, target_size)
return logit


48 changes: 0 additions & 48 deletions models/mlp/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion models/mlp/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def get_args():

parser.add_argument('--dataset', type=str, default='Reuters', choices=['Reuters', 'AAPD', 'IMDB', 'Yelp2014'])
parser.add_argument('--embed-dim', type=int, default=300)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--dropout', type=float, default=0)
parser.add_argument('--epoch-decay', type=int, default=15)
parser.add_argument('--weight-decay', type=float, default=0)

Expand Down
4 changes: 2 additions & 2 deletions models/mlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ def __init__(self, config):
super().__init__()
dataset = config.dataset
target_class = config.target_class
# self.dropout = nn.Dropout(config.dropout)
self.dropout = nn.Dropout(config.dropout)
self.fc1 = nn.Linear(dataset.VOCAB_SIZE, target_class)

def forward(self, x, **kwargs):
x = torch.squeeze(x) # (batch, vocab_size)
# x = self.dropout(x)
x = self.dropout(x)
logit = self.fc1(x) # (batch, target_size)
return logit