From 35c96e5e3ff9b08d986854149b41e02e8289378e Mon Sep 17 00:00:00 2001 From: Anthony Moi Date: Tue, 24 Aug 2021 11:08:07 +0200 Subject: [PATCH] Add tests for from_pretrained --- bindings/node/lib/bindings/tokenizer.d.ts | 16 +++++++- bindings/node/lib/bindings/tokenizer.test.ts | 27 ++++++++++++++ .../python/tests/bindings/test_tokenizer.py | 14 +++++++ tokenizers/tests/from_pretrained.rs | 37 +++++++++++++++++++ 4 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 tokenizers/tests/from_pretrained.rs diff --git a/bindings/node/lib/bindings/tokenizer.d.ts b/bindings/node/lib/bindings/tokenizer.d.ts index 5d8dd9566..1b92af8ac 100644 --- a/bindings/node/lib/bindings/tokenizer.d.ts +++ b/bindings/node/lib/bindings/tokenizer.d.ts @@ -7,6 +7,19 @@ import { PreTokenizer } from "./pre-tokenizers"; import { RawEncoding } from "./raw-encoding"; import { Trainer } from "./trainers"; +export interface FromPretrainedOptions { + /** + * The revision to download + * @default "main" + */ + revision?: string; + /** + * The auth token to use to access private repositories on the Hugging Face Hub + * @default undefined + */ + authToken?: string; +} + export interface TruncationOptions { /** * The length of the previous sequence to be included in the overflowing sequence @@ -128,8 +141,9 @@ export class Tokenizer { * Hugging Face Hub. Any model repo containing a `tokenizer.json` * can be used here. * @param identifier A model identifier on the Hub + * @param options Additional options */ - static fromPretrained(s: string): Tokenizer; + static fromPretrained(s: string, options?: FromPretrainedOptions): Tokenizer; /** * Add the given tokens to the vocabulary diff --git a/bindings/node/lib/bindings/tokenizer.test.ts b/bindings/node/lib/bindings/tokenizer.test.ts index bd03df180..756e1ea53 100644 --- a/bindings/node/lib/bindings/tokenizer.test.ts +++ b/bindings/node/lib/bindings/tokenizer.test.ts @@ -95,6 +95,33 @@ describe("Tokenizer", () => { expect(typeof tokenizer.train).toBe("function"); }); + it("can be instantiated from the hub", async () => { + let tokenizer: Tokenizer; + let encode: ( + sequence: InputSequence, + pair?: InputSequence | null, + options?: EncodeOptions | null + ) => Promise; + let output: RawEncoding; + + tokenizer = Tokenizer.fromPretrained("bert-base-cased"); + encode = promisify(tokenizer.encode.bind(tokenizer)); + output = await encode("Hey there dear friend!", null, { addSpecialTokens: false }); + expect(output.getTokens()).toEqual(["Hey", "there", "dear", "friend", "!"]); + + tokenizer = Tokenizer.fromPretrained("anthony/tokenizers-test"); + encode = promisify(tokenizer.encode.bind(tokenizer)); + output = await encode("Hey there dear friend!", null, { addSpecialTokens: false }); + expect(output.getTokens()).toEqual(["hey", "there", "dear", "friend", "!"]); + + tokenizer = Tokenizer.fromPretrained("anthony/tokenizers-test", { + revision: "gpt-2", + }); + encode = promisify(tokenizer.encode.bind(tokenizer)); + output = await encode("Hey there dear friend!", null, { addSpecialTokens: false }); + expect(output.getTokens()).toEqual(["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]); + }); + describe("addTokens", () => { it("accepts a list of string as new tokens when initial model is empty", () => { const model = BPE.empty(); diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index f99aa4092..e15d523d3 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -392,3 +392,17 @@ def test_multiprocessing_with_parallelism(self): tokenizer = Tokenizer(BPE()) multiprocessing_with_parallelism(tokenizer, False) multiprocessing_with_parallelism(tokenizer, True) + + def test_from_pretrained(self): + tokenizer = Tokenizer.from_pretrained("bert-base-cased") + output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False) + assert output.tokens == ["Hey", "there", "dear", "friend", "!"] + + def test_from_pretrained_revision(self): + tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test") + output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False) + assert output.tokens == ["hey", "there", "dear", "friend", "!"] + + tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2") + output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False) + assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"] diff --git a/tokenizers/tests/from_pretrained.rs b/tokenizers/tests/from_pretrained.rs new file mode 100644 index 000000000..ec098cf9b --- /dev/null +++ b/tokenizers/tests/from_pretrained.rs @@ -0,0 +1,37 @@ +use tokenizers::{FromPretrainedParameters, Result, Tokenizer}; + +#[test] +fn test_from_pretrained() -> Result<()> { + let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None)?; + let encoding = tokenizer.encode("Hey there dear friend!", false)?; + assert_eq!( + encoding.get_tokens(), + &["Hey", "there", "dear", "friend", "!"] + ); + Ok(()) +} + +#[test] +fn test_from_pretrained_revision() -> Result<()> { + let tokenizer = Tokenizer::from_pretrained("anthony/tokenizers-test", None)?; + let encoding = tokenizer.encode("Hey there dear friend!", false)?; + assert_eq!( + encoding.get_tokens(), + &["hey", "there", "dear", "friend", "!"] + ); + + let tokenizer = Tokenizer::from_pretrained( + "anthony/tokenizers-test", + Some(FromPretrainedParameters { + revision: "gpt-2".to_string(), + ..Default::default() + }), + )?; + let encoding = tokenizer.encode("Hey there dear friend!", false)?; + assert_eq!( + encoding.get_tokens(), + &["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"] + ); + + Ok(()) +}