mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 12:39:21 +00:00
Test BPE keeping its options after training
This commit is contained in:
@ -581,12 +581,12 @@ mod test {
|
||||
match *py_model.model.as_ref().read().unwrap() {
|
||||
ModelWrapper::BPE(_) => (),
|
||||
_ => panic!("Expected Bert postprocessor."),
|
||||
}
|
||||
};
|
||||
|
||||
let py_model: PyModel = serde_json::from_str(&rs_wrapper_ser).unwrap();
|
||||
match *py_model.model.as_ref().read().unwrap() {
|
||||
ModelWrapper::BPE(_) => (),
|
||||
_ => panic!("Expected Bert postprocessor."),
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1045,7 +1045,9 @@ impl PyTokenizer {
|
||||
let trainer =
|
||||
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| ToPyResult(self.tokenizer.train(&trainer, files)).into())
|
||||
py.allow_threads(|| {
|
||||
ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -1173,15 +1175,15 @@ mod test {
|
||||
use super::*;
|
||||
use crate::models::PyModel;
|
||||
use crate::normalizers::{PyNormalizer, PyNormalizerTypeWrapper};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tempfile::NamedTempFile;
|
||||
use tk::normalizers::{Lowercase, NFKC};
|
||||
|
||||
#[test]
|
||||
fn serialize() {
|
||||
let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(
|
||||
let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(RwLock::new(
|
||||
tk::models::bpe::BPE::default().into(),
|
||||
)));
|
||||
))));
|
||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
||||
Arc::new(NFKC.into()),
|
||||
Arc::new(Lowercase.into()),
|
||||
|
@ -41,7 +41,7 @@ class TestUnigram:
|
||||
del os.environ["TOKENIZERS_PARALLELISM"]
|
||||
|
||||
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
|
||||
bpe_tokenizer.train(trainer, [train_files["small"]])
|
||||
bpe_tokenizer.train([train_files["small"]], trainer=trainer)
|
||||
|
||||
def test_train_with_special_tokens(self):
|
||||
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
|
||||
@ -76,7 +76,7 @@ class TestUnigram:
|
||||
show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]"
|
||||
)
|
||||
|
||||
tokenizer.train(trainer, [filename])
|
||||
tokenizer.train([filename], trainer=trainer)
|
||||
|
||||
assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [
|
||||
"[CLS]",
|
||||
|
@ -189,15 +189,15 @@ pub struct BPE {
|
||||
cache: Option<Cache<String, Word>>,
|
||||
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
|
||||
/// perform no merges, so the result will just be characters.
|
||||
pub(super) dropout: Option<f32>,
|
||||
pub dropout: Option<f32>,
|
||||
/// The unknown token to be used when we encounter an unknown char
|
||||
pub(super) unk_token: Option<String>,
|
||||
pub unk_token: Option<String>,
|
||||
/// An optional prefix to use on any subword that exist only behind another one
|
||||
pub(super) continuing_subword_prefix: Option<String>,
|
||||
pub continuing_subword_prefix: Option<String>,
|
||||
/// An optional suffix to caracterize and end-of-word subword
|
||||
pub(super) end_of_word_suffix: Option<String>,
|
||||
pub end_of_word_suffix: Option<String>,
|
||||
/// Do multiple unk tokens get fused
|
||||
pub(super) fuse_unk: bool,
|
||||
pub fuse_unk: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BPE {
|
||||
|
@ -1056,7 +1056,7 @@ where
|
||||
}
|
||||
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
pub fn train<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<()>
|
||||
pub fn train<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<&mut Self>
|
||||
where
|
||||
T: Trainer<Model = M> + Sync,
|
||||
{
|
||||
@ -1065,7 +1065,7 @@ where
|
||||
let special_tokens = trainer.train(words, &mut self.model)?;
|
||||
self.add_special_tokens(&special_tokens);
|
||||
|
||||
Ok(())
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@ use tokenizers::{Tokenizer, TokenizerImpl};
|
||||
#[test]
|
||||
fn train_tokenizer() {
|
||||
let vocab_size: usize = 100;
|
||||
let tokenizer = TokenizerBuilder::new()
|
||||
let mut tokenizer = TokenizerBuilder::new()
|
||||
.with_model(BPE::default())
|
||||
.with_normalizer(Some(Sequence::new(vec![
|
||||
Strip::new(true, true).into(),
|
||||
@ -97,7 +97,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
|
||||
"data/wikitext-103-raw/wiki.test.raw".into(),
|
||||
"data/wikitext-103-raw/wiki.valid.raw".into(),
|
||||
];
|
||||
tokenizer.train_and_replace(&trainer, files)?;
|
||||
tokenizer.train(&trainer, files)?;
|
||||
// END quicktour_train
|
||||
// START quicktour_reload_model
|
||||
use std::path::Path;
|
||||
@ -427,7 +427,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
|
||||
"data/wikitext-103-raw/wiki.test.raw".into(),
|
||||
"data/wikitext-103-raw/wiki.valid.raw".into(),
|
||||
];
|
||||
bert_tokenizer.train_and_replace(&trainer, files)?;
|
||||
bert_tokenizer.train(&trainer, files)?;
|
||||
|
||||
let model_files = bert_tokenizer
|
||||
.get_model()
|
||||
|
29
tokenizers/tests/training.rs
Normal file
29
tokenizers/tests/training.rs
Normal file
@ -0,0 +1,29 @@
|
||||
use tokenizers::models::bpe::BPE;
|
||||
use tokenizers::{DecoderWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper};
|
||||
use tokenizers::{Model, TokenizerBuilder};
|
||||
|
||||
#[test]
|
||||
fn bpe_values_after_training() {
|
||||
let mut tokenizer = TokenizerBuilder::<
|
||||
BPE,
|
||||
NormalizerWrapper,
|
||||
PreTokenizerWrapper,
|
||||
PostProcessorWrapper,
|
||||
DecoderWrapper,
|
||||
>::default()
|
||||
.with_model(
|
||||
BPE::builder()
|
||||
.unk_token("[UNK]".to_string())
|
||||
.dropout(0.1)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.build()
|
||||
.unwrap();
|
||||
let trainer = tokenizer.get_model().get_trainer();
|
||||
tokenizer
|
||||
.train(&trainer, vec!["./data/small.txt".to_string()])
|
||||
.unwrap();
|
||||
assert_eq!(tokenizer.get_model().dropout, Some(0.1));
|
||||
assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string()));
|
||||
}
|
Reference in New Issue
Block a user