mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-16 17:18:43 +00:00
Python - Fix cases where str expected instead of AddedToken
This commit is contained in:
@@ -3,6 +3,7 @@ extern crate tokenizers as tk;
|
|||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
|
use pyo3::PyObjectProtocol;
|
||||||
|
|
||||||
use super::decoders::Decoder;
|
use super::decoders::Decoder;
|
||||||
use super::encoding::Encoding;
|
use super::encoding::Encoding;
|
||||||
@@ -45,6 +46,19 @@ impl AddedToken {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[pyproto]
|
||||||
|
impl PyObjectProtocol for AddedToken {
|
||||||
|
fn __str__(&'p self) -> PyResult<&'p str> {
|
||||||
|
Ok(&self.token.content)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __repr__(&self) -> PyResult<String> {
|
||||||
|
Ok(format!(
|
||||||
|
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={})",
|
||||||
|
self.token.content, self.token.rstrip, self.token.lstrip, self.token.single_word
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
pub struct Tokenizer {
|
pub struct Tokenizer {
|
||||||
|
|||||||
@@ -28,21 +28,21 @@ class BertWordPieceTokenizer(BaseTokenizer):
|
|||||||
):
|
):
|
||||||
|
|
||||||
if vocab_file is not None:
|
if vocab_file is not None:
|
||||||
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token))
|
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=str(unk_token)))
|
||||||
else:
|
else:
|
||||||
tokenizer = Tokenizer(WordPiece.empty())
|
tokenizer = Tokenizer(WordPiece.empty())
|
||||||
|
|
||||||
# Let the tokenizer know about special tokens if they are part of the vocab
|
# Let the tokenizer know about special tokens if they are part of the vocab
|
||||||
if tokenizer.token_to_id(unk_token) is not None:
|
if tokenizer.token_to_id(str(unk_token)) is not None:
|
||||||
tokenizer.add_special_tokens([unk_token])
|
tokenizer.add_special_tokens([str(unk_token)])
|
||||||
if tokenizer.token_to_id(sep_token) is not None:
|
if tokenizer.token_to_id(str(sep_token)) is not None:
|
||||||
tokenizer.add_special_tokens([sep_token])
|
tokenizer.add_special_tokens([str(sep_token)])
|
||||||
if tokenizer.token_to_id(cls_token) is not None:
|
if tokenizer.token_to_id(str(cls_token)) is not None:
|
||||||
tokenizer.add_special_tokens([cls_token])
|
tokenizer.add_special_tokens([str(cls_token)])
|
||||||
if tokenizer.token_to_id(pad_token) is not None:
|
if tokenizer.token_to_id(str(pad_token)) is not None:
|
||||||
tokenizer.add_special_tokens([pad_token])
|
tokenizer.add_special_tokens([str(pad_token)])
|
||||||
if tokenizer.token_to_id(mask_token) is not None:
|
if tokenizer.token_to_id(str(mask_token)) is not None:
|
||||||
tokenizer.add_special_tokens([mask_token])
|
tokenizer.add_special_tokens([str(mask_token)])
|
||||||
|
|
||||||
tokenizer.normalizer = BertNormalizer(
|
tokenizer.normalizer = BertNormalizer(
|
||||||
clean_text=clean_text,
|
clean_text=clean_text,
|
||||||
@@ -53,15 +53,15 @@ class BertWordPieceTokenizer(BaseTokenizer):
|
|||||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||||
|
|
||||||
if add_special_tokens and vocab_file is not None:
|
if add_special_tokens and vocab_file is not None:
|
||||||
sep_token_id = tokenizer.token_to_id(sep_token)
|
sep_token_id = tokenizer.token_to_id(str(sep_token))
|
||||||
if sep_token_id is None:
|
if sep_token_id is None:
|
||||||
raise TypeError("sep_token not found in the vocabulary")
|
raise TypeError("sep_token not found in the vocabulary")
|
||||||
cls_token_id = tokenizer.token_to_id(cls_token)
|
cls_token_id = tokenizer.token_to_id(str(cls_token))
|
||||||
if cls_token_id is None:
|
if cls_token_id is None:
|
||||||
raise TypeError("cls_token not found in the vocabulary")
|
raise TypeError("cls_token not found in the vocabulary")
|
||||||
|
|
||||||
tokenizer.post_processor = BertProcessing(
|
tokenizer.post_processor = BertProcessing(
|
||||||
(sep_token, sep_token_id), (cls_token, cls_token_id)
|
(str(sep_token), sep_token_id), (str(cls_token), cls_token_id)
|
||||||
)
|
)
|
||||||
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
|
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
|
||||||
|
|
||||||
|
|||||||
@@ -28,15 +28,15 @@ class CharBPETokenizer(BaseTokenizer):
|
|||||||
vocab_file,
|
vocab_file,
|
||||||
merges_file,
|
merges_file,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
unk_token=unk_token,
|
unk_token=str(unk_token),
|
||||||
end_of_word_suffix=suffix,
|
end_of_word_suffix=suffix,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer = Tokenizer(BPE.empty())
|
tokenizer = Tokenizer(BPE.empty())
|
||||||
|
|
||||||
if tokenizer.token_to_id(unk_token) is not None:
|
if tokenizer.token_to_id(str(unk_token)) is not None:
|
||||||
tokenizer.add_special_tokens([unk_token])
|
tokenizer.add_special_tokens([str(unk_token)])
|
||||||
|
|
||||||
# Check for Unicode normalization first (before everything else)
|
# Check for Unicode normalization first (before everything else)
|
||||||
normalizers = []
|
normalizers = []
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ class SentencePieceBPETokenizer(BaseTokenizer):
|
|||||||
else:
|
else:
|
||||||
tokenizer = Tokenizer(BPE.empty())
|
tokenizer = Tokenizer(BPE.empty())
|
||||||
|
|
||||||
if tokenizer.token_to_id(unk_token) is not None:
|
if tokenizer.token_to_id(str(unk_token)) is not None:
|
||||||
tokenizer.add_special_tokens([unk_token])
|
tokenizer.add_special_tokens([str(unk_token)])
|
||||||
|
|
||||||
tokenizer.normalizer = NFKC()
|
tokenizer.normalizer = NFKC()
|
||||||
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
|
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
|
||||||
|
|||||||
Reference in New Issue
Block a user