mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Add unigram bytefallback (#1217)
* current updates will go red * cargo fmt * npm install * refactor train for unigram to allow bytefallbakc (breaking) * fmt * nits * update * add a proper test * fix encode optimised fallback + add trainer arg * fixes * fixes * fix tests * add test * fmt * fix rust test * update python bindings * update * pub is okay and needed * more fix * cleanup * remove useles id * MissingUnkId error * nits * fix offset * add a test in python * update src bindings * remove bytefallback from trainer * styling * update pckg * lint * fmt * stup with dev * update code based on review * remove unused function * udpate python test to compare ids * fix option bool issues * final fix * clippy * fix npm isntall * update * update test * more in depth testing * Lint * last attempt to fix node * update node bindings * fmt * Update tokenizers/src/models/unigram/model.rs Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * update based on review * simpler test * lint --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
5
bindings/node/lib/bindings/models.d.ts
vendored
5
bindings/node/lib/bindings/models.d.ts
vendored
@ -170,6 +170,11 @@ export interface UnigramOptions {
|
|||||||
* @default undefined
|
* @default undefined
|
||||||
*/
|
*/
|
||||||
unkId?: number;
|
unkId?: number;
|
||||||
|
/**
|
||||||
|
* Whether or not bytefallback support should be enabled.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
byte_fallback?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export namespace Unigram {
|
export namespace Unigram {
|
||||||
|
@ -124,6 +124,7 @@ describe("Unigram", () => {
|
|||||||
],
|
],
|
||||||
{
|
{
|
||||||
unkId: 0,
|
unkId: 0,
|
||||||
|
byte_fallback: false,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
expect(unigram.constructor.name).toEqual("Model");
|
expect(unigram.constructor.name).toEqual("Model");
|
||||||
|
@ -191,6 +191,7 @@ fn bpe_init(mut cx: FunctionContext) -> JsResult<JsModel> {
|
|||||||
/// unkToken?: string,
|
/// unkToken?: string,
|
||||||
/// continuingSubwordPrefix?: string,
|
/// continuingSubwordPrefix?: string,
|
||||||
/// endOfWordSuffix?: string
|
/// endOfWordSuffix?: string
|
||||||
|
/// byteFallback?: bool
|
||||||
/// }, callback)
|
/// }, callback)
|
||||||
fn bpe_from_file(mut cx: FunctionContext) -> JsResult<JsUndefined> {
|
fn bpe_from_file(mut cx: FunctionContext) -> JsResult<JsUndefined> {
|
||||||
let (options, callback) = match cx.extract_opt::<BpeOptions>(2) {
|
let (options, callback) = match cx.extract_opt::<BpeOptions>(2) {
|
||||||
@ -369,16 +370,16 @@ fn wordlevel_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
|
|||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct UnigramOptions {
|
struct UnigramOptions {
|
||||||
unk_id: Option<usize>,
|
unk_id: Option<usize>,
|
||||||
|
byte_fallback: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// unigram_init(vocab: [string, number][], options?: {
|
/// unigram_init(vocab: [string, number][], options?: {
|
||||||
/// unkId?: number
|
/// unkId?: number
|
||||||
/// })
|
/// })
|
||||||
fn unigram_init(mut cx: FunctionContext) -> JsResult<JsModel> {
|
fn unigram_init(mut cx: FunctionContext) -> JsResult<JsModel> {
|
||||||
let vocab = cx.extract::<Vec<(String, f64)>>(0)?;
|
let vocab = cx.extract::<Vec<(String, f64)>>(0)?;
|
||||||
let options = cx.extract_opt::<UnigramOptions>(1)?.unwrap_or_default();
|
let options = cx.extract_opt::<UnigramOptions>(1)?.unwrap_or_default();
|
||||||
|
let byte_fallback = options.byte_fallback.unwrap_or(false);
|
||||||
let unigram = tk::models::unigram::Unigram::from(vocab, options.unk_id)
|
let unigram = tk::models::unigram::Unigram::from(vocab, options.unk_id, byte_fallback)
|
||||||
.map_err(|e| Error(e.to_string()))?;
|
.map_err(|e| Error(e.to_string()))?;
|
||||||
|
|
||||||
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
|
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
|
||||||
|
664
bindings/node/package-lock.json
generated
664
bindings/node/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -16,7 +16,9 @@
|
|||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/node": "^13.13.52",
|
"@types/node": "^13.13.52",
|
||||||
"node-pre-gyp": "^0.14.0"
|
"native": "^0.3.3",
|
||||||
|
"node-pre-gyp": "^0.14.0",
|
||||||
|
"package.json": "^2.0.1"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/jest": "^26.0.24",
|
"@types/jest": "^26.0.24",
|
||||||
|
@ -162,6 +162,7 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
|||||||
vocab = [(piece.piece, piece.score) for piece in m.pieces]
|
vocab = [(piece.piece, piece.score) for piece in m.pieces]
|
||||||
unk_id = m.trainer_spec.unk_id
|
unk_id = m.trainer_spec.unk_id
|
||||||
model_type = m.trainer_spec.model_type
|
model_type = m.trainer_spec.model_type
|
||||||
|
byte_fallback = m.trainer_spec.byte_fallback
|
||||||
if model_type != 1:
|
if model_type != 1:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||||
@ -170,7 +171,7 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
|||||||
replacement = "▁"
|
replacement = "▁"
|
||||||
add_prefix_space = True
|
add_prefix_space = True
|
||||||
|
|
||||||
tokenizer = Tokenizer(Unigram(vocab, unk_id))
|
tokenizer = Tokenizer(Unigram(vocab, unk_id, byte_fallback))
|
||||||
|
|
||||||
tokenizer.normalizer = normalizers.Sequence(
|
tokenizer.normalizer = normalizers.Sequence(
|
||||||
[
|
[
|
||||||
|
@ -242,11 +242,11 @@ class Unigram(Model):
|
|||||||
An implementation of the Unigram algorithm
|
An implementation of the Unigram algorithm
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab (:obj:`List[Tuple[str, float]]`, `optional`):
|
vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`):
|
||||||
A list of vocabulary items and their relative score [("am", -0.2442),...]
|
A list of vocabulary items and their relative score [("am", -0.2442),...]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab):
|
def __init__(self, vocab, unk_id, byte_fallback):
|
||||||
pass
|
pass
|
||||||
def get_trainer(self):
|
def get_trainer(self):
|
||||||
"""
|
"""
|
||||||
|
@ -804,24 +804,32 @@ impl PyWordLevel {
|
|||||||
/// An implementation of the Unigram algorithm
|
/// An implementation of the Unigram algorithm
|
||||||
///
|
///
|
||||||
/// Args:
|
/// Args:
|
||||||
/// vocab (:obj:`List[Tuple[str, float]]`, `optional`):
|
/// vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`):
|
||||||
/// A list of vocabulary items and their relative score [("am", -0.2442),...]
|
/// A list of vocabulary items and their relative score [("am", -0.2442),...]
|
||||||
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "Unigram")]
|
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "Unigram")]
|
||||||
#[pyo3(text_signature = "(self, vocab)")]
|
#[pyo3(text_signature = "(self, vocab, unk_id, byte_fallback)")]
|
||||||
pub struct PyUnigram {}
|
pub struct PyUnigram {}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyUnigram {
|
impl PyUnigram {
|
||||||
#[new]
|
#[new]
|
||||||
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> {
|
fn new(
|
||||||
match (vocab, unk_id) {
|
vocab: Option<Vec<(String, f64)>>,
|
||||||
(Some(vocab), unk_id) => {
|
unk_id: Option<usize>,
|
||||||
let model = Unigram::from(vocab, unk_id).map_err(|e| {
|
byte_fallback: Option<bool>,
|
||||||
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
|
) -> PyResult<(Self, PyModel)> {
|
||||||
|
match (vocab, unk_id, byte_fallback) {
|
||||||
|
(Some(vocab), unk_id, byte_fallback) => {
|
||||||
|
let model =
|
||||||
|
Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(|e| {
|
||||||
|
exceptions::PyException::new_err(format!(
|
||||||
|
"Error while loading Unigram: {}",
|
||||||
|
e
|
||||||
|
))
|
||||||
})?;
|
})?;
|
||||||
Ok((PyUnigram {}, model.into()))
|
Ok((PyUnigram {}, model.into()))
|
||||||
}
|
}
|
||||||
(None, None) => Ok((PyUnigram {}, Unigram::default().into())),
|
(None, None, _) => Ok((PyUnigram {}, Unigram::default().into())),
|
||||||
_ => Err(exceptions::PyValueError::new_err(
|
_ => Err(exceptions::PyValueError::new_err(
|
||||||
"`vocab` and `unk_id` must be both specified",
|
"`vocab` and `unk_id` must be both specified",
|
||||||
)),
|
)),
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from tokenizers import AddedToken, Encoding, Tokenizer
|
from tokenizers import AddedToken, Encoding, Tokenizer
|
||||||
from tokenizers.implementations import BertWordPieceTokenizer
|
from tokenizers.implementations import BertWordPieceTokenizer
|
||||||
from tokenizers.models import BPE, Model, WordPiece
|
from tokenizers.models import BPE, Model, WordPiece, Unigram
|
||||||
from tokenizers.normalizers import Lowercase
|
from tokenizers.normalizers import Lowercase
|
||||||
from tokenizers.pre_tokenizers import ByteLevel
|
from tokenizers.pre_tokenizers import ByteLevel
|
||||||
from tokenizers.processors import BertProcessing, RobertaProcessing
|
from tokenizers.processors import BertProcessing, RobertaProcessing
|
||||||
@ -412,3 +412,29 @@ class TestTokenizer:
|
|||||||
tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2")
|
tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2")
|
||||||
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
|
output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
|
||||||
assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]
|
assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]
|
||||||
|
|
||||||
|
def test_unigram_byte_fallback(self):
|
||||||
|
vocab = [
|
||||||
|
("<unk>", 0.0),
|
||||||
|
("A", -0.01),
|
||||||
|
("sen", -0.02),
|
||||||
|
("te", -0.03),
|
||||||
|
("n", -0.04),
|
||||||
|
("ce", -0.05),
|
||||||
|
("<0xF0>", -0.06),
|
||||||
|
("<0x9F>", -0.06),
|
||||||
|
("<0xA4>", -0.06),
|
||||||
|
("<0x97>", -0.06),
|
||||||
|
(" ", -0.4),
|
||||||
|
]
|
||||||
|
tokenizer = tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
|
||||||
|
|
||||||
|
output = tokenizer.encode("A sentence 🤗")
|
||||||
|
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 0]
|
||||||
|
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "🤗"]
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
|
||||||
|
|
||||||
|
output = tokenizer.encode("A sentence 🤗")
|
||||||
|
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9]
|
||||||
|
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"]
|
||||||
|
@ -27,6 +27,7 @@ pub struct Unigram {
|
|||||||
|
|
||||||
fuse_unk: bool,
|
fuse_unk: bool,
|
||||||
is_optimized: bool,
|
is_optimized: bool,
|
||||||
|
byte_fallback: bool,
|
||||||
}
|
}
|
||||||
impl PartialEq for Unigram {
|
impl PartialEq for Unigram {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
@ -50,6 +51,7 @@ impl Clone for Unigram {
|
|||||||
eos_id: self.eos_id,
|
eos_id: self.eos_id,
|
||||||
fuse_unk: self.fuse_unk,
|
fuse_unk: self.fuse_unk,
|
||||||
is_optimized: self.is_optimized,
|
is_optimized: self.is_optimized,
|
||||||
|
byte_fallback: self.byte_fallback,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -59,6 +61,7 @@ impl std::fmt::Debug for Unigram {
|
|||||||
fmt.debug_struct("Unigram")
|
fmt.debug_struct("Unigram")
|
||||||
.field("vocab", &self.vocab.len())
|
.field("vocab", &self.vocab.len())
|
||||||
.field("unk_id", &self.unk_id)
|
.field("unk_id", &self.unk_id)
|
||||||
|
.field("byte_fallback", &self.byte_fallback)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,7 +81,7 @@ pub enum UnigramError {
|
|||||||
impl Default for Unigram {
|
impl Default for Unigram {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let vocab = vec![("<unk>".to_string(), 0.0)];
|
let vocab = vec![("<unk>".to_string(), 0.0)];
|
||||||
Self::from(vocab, Some(0)).unwrap()
|
Self::from(vocab, Some(0), false).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +92,11 @@ impl Unigram {
|
|||||||
/// unk_id, is the index within the vocabulary.
|
/// unk_id, is the index within the vocabulary.
|
||||||
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
|
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
|
||||||
/// Further versions might allow that part to be hidden.
|
/// Further versions might allow that part to be hidden.
|
||||||
pub fn from(vocab: Vec<(String, f64)>, unk_id: Option<usize>) -> Result<Self> {
|
pub fn from(
|
||||||
|
vocab: Vec<(String, f64)>,
|
||||||
|
unk_id: Option<usize>,
|
||||||
|
byte_fallback: bool,
|
||||||
|
) -> Result<Self> {
|
||||||
let n = vocab.len();
|
let n = vocab.len();
|
||||||
let mut token_to_ids: TokenMap = HashMap::new();
|
let mut token_to_ids: TokenMap = HashMap::new();
|
||||||
let mut builder = TrieBuilder::default();
|
let mut builder = TrieBuilder::default();
|
||||||
@ -102,7 +109,6 @@ impl Unigram {
|
|||||||
return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
|
return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let bos_id = n + 1;
|
let bos_id = n + 1;
|
||||||
let eos_id = n + 2;
|
let eos_id = n + 2;
|
||||||
|
|
||||||
@ -130,6 +136,7 @@ impl Unigram {
|
|||||||
fuse_unk,
|
fuse_unk,
|
||||||
cache: Cache::default(),
|
cache: Cache::default(),
|
||||||
is_optimized,
|
is_optimized,
|
||||||
|
byte_fallback,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,7 +150,9 @@ impl Unigram {
|
|||||||
pub(super) fn set_optimized(&mut self, is_optimized: bool) {
|
pub(super) fn set_optimized(&mut self, is_optimized: bool) {
|
||||||
self.is_optimized = is_optimized;
|
self.is_optimized = is_optimized;
|
||||||
}
|
}
|
||||||
|
pub fn byte_fallback(&self) -> bool {
|
||||||
|
self.byte_fallback
|
||||||
|
}
|
||||||
pub(super) fn len(&self) -> usize {
|
pub(super) fn len(&self) -> usize {
|
||||||
self.vocab.len()
|
self.vocab.len()
|
||||||
}
|
}
|
||||||
@ -205,7 +214,7 @@ impl Unigram {
|
|||||||
/// ("abc".to_string(), 5.0),
|
/// ("abc".to_string(), 5.0),
|
||||||
/// ("abcd".to_string(), 10.0),
|
/// ("abcd".to_string(), 10.0),
|
||||||
/// ];
|
/// ];
|
||||||
/// let model = Unigram::from(pieces, Some(0)).unwrap();
|
/// let model = Unigram::from(pieces, Some(0), false).unwrap();
|
||||||
/// let result = model.encode("abcdacdxx").unwrap();
|
/// let result = model.encode("abcdacdxx").unwrap();
|
||||||
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
||||||
/// ```
|
/// ```
|
||||||
@ -407,12 +416,31 @@ impl Model for Unigram {
|
|||||||
let mut offset = 0;
|
let mut offset = 0;
|
||||||
let mut tokens = Vec::with_capacity(str_tokens.len());
|
let mut tokens = Vec::with_capacity(str_tokens.len());
|
||||||
for string in str_tokens {
|
for string in str_tokens {
|
||||||
let id: u32 = match self.token_to_ids.get(&string) {
|
|
||||||
Some(id) => *id,
|
|
||||||
None => self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32,
|
|
||||||
};
|
|
||||||
let len = string.len();
|
let len = string.len();
|
||||||
let offsets = (offset, offset + len);
|
let offsets = (offset, offset + len);
|
||||||
|
let id: u32 = match self.token_to_ids.get(&string) {
|
||||||
|
Some(id) => *id,
|
||||||
|
None => {
|
||||||
|
if self.byte_fallback {
|
||||||
|
let byte_tokens: Option<Vec<_>> = string
|
||||||
|
.bytes()
|
||||||
|
.map(|byte| -> Option<Token> {
|
||||||
|
let byte_string = format!("<0x{:02X}>", byte);
|
||||||
|
let id = self.token_to_ids.get(&byte_string);
|
||||||
|
id.map(|id| Token::new(*id, byte_string, (offset, offset + len)))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
if let Some(byte_tokens) = byte_tokens {
|
||||||
|
for token in byte_tokens {
|
||||||
|
tokens.push(token);
|
||||||
|
}
|
||||||
|
offset += len;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32
|
||||||
|
}
|
||||||
|
};
|
||||||
offset += len;
|
offset += len;
|
||||||
tokens.push(Token::new(id, string, offsets));
|
tokens.push(Token::new(id, string, offsets));
|
||||||
}
|
}
|
||||||
@ -452,7 +480,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_populate_nodes_unk() {
|
fn test_populate_nodes_unk() {
|
||||||
let pieces = vec![("<unk>".to_string(), 0.0)];
|
let pieces = vec![("<unk>".to_string(), 0.0)];
|
||||||
let model = Unigram::from(pieces, Some(0)).unwrap();
|
let model = Unigram::from(pieces, Some(0), false).unwrap();
|
||||||
|
|
||||||
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
|
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
@ -477,7 +505,7 @@ mod tests {
|
|||||||
("ab".to_string(), 0.3),
|
("ab".to_string(), 0.3),
|
||||||
("bc".to_string(), 0.4),
|
("bc".to_string(), 0.4),
|
||||||
];
|
];
|
||||||
let model = Unigram::from(pieces, Some(0)).unwrap();
|
let model = Unigram::from(pieces, Some(0), false).unwrap();
|
||||||
|
|
||||||
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
|
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
@ -514,7 +542,7 @@ mod tests {
|
|||||||
("abcd".to_string(), 10.0),
|
("abcd".to_string(), 10.0),
|
||||||
];
|
];
|
||||||
|
|
||||||
let model = Unigram::from(sentencepieces, Some(0)).unwrap();
|
let model = Unigram::from(sentencepieces, Some(0), false).unwrap();
|
||||||
let result = model.encode("abcd").unwrap();
|
let result = model.encode("abcd").unwrap();
|
||||||
assert_eq!(result, vec!["abcd"]);
|
assert_eq!(result, vec!["abcd"]);
|
||||||
}
|
}
|
||||||
@ -536,7 +564,7 @@ mod tests {
|
|||||||
("qr".to_string(), -0.5),
|
("qr".to_string(), -0.5),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut model = Unigram::from(sentencepieces, Some(0)).unwrap();
|
let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap();
|
||||||
|
|
||||||
for is_optimized in &[true, false] {
|
for is_optimized in &[true, false] {
|
||||||
model.set_optimized(*is_optimized);
|
model.set_optimized(*is_optimized);
|
||||||
@ -573,4 +601,35 @@ mod tests {
|
|||||||
assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
|
assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unigram_bytefallback() {
|
||||||
|
// In [97]: processor.encode_as_pieces("⅐⅛⅑ ")
|
||||||
|
// Out[97]: ['▁', '<0xE2>', '<0x85>', '<0x90>', '⅛', '<0xE2>', '<0x85>', '<0x91>', '▁']
|
||||||
|
let sentencepieces = vec![
|
||||||
|
("<unk>".to_string(), 0.0),
|
||||||
|
("<0xC3>".to_string(), -0.01),
|
||||||
|
("<0xA9>".to_string(), -0.03),
|
||||||
|
];
|
||||||
|
let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap();
|
||||||
|
let tokens: Vec<Token> = unigram.tokenize("é").unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
tokens,
|
||||||
|
[
|
||||||
|
Token {
|
||||||
|
id: 1,
|
||||||
|
value: "<0xC3>".to_string(),
|
||||||
|
offsets: (0, 2)
|
||||||
|
},
|
||||||
|
Token {
|
||||||
|
id: 2,
|
||||||
|
value: "<0xA9>".to_string(),
|
||||||
|
offsets: (0, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
let tokens = unigram.tokenize("?é").unwrap();
|
||||||
|
assert_eq!(tokens[0].id, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ impl Serialize for Unigram {
|
|||||||
model.serialize_field("type", "Unigram")?;
|
model.serialize_field("type", "Unigram")?;
|
||||||
model.serialize_field("unk_id", &self.unk_id)?;
|
model.serialize_field("unk_id", &self.unk_id)?;
|
||||||
model.serialize_field("vocab", &self.vocab)?;
|
model.serialize_field("vocab", &self.vocab)?;
|
||||||
|
model.serialize_field("byte_fallback", &self.byte_fallback())?;
|
||||||
|
|
||||||
model.end()
|
model.end()
|
||||||
}
|
}
|
||||||
@ -25,7 +26,11 @@ impl<'de> Deserialize<'de> for Unigram {
|
|||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
deserializer.deserialize_struct("Unigram", &["type", "vocab", "unk_id"], UnigramVisitor)
|
deserializer.deserialize_struct(
|
||||||
|
"Unigram",
|
||||||
|
&["type", "vocab", "unk_id", "byte_fallback"],
|
||||||
|
UnigramVisitor,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,11 +48,13 @@ impl<'de> Visitor<'de> for UnigramVisitor {
|
|||||||
{
|
{
|
||||||
let mut vocab: Option<Vec<(String, f64)>> = None;
|
let mut vocab: Option<Vec<(String, f64)>> = None;
|
||||||
let mut unk_id: Option<usize> = None;
|
let mut unk_id: Option<usize> = None;
|
||||||
|
let mut byte_fallback: bool = false;
|
||||||
while let Some(key) = map.next_key::<String>()? {
|
while let Some(key) = map.next_key::<String>()? {
|
||||||
match key.as_ref() {
|
match key.as_ref() {
|
||||||
"unk_id" => {
|
"unk_id" => {
|
||||||
unk_id = map.next_value()?;
|
unk_id = map.next_value()?;
|
||||||
}
|
}
|
||||||
|
"byte_fallback" => byte_fallback = map.next_value()?,
|
||||||
"vocab" => vocab = Some(map.next_value()?),
|
"vocab" => vocab = Some(map.next_value()?),
|
||||||
"type" => match map.next_value()? {
|
"type" => match map.next_value()? {
|
||||||
"Unigram" => {}
|
"Unigram" => {}
|
||||||
@ -61,10 +68,10 @@ impl<'de> Visitor<'de> for UnigramVisitor {
|
|||||||
_ => (),
|
_ => (),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match (vocab, unk_id) {
|
match (vocab, unk_id, byte_fallback) {
|
||||||
(Some(vocab), unk_id) => Ok(Unigram::from(vocab, unk_id)
|
(Some(vocab), unk_id, byte_fallback) => Ok(Unigram::from(vocab, unk_id, byte_fallback)
|
||||||
.map_err(|err| Error::custom(format!("Unable to load vocab {:?}", err)))?),
|
.map_err(|err| Error::custom(format!("Unable to load vocab {:?}", err)))?),
|
||||||
(None, _) => Err(Error::custom("Missing vocab")),
|
(None, _, _) => Err(Error::custom("Missing vocab")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,7 +83,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_serialization() {
|
fn test_serialization() {
|
||||||
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
||||||
let model = Unigram::from(vocab, Some(0)).unwrap();
|
let model = Unigram::from(vocab, Some(0), false).unwrap();
|
||||||
|
|
||||||
let data = serde_json::to_string(&model).unwrap();
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
@ -87,7 +94,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_serialization_unk_id_not_zero() {
|
fn test_serialization_unk_id_not_zero() {
|
||||||
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
|
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
|
||||||
let model = Unigram::from(vocab, Some(1)).unwrap();
|
let model = Unigram::from(vocab, Some(1), false).unwrap();
|
||||||
|
|
||||||
let data = serde_json::to_string(&model).unwrap();
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
@ -98,7 +105,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_serialization_no_unk_id() {
|
fn test_serialization_no_unk_id() {
|
||||||
let vocab = vec![("a".to_string(), -0.5)];
|
let vocab = vec![("a".to_string(), -0.5)];
|
||||||
let model = Unigram::from(vocab, None).unwrap();
|
let model = Unigram::from(vocab, None, false).unwrap();
|
||||||
|
|
||||||
let data = serde_json::to_string(&model).unwrap();
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
|
@ -177,7 +177,11 @@ impl UnigramTrainer {
|
|||||||
special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0));
|
special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
Unigram::from(special_tokens.into_iter().chain(pieces).collect(), unk_id)
|
Unigram::from(
|
||||||
|
special_tokens.into_iter().chain(pieces).collect(),
|
||||||
|
unk_id,
|
||||||
|
model.byte_fallback(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
||||||
@ -563,7 +567,7 @@ impl UnigramTrainer {
|
|||||||
if required_chars.len() as u32 > self.vocab_size {
|
if required_chars.len() as u32 > self.vocab_size {
|
||||||
return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
|
return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
|
||||||
}
|
}
|
||||||
let mut new_model = Unigram::from(pieces.clone(), Some(0))?;
|
let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?;
|
||||||
loop {
|
loop {
|
||||||
// Sub-EM iteration.
|
// Sub-EM iteration.
|
||||||
for _iter in 0..self.n_sub_iterations {
|
for _iter in 0..self.n_sub_iterations {
|
||||||
@ -572,7 +576,7 @@ impl UnigramTrainer {
|
|||||||
|
|
||||||
// Executes M step.
|
// Executes M step.
|
||||||
pieces = self.run_m_step(&pieces, &expected);
|
pieces = self.run_m_step(&pieces, &expected);
|
||||||
new_model = Unigram::from(pieces.clone(), Some(0))?;
|
new_model = Unigram::from(pieces.clone(), Some(0), false)?;
|
||||||
|
|
||||||
// Useful comment for checking compatibility with spm
|
// Useful comment for checking compatibility with spm
|
||||||
debug!(
|
debug!(
|
||||||
@ -596,7 +600,7 @@ impl UnigramTrainer {
|
|||||||
|
|
||||||
// Prunes pieces.
|
// Prunes pieces.
|
||||||
pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences);
|
pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences);
|
||||||
new_model = Unigram::from(pieces.clone(), Some(0))?;
|
new_model = Unigram::from(pieces.clone(), Some(0), false)?;
|
||||||
}
|
}
|
||||||
self.finalize_progress(&progress, expected_updates);
|
self.finalize_progress(&progress, expected_updates);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user