Add tests for from_pretrained

This commit is contained in:
Anthony Moi
2021-08-24 11:08:07 +02:00
committed by Anthony MOI
parent ad7090a5c7
commit 35c96e5e3f
4 changed files with 93 additions and 1 deletions

View File

@ -7,6 +7,19 @@ import { PreTokenizer } from "./pre-tokenizers";
import { RawEncoding } from "./raw-encoding";
import { Trainer } from "./trainers";
export interface FromPretrainedOptions {
/**
* The revision to download
* @default "main"
*/
revision?: string;
/**
* The auth token to use to access private repositories on the Hugging Face Hub
* @default undefined
*/
authToken?: string;
}
export interface TruncationOptions {
/**
* The length of the previous sequence to be included in the overflowing sequence
@ -128,8 +141,9 @@ export class Tokenizer {
* Hugging Face Hub. Any model repo containing a `tokenizer.json`
* can be used here.
* @param identifier A model identifier on the Hub
* @param options Additional options
*/
static fromPretrained(s: string): Tokenizer;
static fromPretrained(s: string, options?: FromPretrainedOptions): Tokenizer;
/**
* Add the given tokens to the vocabulary

View File

@ -95,6 +95,33 @@ describe("Tokenizer", () => {
expect(typeof tokenizer.train).toBe("function");
});
it("can be instantiated from the hub", async () => {
let tokenizer: Tokenizer;
let encode: (
sequence: InputSequence,
pair?: InputSequence | null,
options?: EncodeOptions | null
) => Promise<RawEncoding>;
let output: RawEncoding;
tokenizer = Tokenizer.fromPretrained("bert-base-cased");
encode = promisify(tokenizer.encode.bind(tokenizer));
output = await encode("Hey there dear friend!", null, { addSpecialTokens: false });
expect(output.getTokens()).toEqual(["Hey", "there", "dear", "friend", "!"]);
tokenizer = Tokenizer.fromPretrained("anthony/tokenizers-test");
encode = promisify(tokenizer.encode.bind(tokenizer));
output = await encode("Hey there dear friend!", null, { addSpecialTokens: false });
expect(output.getTokens()).toEqual(["hey", "there", "dear", "friend", "!"]);
tokenizer = Tokenizer.fromPretrained("anthony/tokenizers-test", {
revision: "gpt-2",
});
encode = promisify(tokenizer.encode.bind(tokenizer));
output = await encode("Hey there dear friend!", null, { addSpecialTokens: false });
expect(output.getTokens()).toEqual(["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]);
});
describe("addTokens", () => {
it("accepts a list of string as new tokens when initial model is empty", () => {
const model = BPE.empty();

View File

@ -392,3 +392,17 @@ class TestTokenizer:
tokenizer = Tokenizer(BPE())
multiprocessing_with_parallelism(tokenizer, False)
multiprocessing_with_parallelism(tokenizer, True)
def test_from_pretrained(self):
tokenizer = Tokenizer.from_pretrained("bert-base-cased")
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
assert output.tokens == ["Hey", "there", "dear", "friend", "!"]
def test_from_pretrained_revision(self):
tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test")
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
assert output.tokens == ["hey", "there", "dear", "friend", "!"]
tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2")
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]

View File

@ -0,0 +1,37 @@
use tokenizers::{FromPretrainedParameters, Result, Tokenizer};
#[test]
fn test_from_pretrained() -> Result<()> {
let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None)?;
let encoding = tokenizer.encode("Hey there dear friend!", false)?;
assert_eq!(
encoding.get_tokens(),
&["Hey", "there", "dear", "friend", "!"]
);
Ok(())
}
#[test]
fn test_from_pretrained_revision() -> Result<()> {
let tokenizer = Tokenizer::from_pretrained("anthony/tokenizers-test", None)?;
let encoding = tokenizer.encode("Hey there dear friend!", false)?;
assert_eq!(
encoding.get_tokens(),
&["hey", "there", "dear", "friend", "!"]
);
let tokenizer = Tokenizer::from_pretrained(
"anthony/tokenizers-test",
Some(FromPretrainedParameters {
revision: "gpt-2".to_string(),
..Default::default()
}),
)?;
let encoding = tokenizer.encode("Hey there dear friend!", false)?;
assert_eq!(
encoding.get_tokens(),
&["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]
);
Ok(())
}