Skip to content

Commit

Permalink
Add support for albert models
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 9, 2023
1 parent b0f5bb3 commit 55b27fc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 3 deletions.
20 changes: 20 additions & 0 deletions scripts/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@
'default',
],
},
'albert': {
'albert-base-v2': [
'default',
'masked-lm',
],
'albert-large-v2': [
'default',
'masked-lm',
],
'sentence-transformers/paraphrase-albert-small-v2': [
'default',
],
'sentence-transformers/paraphrase-albert-base-v2': [
'default',
],
},
'distilbert': {
'distilbert-base-uncased': [
'default',
Expand Down Expand Up @@ -150,6 +166,10 @@
'token-classification',
'question-answering'
],
'albert': [
'default',
'masked-lm',
],
'distilbert': [
'default',
'masked-lm',
Expand Down
35 changes: 33 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class AutoModel {
switch (config.model_type) {
case 'bert':
return new BertModel(config, session);
case 'albert':
return new AlbertModel(config, session);
case 'distilbert':
return new DistilBertModel(config, session);
case 't5':
Expand Down Expand Up @@ -129,6 +131,8 @@ class AutoModelForSequenceClassification {
switch (config.model_type) {
case 'bert':
return new BertForSequenceClassification(config, session);
case 'albert':
return new AlbertForSequenceClassification(config, session);
case 'distilbert':
return new DistilBertForSequenceClassification(config, session);
case 'roberta':
Expand Down Expand Up @@ -224,6 +228,8 @@ class AutoModelForMaskedLM {
switch (config.model_type) {
case 'bert':
return new BertForMaskedLM(config, session);
case 'albert':
return new AlbertForMaskedLM(config, session);
case 'distilbert':
return new DistilBertForMaskedLM(config, session);
case 'roberta':
Expand Down Expand Up @@ -255,10 +261,10 @@ class AutoModelForQuestionAnswering {
switch (config.model_type) {
case 'bert':
return new BertForQuestionAnswering(config, session);

case 'albert':
return new AlbertForQuestionAnswering(config, session);
case 'distilbert':
return new DistilBertForQuestionAnswering(config, session);

case 'roberta':
return new RobertaForQuestionAnswering(config, session);

Expand Down Expand Up @@ -543,6 +549,31 @@ class DistilBertForMaskedLM extends DistilBertPreTrainedModel {
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// DistilBert models
class AlbertPreTrainedModel extends PreTrainedModel { }
class AlbertModel extends AlbertPreTrainedModel { }
class AlbertForSequenceClassification extends AlbertPreTrainedModel {
async _call(model_inputs) {
let logits = (await super._call(model_inputs)).logits;
return new SequenceClassifierOutput(logits)
}
}
class AlbertForQuestionAnswering extends AlbertPreTrainedModel {
async _call(model_inputs) {
let outputs = await super._call(model_inputs);
return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits);
}
}
class AlbertForMaskedLM extends AlbertPreTrainedModel {
async _call(model_inputs) {
let logits = (await super._call(model_inputs)).logits;
return new MaskedLMOutput(logits)
}
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// T5 models
class T5PreTrainedModel extends PreTrainedModel { };
Expand Down
55 changes: 54 additions & 1 deletion src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ class Normalizer extends Callable {
return new NormalizerSequence(config);
case 'Replace':
return new Replace(config);
case 'NFKD':
return new NFKD(config);
case 'StripAccents':
return new StripAccents(config);
case 'Lowercase':
return new Lowercase(config);
default:
throw new Error(`Unknown Normalizer type: ${config.type}`);
}
Expand All @@ -338,8 +344,35 @@ class Normalizer extends Callable {
class Replace extends Normalizer {
normalize(text) {
// TODO: this.config.pattern might not be Regex.
if (this.config.pattern.Regex) {
text = text.replace(new RegExp(this.config.pattern.Regex, 'g'), this.config.content)

text = text.replace(new RegExp(this.config.pattern.Regex, 'g'), this.config.content)
} else if (this.config.pattern.String) {
text = text.replace(this.config.pattern.String, this.config.content)

} else {
console.warn('Unknown pattern type:', this.config.pattern)
}

return text;
}
}

class NFKD extends Normalizer {
normalize(text) {
text = text.normalize('NFKD')
return text;
}
}
class StripAccents extends Normalizer {
normalize(text) {
text = text.replace(/[\u0300-\u036f]/g, '');
return text;
}
}
class Lowercase extends Normalizer {
normalize(text) {
text = text.toLowerCase();
return text;
}
}
Expand Down Expand Up @@ -693,6 +726,9 @@ class AutoTokenizer {
case 'BertTokenizer':
return new BertTokenizer(tokenizerJSON, tokenizerConfig);

case 'AlbertTokenizer':
return new AlbertTokenizer(tokenizerJSON, tokenizerConfig);

case 'GPT2Tokenizer':
return new GPT2Tokenizer(tokenizerJSON, tokenizerConfig);

Expand Down Expand Up @@ -733,6 +769,11 @@ class PreTrainedTokenizer extends Callable {

// Set mask token if present (otherwise will be undefined, which is fine)
this.mask_token = this.tokenizerConfig.mask_token;
if (typeof this.mask_token === 'object') {
// sometimes of type: 'AddedToken'
this.mask_token = this.mask_token.content
}

this.mask_token_id = this.model.tokens_to_ids[this.mask_token];

this.pad_token = this.tokenizerConfig.pad_token ?? this.tokenizerConfig.eos_token;
Expand All @@ -742,6 +783,8 @@ class PreTrainedTokenizer extends Callable {
this.sep_token_id = this.model.tokens_to_ids[this.sep_token];

this.model_max_length = this.tokenizerConfig.model_max_length;

this.remove_space = this.tokenizerConfig.remove_space;
}

static async from_pretrained(modelPath, progressCallback = null) {
Expand Down Expand Up @@ -870,6 +913,10 @@ class PreTrainedTokenizer extends Callable {
// Ignore special tokens
return x
} else {
if (this.remove_space) {
// remove_space
x = x.trim().split(/\s+/).join(' ')
}
// Actually perform encoding
if (this.normalizer !== null) {
x = this.normalizer(x);
Expand Down Expand Up @@ -969,6 +1016,12 @@ class BertTokenizer extends PreTrainedTokenizer {
return inputs;
}
}
class AlbertTokenizer extends PreTrainedTokenizer {
prepare_model_inputs(inputs) {
inputs.token_type_ids = inputs.input_ids.map(x => new Array(x.length).fill(0))
return inputs;
}
}
class DistilBertTokenizer extends PreTrainedTokenizer { }
class T5Tokenizer extends PreTrainedTokenizer { }
class GPT2Tokenizer extends PreTrainedTokenizer { }
Expand Down

0 comments on commit 55b27fc

Please sign in to comment.