mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Add unigram bytefallback (#1217)
* current updates will go red * cargo fmt * npm install * refactor train for unigram to allow bytefallbakc (breaking) * fmt * nits * update * add a proper test * fix encode optimised fallback + add trainer arg * fixes * fixes * fix tests * add test * fmt * fix rust test * update python bindings * update * pub is okay and needed * more fix * cleanup * remove useles id * MissingUnkId error * nits * fix offset * add a test in python * update src bindings * remove bytefallback from trainer * styling * update pckg * lint * fmt * stup with dev * update code based on review * remove unused function * udpate python test to compare ids * fix option bool issues * final fix * clippy * fix npm isntall * update * update test * more in depth testing * Lint * last attempt to fix node * update node bindings * fmt * Update tokenizers/src/models/unigram/model.rs Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * update based on review * simpler test * lint --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -162,6 +162,7 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
||||
vocab = [(piece.piece, piece.score) for piece in m.pieces]
|
||||
unk_id = m.trainer_spec.unk_id
|
||||
model_type = m.trainer_spec.model_type
|
||||
byte_fallback = m.trainer_spec.byte_fallback
|
||||
if model_type != 1:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
@ -170,7 +171,7 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
||||
replacement = "▁"
|
||||
add_prefix_space = True
|
||||
|
||||
tokenizer = Tokenizer(Unigram(vocab, unk_id))
|
||||
tokenizer = Tokenizer(Unigram(vocab, unk_id, byte_fallback))
|
||||
|
||||
tokenizer.normalizer = normalizers.Sequence(
|
||||
[
|
||||
|
@ -242,11 +242,11 @@ class Unigram(Model):
|
||||
An implementation of the Unigram algorithm
|
||||
|
||||
Args:
|
||||
vocab (:obj:`List[Tuple[str, float]]`, `optional`):
|
||||
vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`):
|
||||
A list of vocabulary items and their relative score [("am", -0.2442),...]
|
||||
"""
|
||||
|
||||
def __init__(self, vocab):
|
||||
def __init__(self, vocab, unk_id, byte_fallback):
|
||||
pass
|
||||
def get_trainer(self):
|
||||
"""
|
||||
|
@ -804,24 +804,32 @@ impl PyWordLevel {
|
||||
/// An implementation of the Unigram algorithm
|
||||
///
|
||||
/// Args:
|
||||
/// vocab (:obj:`List[Tuple[str, float]]`, `optional`):
|
||||
/// vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`):
|
||||
/// A list of vocabulary items and their relative score [("am", -0.2442),...]
|
||||
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "Unigram")]
|
||||
#[pyo3(text_signature = "(self, vocab)")]
|
||||
#[pyo3(text_signature = "(self, vocab, unk_id, byte_fallback)")]
|
||||
pub struct PyUnigram {}
|
||||
|
||||
#[pymethods]
|
||||
impl PyUnigram {
|
||||
#[new]
|
||||
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> {
|
||||
match (vocab, unk_id) {
|
||||
(Some(vocab), unk_id) => {
|
||||
let model = Unigram::from(vocab, unk_id).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
|
||||
})?;
|
||||
fn new(
|
||||
vocab: Option<Vec<(String, f64)>>,
|
||||
unk_id: Option<usize>,
|
||||
byte_fallback: Option<bool>,
|
||||
) -> PyResult<(Self, PyModel)> {
|
||||
match (vocab, unk_id, byte_fallback) {
|
||||
(Some(vocab), unk_id, byte_fallback) => {
|
||||
let model =
|
||||
Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while loading Unigram: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok((PyUnigram {}, model.into()))
|
||||
}
|
||||
(None, None) => Ok((PyUnigram {}, Unigram::default().into())),
|
||||
(None, None, _) => Ok((PyUnigram {}, Unigram::default().into())),
|
||||
_ => Err(exceptions::PyValueError::new_err(
|
||||
"`vocab` and `unk_id` must be both specified",
|
||||
)),
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from tokenizers import AddedToken, Encoding, Tokenizer
|
||||
from tokenizers.implementations import BertWordPieceTokenizer
|
||||
from tokenizers.models import BPE, Model, WordPiece
|
||||
from tokenizers.models import BPE, Model, WordPiece, Unigram
|
||||
from tokenizers.normalizers import Lowercase
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.processors import BertProcessing, RobertaProcessing
|
||||
@ -412,3 +412,29 @@ class TestTokenizer:
|
||||
tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2")
|
||||
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
|
||||
assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]
|
||||
|
||||
def test_unigram_byte_fallback(self):
|
||||
vocab = [
|
||||
("<unk>", 0.0),
|
||||
("A", -0.01),
|
||||
("sen", -0.02),
|
||||
("te", -0.03),
|
||||
("n", -0.04),
|
||||
("ce", -0.05),
|
||||
("<0xF0>", -0.06),
|
||||
("<0x9F>", -0.06),
|
||||
("<0xA4>", -0.06),
|
||||
("<0x97>", -0.06),
|
||||
(" ", -0.4),
|
||||
]
|
||||
tokenizer = tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
|
||||
|
||||
output = tokenizer.encode("A sentence 🤗")
|
||||
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 0]
|
||||
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "🤗"]
|
||||
|
||||
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
|
||||
|
||||
output = tokenizer.encode("A sentence 🤗")
|
||||
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9]
|
||||
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"]
|
||||
|
Reference in New Issue
Block a user