AddedVocabulary - Add tests, update bindings + various tweaks

This commit is contained in:
Anthony MOI
2020-06-16 22:34:59 -04:00
parent c6f633eb1c
commit fc63d56eab
12 changed files with 326 additions and 91 deletions

View File

@@ -392,6 +392,14 @@ export interface AddedTokenOptions {
* @default False * @default False
*/ */
singleWord?: boolean; singleWord?: boolean;
/**
* Whether this token should match on the normalized version of the text. For example
* with the added token `yesterday` and a normalizer in charge of lowercasing the text,
* the input `I saw a lion Yesterday` would match the token.
* This is False for special tokens by default, true otherwise
* @default True
*/
normalized?: boolean;
} }
/** /**
@@ -404,9 +412,10 @@ export class AddedToken {
/** /**
* Instantiate a new AddedToken * Instantiate a new AddedToken
* @param content The content of the token * @param content The content of the token
* @param special Whether this is a special token
* @param [options] Options for the token * @param [options] Options for the token
*/ */
constructor(content: string, options?: AddedTokenOptions); constructor(content: string, special: boolean, options?: AddedTokenOptions);
/** /**
* Get the content of the AddedToken * Get the content of the AddedToken

View File

@@ -32,17 +32,17 @@ import {
describe("AddedToken", () => { describe("AddedToken", () => {
it("instantiates with only content", () => { it("instantiates with only content", () => {
const addToken = new AddedToken("test"); const addToken = new AddedToken("test", false);
expect(addToken.constructor.name).toEqual("AddedToken"); expect(addToken.constructor.name).toEqual("AddedToken");
}); });
it("instantiates with empty options", () => { it("instantiates with empty options", () => {
const addToken = new AddedToken("test", {}); const addToken = new AddedToken("test", false, {});
expect(addToken.constructor.name).toEqual("AddedToken"); expect(addToken.constructor.name).toEqual("AddedToken");
}); });
it("instantiates with options", () => { it("instantiates with options", () => {
const addToken = new AddedToken("test", { const addToken = new AddedToken("test", false, {
leftStrip: true, leftStrip: true,
rightStrip: true, rightStrip: true,
singleWord: true singleWord: true
@@ -52,7 +52,7 @@ describe("AddedToken", () => {
describe("getContent", () => { describe("getContent", () => {
it("returns the string content of AddedToken", () => { it("returns the string content of AddedToken", () => {
const addedToken = new AddedToken("test"); const addedToken = new AddedToken("test", false);
expect(addedToken.getContent()).toEqual("test"); expect(addedToken.getContent()).toEqual("test");
}); });
}); });
@@ -107,7 +107,7 @@ describe("Tokenizer", () => {
it("accepts a list of AddedToken as new tokens when initial model is empty", () => { it("accepts a list of AddedToken as new tokens when initial model is empty", () => {
const model = BPE.empty(); const model = BPE.empty();
const tokenizer = new Tokenizer(model); const tokenizer = new Tokenizer(model);
const addedToken = new AddedToken("test"); const addedToken = new AddedToken("test", false);
const nbAdd = tokenizer.addTokens([addedToken]); const nbAdd = tokenizer.addTokens([addedToken]);
expect(nbAdd).toBe(1); expect(nbAdd).toBe(1);
@@ -132,7 +132,7 @@ describe("Tokenizer", () => {
const model = BPE.empty(); const model = BPE.empty();
tokenizer = new Tokenizer(model); tokenizer = new Tokenizer(model);
tokenizer.addTokens(["my", "name", "is", "john", new AddedToken("pair")]); tokenizer.addTokens(["my", "name", "is", "john", new AddedToken("pair", false)]);
encode = promisify(tokenizer.encode.bind(tokenizer)); encode = promisify(tokenizer.encode.bind(tokenizer));
encodeBatch = promisify(tokenizer.encodeBatch.bind(tokenizer)); encodeBatch = promisify(tokenizer.encodeBatch.bind(tokenizer));

View File

@@ -30,10 +30,11 @@ struct AddedTokenOptions {
singleWord: Option<bool>, singleWord: Option<bool>,
leftStrip: Option<bool>, leftStrip: Option<bool>,
rightStrip: Option<bool>, rightStrip: Option<bool>,
normalized: Option<bool>,
} }
impl AddedTokenOptions { impl AddedTokenOptions {
fn into_added_token(self, content: String) -> tk::AddedToken { fn into_added_token(self, content: String, special: bool) -> tk::AddedToken {
let mut token = tk::AddedToken::from(content); let mut token = tk::AddedToken::from(content, special);
if let Some(sw) = self.singleWord { if let Some(sw) = self.singleWord {
token = token.single_word(sw); token = token.single_word(sw);
} }
@@ -43,6 +44,9 @@ impl AddedTokenOptions {
if let Some(rs) = self.rightStrip { if let Some(rs) = self.rightStrip {
token = token.rstrip(rs); token = token.rstrip(rs);
} }
if let Some(n) = self.normalized {
token = token.normalized(n);
}
token token
} }
} }
@@ -52,18 +56,20 @@ declare_types! {
init(mut cx) { init(mut cx) {
// init( // init(
// content: string, // content: string,
// special: boolean,
// options?: { // options?: {
// singleWord?: boolean = false, // singleWord?: boolean = false,
// leftStrip?: boolean = false, // leftStrip?: boolean = false,
// rightStrip?: boolean = false // rightStrip?: boolean = false
// normalized?: boolean = true,
// } // }
// ) // )
let content = cx.extract::<String>(0) let content = cx.extract::<String>(0)?;
.map_err(|_| Error("First argument must be string".into()))?; let special = cx.extract::<bool>(1)?;
let token = cx.extract_opt::<AddedTokenOptions>(1)? let token = cx.extract_opt::<AddedTokenOptions>(2)?
.unwrap_or_else(AddedTokenOptions::default) .unwrap_or_else(AddedTokenOptions::default)
.into_added_token(content); .into_added_token(content, special);
Ok(AddedToken { token }) Ok(AddedToken { token })
} }
@@ -87,7 +93,7 @@ impl FromJsValue for AddedToken {
fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> { fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> {
if let Ok(token) = from.downcast::<JsString>() { if let Ok(token) = from.downcast::<JsString>() {
Ok(AddedToken { Ok(AddedToken {
token: tk::AddedToken::from(token.value()), token: tk::AddedToken::from(token.value(), false),
}) })
} else if let Ok(token) = from.downcast::<JsAddedToken>() { } else if let Ok(token) = from.downcast::<JsAddedToken>() {
let guard = cx.lock(); let guard = cx.lock();
@@ -99,6 +105,21 @@ impl FromJsValue for AddedToken {
} }
} }
struct SpecialToken(tk::AddedToken);
impl FromJsValue for SpecialToken {
fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> {
if let Ok(token) = from.downcast::<JsString>() {
Ok(SpecialToken(tk::AddedToken::from(token.value(), true)))
} else if let Ok(token) = from.downcast::<JsAddedToken>() {
let guard = cx.lock();
let token = token.borrow(&guard);
Ok(SpecialToken(token.token.clone()))
} else {
Err(Error("Expected `string | AddedToken`".into()))
}
}
}
// encode & encodeBatch types // encode & encodeBatch types
struct TextInputSequence(tk::InputSequence); struct TextInputSequence(tk::InputSequence);
@@ -623,7 +644,7 @@ declare_types! {
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
let token = this.borrow(&guard).tokenizer.id_to_token(id); let token = this.borrow(&guard).tokenizer.id_to_token(id).map(|t| t.to_owned());
if let Some(token) = token { if let Some(token) = token {
Ok(cx.string(token).upcast()) Ok(cx.string(token).upcast())
@@ -650,9 +671,9 @@ declare_types! {
method addSpecialTokens(mut cx) { method addSpecialTokens(mut cx) {
// addSpecialTokens(tokens: (string | AddedToken)[]): number // addSpecialTokens(tokens: (string | AddedToken)[]): number
let tokens = cx.extract_vec::<AddedToken>(0)? let tokens = cx.extract_vec::<SpecialToken>(0)?
.into_iter() .into_iter()
.map(|token| token.into()) .map(|token| token.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut this = cx.this(); let mut this = cx.this();

View File

@@ -29,7 +29,7 @@ impl AddedToken {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> { fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = tk::tokenizer::AddedToken::from(content.to_owned(), is_special_token); let mut token = tk::tokenizer::AddedToken::from(content, is_special_token);
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
for (key, value) in kwargs { for (key, value) in kwargs {

View File

@@ -200,7 +200,13 @@ class AddedToken:
""" """
def __new__( def __new__(
cls, content: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False, cls,
content: str,
is_special_token: bool,
single_word: bool = False,
lstrip: bool = False,
rstrip: bool = False,
normalized: bool = True,
) -> AddedToken: ) -> AddedToken:
""" Instantiate a new AddedToken """ Instantiate a new AddedToken
@@ -208,19 +214,30 @@ class AddedToken:
content: str: content: str:
The content of the token The content of the token
is_special_token: bool:
Whether this token is a special token. This has an impact on the default value for
`normalized` which is False for special tokens, but True for others.
single_word: bool single_word: bool
Whether this token should only match against single word. If True, Whether this token should only match against single words. If True,
this token will never match inside of a word. this token will never match inside of a word. For example the token `ing` would
match on `tokenizing` if this option if False, but not if this option is True.
lstrip: bool lstrip: bool
Whether this token should strip all potential whitespaces on the left side. Whether this token should strip all potential whitespaces on the left side.
If True, this token will greedily match any whitespace on the left and then strip If True, this token will greedily match any whitespace on the left. For example,
them out. if we try to match the token `[MASK]` with lstrip=True, in the text `I saw a [MASK]`
we will match on ` [MASK]`.
rstrip: bool rstrip: bool
Whether this token should strip all potential whitespaces on the right side. Whether this token should strip all potential whitespaces on the right side.
If True, this token will greedily match any whitespace on the right and then strip If True, this token will greedily match any whitespace on the right. It works just
them out. like lstrip, but on the right.
normalized: bool:
Whether this token should be match the normalized version of the input text. For
example, with the added token `yesterday` and a normalizer in charge of lowercasing
the text, the token could be extract from the input `I saw a lion Yesterday`.
""" """
pass pass

View File

@@ -56,11 +56,11 @@ fn main() -> Result<()>{
.vocab_size(vocab_size) .vocab_size(vocab_size)
.min_frequency(0) .min_frequency(0)
.special_tokens(vec![ .special_tokens(vec![
AddedToken::from("<s>".into()), AddedToken::from("<s>", true),
AddedToken::from("<pad>".into()), AddedToken::from("<pad>", true),
AddedToken::from("</s>".into()), AddedToken::from("</s>", true),
AddedToken::from("<unk>".into()), AddedToken::from("<unk>", true),
AddedToken::from("<mask>".into()), AddedToken::from("<mask>", true),
]) ])
.build(), .build(),
); );

View File

@@ -17,9 +17,8 @@ fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
let mut tokenizer = Tokenizer::new(Box::new(bpe)); let mut tokenizer = Tokenizer::new(Box::new(bpe));
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default())); tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
tokenizer tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
.add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
tokenizer tokenizer
} }

View File

@@ -20,13 +20,14 @@ pub struct AddedToken {
/// Whether this token should be normalized /// Whether this token should be normalized
pub normalized: bool, pub normalized: bool,
} }
impl AddedToken { impl AddedToken {
/// Build this token from the given content, specifying if it is intented to be a /// Build this token from the given content, specifying if it is intented to be a
/// special token. Special tokens are not normalized by default. /// special token. Special tokens are not normalized by default.
pub fn from(content: String, special_token: bool) -> Self { pub fn from<S: Into<String>>(content: S, special: bool) -> Self {
AddedToken { AddedToken {
content, content: content.into(),
normalized: !special_token, normalized: !special,
..Default::default() ..Default::default()
} }
} }
@@ -48,7 +49,7 @@ impl AddedToken {
self.rstrip = rstrip; self.rstrip = rstrip;
self self
} }
/// Specify whether this token should be normalized, and/or match against its normalized /// Specify whether this token should be normalized and match against its normalized
/// version in the input text. /// version in the input text.
pub fn normalized(mut self, normalized: bool) -> Self { pub fn normalized(mut self, normalized: bool) -> Self {
self.normalized = normalized; self.normalized = normalized;
@@ -108,7 +109,7 @@ impl Default for AddedToken {
single_word: false, single_word: false,
lstrip: false, lstrip: false,
rstrip: false, rstrip: false,
normalized: false, normalized: true,
} }
} }
} }
@@ -144,22 +145,22 @@ type MatchingSet = (regex::RegexSet, Vec<u32>);
/// exist as required. /// exist as required.
/// ///
pub(super) struct AddedVocabulary { pub(super) struct AddedVocabulary {
/// The size of the original vocabulary. This is what we use to determine the new /// Contains the mapping from String (token content) to ID. This map contains both special
/// ids we need to generate /// tokens and classic added tokens that were added to the this vocabulary.
original_vocab_size: usize,
/// Contains the mapping from String to ID as the user intended it. This map
/// contains both special tokens and classic added tokens.
added_tokens_map: HashMap<String, u32>, added_tokens_map: HashMap<String, u32>,
/// Contains the mapping from ID to AddedToken for all the added tokens, both special /// Contains the mapping from ID to AddedToken for all the added tokens, both special
/// and classic. /// and classic.
added_tokens_map_r: HashMap<u32, AddedToken>, added_tokens_map_r: HashMap<u32, AddedToken>,
/// Contains only the classic AddedToken, in the specific order the user gave them. /// Contains only the classic AddedToken, in the specific order the user gave them.
added_tokens: Vec<AddedToken>, added_tokens: Vec<AddedToken>,
/// Contains only the special AddedToken, in the specific order the user gave them. /// Contains only the special AddedToken, in the specific order the user gave them.
special_tokens: Vec<AddedToken>, special_tokens: Vec<AddedToken>,
/// A Set, containing all the special token for easy access while decoding. This let's /// A Set, containing all the special token for easy access while decoding. This let's
/// use remove them easily with an O(1) complexity. /// us remove them easily with an O(1) complexity.
special_tokens_set: HashSet<String>, special_tokens_set: HashSet<String>,
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_re: MatchingSet, split_re: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens /// A RegexSet containing all the normalized patterns used to split on AddedTokens
@@ -167,9 +168,8 @@ pub(super) struct AddedVocabulary {
} }
impl AddedVocabulary { impl AddedVocabulary {
pub fn new(original_vocab_size: usize) -> Self { pub fn new() -> Self {
Self { Self {
original_vocab_size,
added_tokens_map: HashMap::new(), added_tokens_map: HashMap::new(),
added_tokens_map_r: HashMap::new(), added_tokens_map_r: HashMap::new(),
added_tokens: vec![], added_tokens: vec![],
@@ -180,12 +180,6 @@ impl AddedVocabulary {
} }
} }
/// Sets the original vocabulary size. We need this value to return IDs that
/// are shifted according to the original vocabulary.
pub fn update_original_vocab_size(&mut self, size: usize) {
self.original_vocab_size = size;
}
/// Size of the additional vocabulary /// Size of the additional vocabulary
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.added_tokens_map.len() self.added_tokens_map.len()
@@ -252,7 +246,7 @@ impl AddedVocabulary {
ignored += 1; ignored += 1;
id id
} else { } else {
let new_id = (self.original_vocab_size + self.added_tokens_map.len()) as u32; let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
self.added_tokens_map.insert(token.content.clone(), new_id); self.added_tokens_map.insert(token.content.clone(), new_id);
if !self.special_tokens_set.contains(&token.content) { if !self.special_tokens_set.contains(&token.content) {
@@ -400,7 +394,6 @@ impl AddedVocabulary {
splits splits
.into_iter() .into_iter()
.map(|(idx, (start, end))| { .map(|(idx, (start, end))| {
// TODO: Check this works (especially for offsets)
let normalized = sentence let normalized = sentence
.slice_bytes(Range::Normalized(start..end)) .slice_bytes(Range::Normalized(start..end))
.expect("Error while extracting normalized Range"); .expect("Error while extracting normalized Range");
@@ -472,7 +465,6 @@ impl Serialize for AddedVocabulary {
.added_tokens_map_r .added_tokens_map_r
.iter() .iter()
.map(|(id, token)| AddedTokenWithId { .map(|(id, token)| AddedTokenWithId {
// TODO: Make sure these are the right IDs (related to the model)
id: *id, id: *id,
special: self.special_tokens_set.contains(&token.content), special: self.special_tokens_set.contains(&token.content),
token: token.clone(), token: token.clone(),
@@ -488,3 +480,211 @@ impl Serialize for AddedVocabulary {
vocabulary.end() vocabulary.end()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::normalizers::utils::Lowercase;
use crate::{Offsets, Result, Token};
use std::path::{Path, PathBuf};
#[derive(Serialize, Deserialize)]
struct ModelMock {
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
}
impl ModelMock {
pub fn new<I>(iter: I) -> Self
where
I: IntoIterator<Item = &'static (&'static str, u32)>,
{
let vocab: HashMap<String, u32> = iter
.into_iter()
.map(|&(tok, id)| (tok.to_string(), id))
.collect();
Self {
vocab_r: vocab
.iter()
.map(|(tok, id)| (*id, tok.to_owned()))
.collect(),
vocab,
}
}
}
#[typetag::serde]
impl Model for ModelMock {
fn tokenize(&self, _tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>> {
unimplemented!()
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
}
fn get_vocab(&self) -> &HashMap<String, u32> {
&self.vocab
}
fn get_vocab_size(&self) -> usize {
self.vocab.len()
}
fn save(&self, _folder: &Path, _name: Option<&str>) -> Result<Vec<PathBuf>> {
unimplemented!()
}
}
#[test]
fn can_add_tokens() {
let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
let mut vocab = AddedVocabulary::new();
// Add tokens normally
assert_eq!(
vocab.add_tokens(&[AddedToken::from("added_token_1", false)], &model, None),
1
);
assert_eq!(vocab.len(), 1);
// Does not add multiple time the same token
assert_eq!(
vocab.add_tokens(
&[
AddedToken::from("added_token_2", false),
AddedToken::from("added_token_2", false)
],
&model,
None
),
1
);
assert_eq!(vocab.len(), 2);
// Does not add tokens already covered by the model
assert_eq!(
vocab.add_tokens(&[AddedToken::from("test", false)], &model, None),
0
);
assert_eq!(vocab.len(), 2);
}
#[test]
fn can_add_special_tokens() {
let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
let mut vocab = AddedVocabulary::new();
// Add tokens normally
assert_eq!(
vocab.add_special_tokens(&[AddedToken::from("added_token_1", true)], &model, None),
1
);
assert_eq!(vocab.len(), 1);
// Does not add multiple time the same token
assert_eq!(
vocab.add_special_tokens(
&[
AddedToken::from("added_token_2", true),
AddedToken::from("added_token_2", true)
],
&model,
None
),
1
);
assert_eq!(vocab.len(), 2);
// Can add tokens already covered by the model
assert_eq!(
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, None),
0
);
assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model
assert_eq!(vocab.is_special_token("test"), true);
assert_eq!(vocab.added_tokens_map.contains_key("test"), false);
}
#[test]
fn can_extract_added_tokens() {
// Is able to extract both normal and special tokens
let model = ModelMock::new(&[]);
let mut vocab = AddedVocabulary::new();
vocab.add_tokens(
&[
AddedToken::from("my", false),
AddedToken::from("name", false),
],
&model,
None,
);
vocab.add_special_tokens(
&[
AddedToken::from("[CLS]", true),
AddedToken::from("[SEP]", true),
],
&model,
None,
);
let result = vocab.extract_and_normalize(None, "[CLS] My name is Anthony [SEP]");
assert_eq!(
result
.iter()
.map(|(normalized, id)| (normalized.get(), *id))
.collect::<Vec<_>>(),
vec![
("[CLS]", Some(2)),
(" My ", None),
("name", Some(1)),
(" is Anthony ", None),
("[SEP]", Some(3))
]
);
}
#[test]
fn options_use_cases() {
// Is able to extract both normal and special tokens, with various options (lstrip, rstrip,
// single_word, normalized)
let model = ModelMock::new(&[]);
let normalizer = Lowercase;
let mut vocab = AddedVocabulary::new();
vocab.add_tokens(
&[
AddedToken::from("my", false).lstrip(true).rstrip(true),
AddedToken::from("name", false),
AddedToken::from("ony", false).single_word(true),
],
&model,
Some(&normalizer),
);
vocab.add_special_tokens(
&[
AddedToken::from("[CLS]", true),
AddedToken::from("[SEP]", true),
],
&model,
Some(&normalizer),
);
let result =
vocab.extract_and_normalize(Some(&normalizer), "[CLS] My name is Anthony [SEP]");
assert_eq!(
result
.iter()
.map(|(normalized, id)| (normalized.get(), *id))
.collect::<Vec<_>>(),
vec![
("[CLS]", Some(3)),
// This one includes both spaces because of the lstrip & rstrip
// And it matches because normalized == true
(" my ", Some(0)),
("name", Some(1)),
// `ony` is not extracted here thanks to single_word
(" is anthony ", None),
("[SEP]", Some(4))
]
);
}
}

View File

@@ -211,7 +211,6 @@ impl std::str::FromStr for Tokenizer {
impl Tokenizer { impl Tokenizer {
/// Instantiate a new Tokenizer, with the given Model /// Instantiate a new Tokenizer, with the given Model
pub fn new(model: Box<dyn Model>) -> Self { pub fn new(model: Box<dyn Model>) -> Self {
let original_vocab_size = model.get_vocab_size();
Tokenizer { Tokenizer {
normalizer: None, normalizer: None,
pre_tokenizer: None, pre_tokenizer: None,
@@ -219,7 +218,7 @@ impl Tokenizer {
post_processor: None, post_processor: None,
decoder: None, decoder: None,
added_vocabulary: AddedVocabulary::new(original_vocab_size), added_vocabulary: AddedVocabulary::new(),
truncation: None, truncation: None,
padding: None, padding: None,
@@ -303,8 +302,6 @@ impl Tokenizer {
/// Set the model /// Set the model
pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self { pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self {
self.model = model; self.model = model;
self.added_vocabulary
.update_original_vocab_size(self.model.get_vocab_size());
self self
} }
@@ -669,8 +666,6 @@ impl Tokenizer {
let (model, special_tokens) = trainer.train(words)?; let (model, special_tokens) = trainer.train(words)?;
self.model = model; self.model = model;
self.added_vocabulary
.update_original_vocab_size(self.model.get_vocab_size());
self.add_special_tokens(&special_tokens); self.add_special_tokens(&special_tokens);
Ok(()) Ok(())

View File

@@ -50,7 +50,7 @@ where
/// It is possible to retrieve a part of the original string, by indexing it with offsets from the /// It is possible to retrieve a part of the original string, by indexing it with offsets from the
/// normalized one, and the other way around too. It is also possible to convert offsets from one /// normalized one, and the other way around too. It is also possible to convert offsets from one
/// referential to the other one easily. /// referential to the other one easily.
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone, PartialEq)]
pub struct NormalizedString { pub struct NormalizedString {
/// The original version of the string, before any modification /// The original version of the string, before any modification
original: String, original: String,
@@ -61,12 +61,6 @@ pub struct NormalizedString {
alignments: Vec<(usize, usize)>, alignments: Vec<(usize, usize)>,
} }
impl std::cmp::PartialEq for NormalizedString {
fn eq(&self, other: &NormalizedString) -> bool {
self.normalized == other.normalized
}
}
impl NormalizedString { impl NormalizedString {
/// Create a NormalizedString from the given str /// Create a NormalizedString from the given str
pub fn from(s: &str) -> Self { pub fn from(s: &str) -> Self {
@@ -441,7 +435,7 @@ impl NormalizedString {
/// Merge with the given NormalizedString by appending it to self /// Merge with the given NormalizedString by appending it to self
pub fn merge_with(&mut self, other: &NormalizedString) { pub fn merge_with(&mut self, other: &NormalizedString) {
self.original.push_str(&other.original); self.original.push_str(&other.original);
let len = self.len(); let len = self.len() - 1;
self.alignments.extend( self.alignments.extend(
other other
.alignments .alignments
@@ -879,7 +873,7 @@ mod tests {
Some(NormalizedString { Some(NormalizedString {
original: "𝕞𝕠𝕣𝕟𝕚𝕟𝕘".to_string(), original: "𝕞𝕠𝕣𝕟𝕚𝕟𝕘".to_string(),
normalized: "morning".to_string(), normalized: "morning".to_string(),
alignments: vec![(5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12)] alignments: vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
}) })
); );
assert_eq!( assert_eq!(

View File

@@ -9,8 +9,8 @@ fn add_tokens() {
assert_eq!( assert_eq!(
tokenizer.add_special_tokens(&[ tokenizer.add_special_tokens(&[
AddedToken::from("<cls>".into(), true), AddedToken::from("<cls>", true),
AddedToken::from("<sep>".into(), true) AddedToken::from("<sep>", true)
]), ]),
2 2
); );
@@ -19,8 +19,8 @@ fn add_tokens() {
assert_eq!( assert_eq!(
tokenizer.add_tokens(&[ tokenizer.add_tokens(&[
AddedToken::from("hello".into(), false), AddedToken::from("hello", false),
AddedToken::from("world".into(), false) AddedToken::from("world", false)
]), ]),
2 2
); );
@@ -31,7 +31,7 @@ fn add_tokens() {
#[test] #[test]
fn lstrip_tokens() { fn lstrip_tokens() {
let mut tokenizer = get_byte_level(true, false); let mut tokenizer = get_byte_level(true, false);
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).lstrip(true)]); tokenizer.add_special_tokens(&[AddedToken::from("<mask>", true).lstrip(true)]);
let input = "I saw a <mask> 😺"; let input = "I saw a <mask> 😺";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -49,7 +49,7 @@ fn lstrip_tokens() {
#[test] #[test]
fn rstrip_tokens() { fn rstrip_tokens() {
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]); tokenizer.add_special_tokens(&[AddedToken::from("<mask>", true).rstrip(true)]);
let input = "I saw a <mask> 😺"; let input = "I saw a <mask> 😺";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -62,7 +62,7 @@ fn rstrip_tokens() {
// When `add_prefix_space = true` rstrip cannot work as a prefix space is added // When `add_prefix_space = true` rstrip cannot work as a prefix space is added
// to the next token // to the next token
let mut tokenizer = get_byte_level(true, false); let mut tokenizer = get_byte_level(true, false);
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]); tokenizer.add_special_tokens(&[AddedToken::from("<mask>", true).rstrip(true)]);
let input = "I saw a <mask> 😺"; let input = "I saw a <mask> 😺";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -77,7 +77,7 @@ fn rstrip_tokens() {
fn single_word_tokens() { fn single_word_tokens() {
// If `single_word = true` it shouldn't split `dancing` // If `single_word = true` it shouldn't split `dancing`
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(true)]); tokenizer.add_special_tokens(&[AddedToken::from("ing", true).single_word(true)]);
let input = "I like dancing"; let input = "I like dancing";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -86,7 +86,7 @@ fn single_word_tokens() {
// If `single_word = false` it should split `dancing` // If `single_word = false` it should split `dancing`
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("ing", true).single_word(false)]);
let input = "I like dancing"; let input = "I like dancing";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -98,9 +98,9 @@ fn single_word_tokens() {
fn overlapping_tokens() { fn overlapping_tokens() {
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("danc", true)]);
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("nci", true)]);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("ing", true)]);
let input = "I like dancing"; let input = "I like dancing";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -109,10 +109,10 @@ fn overlapping_tokens() {
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("nci", true)]);
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("danc", true)]);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("ing", true)]);
tokenizer.add_special_tokens(&[AddedToken::from("ike".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("ike", true)]);
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();

View File

@@ -158,7 +158,7 @@ fn split_on_added_tokens_bert() {
let input = "Yesterday I saw a [MASK] far away"; let input = "Yesterday I saw a [MASK] far away";
let mut tokenizer = get_bert(); let mut tokenizer = get_bert();
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into(), true)]); tokenizer.add_special_tokens(&[AddedToken::from("[MASK]", true)]);
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
assert_eq!( assert_eq!(