Generate pyi, fix tests and clippy warnings

This commit is contained in:
Anthony MOI
2020-11-19 17:57:58 -05:00
committed by Anthony MOI
parent 5059be1a8d
commit 387b8a1033
7 changed files with 56 additions and 74 deletions

View File

@ -83,6 +83,27 @@ class UnigramTrainer(Trainer):
def __init__(self, vocab_size=8000, show_progress=True, special_tokens=[]):
pass
class WordLevelTrainer(Trainer):
"""
Capable of training a WorldLevel model
Args:
vocab_size: unsigned int:
The size of the final vocabulary, including all tokens and alphabet.
min_frequency: unsigned int:
The minimum frequency a pair should have in order to be merged.
show_progress: boolean:
Whether to show progress bars while training.
special_tokens: List[Union[str, AddedToken]]:
A list of special tokens the model should know of.
Returns:
Trainer
"""
class WordPieceTrainer(Trainer):
"""
Capable of training a WordPiece model

View File

@ -2,8 +2,13 @@ from ..utils import data_dir, doc_wiki_tokenizer, doc_pipeline_bert_tokenizer
from tokenizers import Tokenizer
disable_printing = True
original_print = print
def print(*args, **kwargs):
pass
if not disable_printing:
original_print(*args, **kwargs)
class TestPipeline:
@ -103,7 +108,7 @@ class TestPipeline:
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
bert_tokenizer = Tokenizer(WordPiece())
bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# END bert_setup_tokenizer
# START bert_setup_normalizer
from tokenizers import normalizers
@ -135,10 +140,7 @@ class TestPipeline:
vocab_size=30522, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
bert_tokenizer.train(trainer, files)
model_files = bert_tokenizer.model.save("data", "bert-wiki")
bert_tokenizer.model = WordPiece.from_file(*model_files, unk_token="[UNK]")
bert_tokenizer.train(files, trainer)
bert_tokenizer.save("data/bert-wiki.json")
# END bert_train_tokenizer
@ -173,6 +175,7 @@ if __name__ == "__main__":
from zipfile import ZipFile
import os
disable_printing = False
if not os.path.isdir("data/wikitext-103-raw"):
print("Downloading wikitext-103...")
wiki_text, _ = request.urlretrieve(

View File

@ -4,6 +4,14 @@ from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
disable_printing = True
original_print = print
def print(*args, **kwargs):
if not disable_printing:
original_print(*args, **kwargs)
class TestQuicktour:
# This method contains everything we don't want to run
@ -13,12 +21,8 @@ class TestQuicktour:
# START train
files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
tokenizer.train(trainer, files)
tokenizer.train(files, trainer)
# END train
# START reload_model
files = tokenizer.model.save("data", "wiki")
tokenizer.model = BPE.from_file(*files, unk_token="[UNK]")
# END reload_model
# START save
tokenizer.save("data/tokenizer-wiki.json")
# END save
@ -29,7 +33,7 @@ class TestQuicktour:
from tokenizers import Tokenizer
from tokenizers.models import BPE
tokenizer = Tokenizer(BPE())
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
# END init_tokenizer
# START init_trainer
from tokenizers.trainers import BpeTrainer
@ -181,6 +185,7 @@ if __name__ == "__main__":
from zipfile import ZipFile
import os
disable_printing = False
if not os.path.isdir("data/wikitext-103-raw"):
print("Downloading wikitext-103...")
wiki_text, _ = request.urlretrieve(

View File

@ -202,35 +202,7 @@ to use:
:end-before: END train
:dedent: 8
This should only take a few seconds to train our tokenizer on the full wikitext dataset! Once this
is done, we need to save the model and reinstantiate it with the unknown token, or this token won't
be used. This will be simplified in a further release, to let you set the :entity:`unk_token` when
first instantiating the model.
.. only:: python
.. literalinclude:: ../../bindings/python/tests/documentation/test_quicktour.py
:language: python
:start-after: START reload_model
:end-before: END reload_model
:dedent: 8
.. only:: rust
.. literalinclude:: ../../tokenizers/tests/documentation.rs
:language: rust
:start-after: START quicktour_reload_model
:end-before: END quicktour_reload_model
:dedent: 4
.. only:: node
.. literalinclude:: ../../bindings/node/examples/documentation/quicktour.test.ts
:language: javascript
:start-after: START reload_model
:end-before: END reload_model
:dedent: 8
This should only take a few seconds to train our tokenizer on the full wikitext dataset!
To save the tokenizer in one file that contains all its configuration and vocabulary, just use the
:entity:`Tokenizer.save` method:

View File

@ -84,7 +84,7 @@ fn main() -> Result<()> {
])
.build();
let tokenizer = TokenizerBuilder::new()
let mut tokenizer = TokenizerBuilder::new()
.with_model(BPE::default())
.with_normalizer(Some(Sequence::new(vec![
Strip::new(true, true).into(),

View File

@ -585,7 +585,7 @@ where
/// Get the vocabulary
pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
let mut final_vocab = self.model.get_vocab().clone();
let mut final_vocab = self.model.get_vocab();
if with_added_tokens {
let added_vocab = self.added_vocabulary.get_vocab();
@ -763,7 +763,6 @@ where
.filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
})
.map(|t| t.to_owned())
})
.collect::<Vec<_>>();

View File

@ -70,7 +70,12 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
> = TokenizerImpl::new(BPE::default());
> = TokenizerImpl::new(
BPE::builder()
.unk_token("[UNK]".to_string())
.build()
.unwrap(),
);
// END quicktour_init_tokenizer
// START quicktour_init_trainer
use tokenizers::models::bpe::BpeTrainer;
@ -99,22 +104,6 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
];
tokenizer.train(&trainer, files)?;
// END quicktour_train
// START quicktour_reload_model
use std::path::Path;
use tokenizers::Model;
let saved_files = tokenizer
.get_model()
.save(&Path::new("data"), Some("wiki"))?;
tokenizer.with_model(
BPE::from_file(
saved_files[0].to_str().unwrap(),
&saved_files[1].to_str().unwrap(),
)
.unk_token("[UNK]".to_string())
.build()?,
);
// END quicktour_reload_model
// START quicktour_save
tokenizer.save("data/tokenizer-wiki.json", false)?;
// END quicktour_save
@ -375,7 +364,12 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::Tokenizer;
let mut bert_tokenizer = Tokenizer::new(WordPiece::default());
let mut bert_tokenizer = Tokenizer::new(
WordPiece::builder()
.unk_token("[UNK]".to_string())
.build()
.unwrap(),
);
// END bert_setup_tokenizer
// START bert_setup_normalizer
use tokenizers::normalizers::utils::Sequence as NormalizerSequence;
@ -407,9 +401,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
);
// END bert_setup_processor
// START bert_train_tokenizer
use std::path::Path;
use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper};
use tokenizers::Model;
let trainer: TrainerWrapper = WordPieceTrainer::builder()
.vocab_size(30_522)
@ -429,16 +421,6 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
];
bert_tokenizer.train(&trainer, files)?;
let model_files = bert_tokenizer
.get_model()
.save(&Path::new("data"), Some("bert-wiki"))?;
bert_tokenizer.with_model(
WordPiece::from_file(model_files[0].to_str().unwrap())
.unk_token("[UNK]".to_string())
.build()
.unwrap(),
);
bert_tokenizer.save("data/bert-wiki.json", false)?;
// END bert_train_tokenizer
Ok(())