Skip to content

Commit

Permalink
Add tests for from_pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
n1t0 committed Aug 31, 2021
1 parent ad7090a commit 35c96e5
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 1 deletion.
16 changes: 15 additions & 1 deletion bindings/node/lib/bindings/tokenizer.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ import { PreTokenizer } from "./pre-tokenizers";
import { RawEncoding } from "./raw-encoding";
import { Trainer } from "./trainers";

export interface FromPretrainedOptions {
/**
* The revision to download
* @default "main"
*/
revision?: string;
/**
* The auth token to use to access private repositories on the Hugging Face Hub
* @default undefined
*/
authToken?: string;
}

export interface TruncationOptions {
/**
* The length of the previous sequence to be included in the overflowing sequence
Expand Down Expand Up @@ -128,8 +141,9 @@ export class Tokenizer {
* Hugging Face Hub. Any model repo containing a `tokenizer.json`
* can be used here.
* @param identifier A model identifier on the Hub
* @param options Additional options
*/
static fromPretrained(s: string): Tokenizer;
static fromPretrained(s: string, options?: FromPretrainedOptions): Tokenizer;

/**
* Add the given tokens to the vocabulary
Expand Down
27 changes: 27 additions & 0 deletions bindings/node/lib/bindings/tokenizer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,33 @@ describe("Tokenizer", () => {
expect(typeof tokenizer.train).toBe("function");
});

it("can be instantiated from the hub", async () => {
let tokenizer: Tokenizer;
let encode: (
sequence: InputSequence,
pair?: InputSequence | null,
options?: EncodeOptions | null
) => Promise<RawEncoding>;
let output: RawEncoding;

tokenizer = Tokenizer.fromPretrained("bert-base-cased");
encode = promisify(tokenizer.encode.bind(tokenizer));
output = await encode("Hey there dear friend!", null, { addSpecialTokens: false });
expect(output.getTokens()).toEqual(["Hey", "there", "dear", "friend", "!"]);

tokenizer = Tokenizer.fromPretrained("anthony/tokenizers-test");
encode = promisify(tokenizer.encode.bind(tokenizer));
output = await encode("Hey there dear friend!", null, { addSpecialTokens: false });
expect(output.getTokens()).toEqual(["hey", "there", "dear", "friend", "!"]);

tokenizer = Tokenizer.fromPretrained("anthony/tokenizers-test", {
revision: "gpt-2",
});
encode = promisify(tokenizer.encode.bind(tokenizer));
output = await encode("Hey there dear friend!", null, { addSpecialTokens: false });
expect(output.getTokens()).toEqual(["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]);
});

describe("addTokens", () => {
it("accepts a list of string as new tokens when initial model is empty", () => {
const model = BPE.empty();
Expand Down
14 changes: 14 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,17 @@ def test_multiprocessing_with_parallelism(self):
tokenizer = Tokenizer(BPE())
multiprocessing_with_parallelism(tokenizer, False)
multiprocessing_with_parallelism(tokenizer, True)

def test_from_pretrained(self):
tokenizer = Tokenizer.from_pretrained("bert-base-cased")
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
assert output.tokens == ["Hey", "there", "dear", "friend", "!"]

def test_from_pretrained_revision(self):
tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test")
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
assert output.tokens == ["hey", "there", "dear", "friend", "!"]

tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2")
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]
37 changes: 37 additions & 0 deletions tokenizers/tests/from_pretrained.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use tokenizers::{FromPretrainedParameters, Result, Tokenizer};

#[test]
fn test_from_pretrained() -> Result<()> {
let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None)?;
let encoding = tokenizer.encode("Hey there dear friend!", false)?;
assert_eq!(
encoding.get_tokens(),
&["Hey", "there", "dear", "friend", "!"]
);
Ok(())
}

#[test]
fn test_from_pretrained_revision() -> Result<()> {
let tokenizer = Tokenizer::from_pretrained("anthony/tokenizers-test", None)?;
let encoding = tokenizer.encode("Hey there dear friend!", false)?;
assert_eq!(
encoding.get_tokens(),
&["hey", "there", "dear", "friend", "!"]
);

let tokenizer = Tokenizer::from_pretrained(
"anthony/tokenizers-test",
Some(FromPretrainedParameters {
revision: "gpt-2".to_string(),
..Default::default()
}),
)?;
let encoding = tokenizer.encode("Hey there dear friend!", false)?;
assert_eq!(
encoding.get_tokens(),
&["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]
);

Ok(())
}

0 comments on commit 35c96e5

Please sign in to comment.