mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-06 12:48:18 +00:00
AddedVocabulary - Add tests, update bindings + various tweaks
This commit is contained in:
11
bindings/node/lib/bindings/tokenizer.d.ts
vendored
11
bindings/node/lib/bindings/tokenizer.d.ts
vendored
@@ -392,6 +392,14 @@ export interface AddedTokenOptions {
|
||||
* @default False
|
||||
*/
|
||||
singleWord?: boolean;
|
||||
/**
|
||||
* Whether this token should match on the normalized version of the text. For example
|
||||
* with the added token `yesterday` and a normalizer in charge of lowercasing the text,
|
||||
* the input `I saw a lion Yesterday` would match the token.
|
||||
* This is False for special tokens by default, true otherwise
|
||||
* @default True
|
||||
*/
|
||||
normalized?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -404,9 +412,10 @@ export class AddedToken {
|
||||
/**
|
||||
* Instantiate a new AddedToken
|
||||
* @param content The content of the token
|
||||
* @param special Whether this is a special token
|
||||
* @param [options] Options for the token
|
||||
*/
|
||||
constructor(content: string, options?: AddedTokenOptions);
|
||||
constructor(content: string, special: boolean, options?: AddedTokenOptions);
|
||||
|
||||
/**
|
||||
* Get the content of the AddedToken
|
||||
|
||||
@@ -32,17 +32,17 @@ import {
|
||||
|
||||
describe("AddedToken", () => {
|
||||
it("instantiates with only content", () => {
|
||||
const addToken = new AddedToken("test");
|
||||
const addToken = new AddedToken("test", false);
|
||||
expect(addToken.constructor.name).toEqual("AddedToken");
|
||||
});
|
||||
|
||||
it("instantiates with empty options", () => {
|
||||
const addToken = new AddedToken("test", {});
|
||||
const addToken = new AddedToken("test", false, {});
|
||||
expect(addToken.constructor.name).toEqual("AddedToken");
|
||||
});
|
||||
|
||||
it("instantiates with options", () => {
|
||||
const addToken = new AddedToken("test", {
|
||||
const addToken = new AddedToken("test", false, {
|
||||
leftStrip: true,
|
||||
rightStrip: true,
|
||||
singleWord: true
|
||||
@@ -52,7 +52,7 @@ describe("AddedToken", () => {
|
||||
|
||||
describe("getContent", () => {
|
||||
it("returns the string content of AddedToken", () => {
|
||||
const addedToken = new AddedToken("test");
|
||||
const addedToken = new AddedToken("test", false);
|
||||
expect(addedToken.getContent()).toEqual("test");
|
||||
});
|
||||
});
|
||||
@@ -107,7 +107,7 @@ describe("Tokenizer", () => {
|
||||
it("accepts a list of AddedToken as new tokens when initial model is empty", () => {
|
||||
const model = BPE.empty();
|
||||
const tokenizer = new Tokenizer(model);
|
||||
const addedToken = new AddedToken("test");
|
||||
const addedToken = new AddedToken("test", false);
|
||||
|
||||
const nbAdd = tokenizer.addTokens([addedToken]);
|
||||
expect(nbAdd).toBe(1);
|
||||
@@ -132,7 +132,7 @@ describe("Tokenizer", () => {
|
||||
|
||||
const model = BPE.empty();
|
||||
tokenizer = new Tokenizer(model);
|
||||
tokenizer.addTokens(["my", "name", "is", "john", new AddedToken("pair")]);
|
||||
tokenizer.addTokens(["my", "name", "is", "john", new AddedToken("pair", false)]);
|
||||
|
||||
encode = promisify(tokenizer.encode.bind(tokenizer));
|
||||
encodeBatch = promisify(tokenizer.encodeBatch.bind(tokenizer));
|
||||
|
||||
@@ -30,10 +30,11 @@ struct AddedTokenOptions {
|
||||
singleWord: Option<bool>,
|
||||
leftStrip: Option<bool>,
|
||||
rightStrip: Option<bool>,
|
||||
normalized: Option<bool>,
|
||||
}
|
||||
impl AddedTokenOptions {
|
||||
fn into_added_token(self, content: String) -> tk::AddedToken {
|
||||
let mut token = tk::AddedToken::from(content);
|
||||
fn into_added_token(self, content: String, special: bool) -> tk::AddedToken {
|
||||
let mut token = tk::AddedToken::from(content, special);
|
||||
if let Some(sw) = self.singleWord {
|
||||
token = token.single_word(sw);
|
||||
}
|
||||
@@ -43,6 +44,9 @@ impl AddedTokenOptions {
|
||||
if let Some(rs) = self.rightStrip {
|
||||
token = token.rstrip(rs);
|
||||
}
|
||||
if let Some(n) = self.normalized {
|
||||
token = token.normalized(n);
|
||||
}
|
||||
token
|
||||
}
|
||||
}
|
||||
@@ -52,18 +56,20 @@ declare_types! {
|
||||
init(mut cx) {
|
||||
// init(
|
||||
// content: string,
|
||||
// special: boolean,
|
||||
// options?: {
|
||||
// singleWord?: boolean = false,
|
||||
// leftStrip?: boolean = false,
|
||||
// rightStrip?: boolean = false
|
||||
// normalized?: boolean = true,
|
||||
// }
|
||||
// )
|
||||
|
||||
let content = cx.extract::<String>(0)
|
||||
.map_err(|_| Error("First argument must be string".into()))?;
|
||||
let token = cx.extract_opt::<AddedTokenOptions>(1)?
|
||||
let content = cx.extract::<String>(0)?;
|
||||
let special = cx.extract::<bool>(1)?;
|
||||
let token = cx.extract_opt::<AddedTokenOptions>(2)?
|
||||
.unwrap_or_else(AddedTokenOptions::default)
|
||||
.into_added_token(content);
|
||||
.into_added_token(content, special);
|
||||
|
||||
Ok(AddedToken { token })
|
||||
}
|
||||
@@ -87,7 +93,7 @@ impl FromJsValue for AddedToken {
|
||||
fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> {
|
||||
if let Ok(token) = from.downcast::<JsString>() {
|
||||
Ok(AddedToken {
|
||||
token: tk::AddedToken::from(token.value()),
|
||||
token: tk::AddedToken::from(token.value(), false),
|
||||
})
|
||||
} else if let Ok(token) = from.downcast::<JsAddedToken>() {
|
||||
let guard = cx.lock();
|
||||
@@ -99,6 +105,21 @@ impl FromJsValue for AddedToken {
|
||||
}
|
||||
}
|
||||
|
||||
struct SpecialToken(tk::AddedToken);
|
||||
impl FromJsValue for SpecialToken {
|
||||
fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> {
|
||||
if let Ok(token) = from.downcast::<JsString>() {
|
||||
Ok(SpecialToken(tk::AddedToken::from(token.value(), true)))
|
||||
} else if let Ok(token) = from.downcast::<JsAddedToken>() {
|
||||
let guard = cx.lock();
|
||||
let token = token.borrow(&guard);
|
||||
Ok(SpecialToken(token.token.clone()))
|
||||
} else {
|
||||
Err(Error("Expected `string | AddedToken`".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// encode & encodeBatch types
|
||||
|
||||
struct TextInputSequence(tk::InputSequence);
|
||||
@@ -623,7 +644,7 @@ declare_types! {
|
||||
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let token = this.borrow(&guard).tokenizer.id_to_token(id);
|
||||
let token = this.borrow(&guard).tokenizer.id_to_token(id).map(|t| t.to_owned());
|
||||
|
||||
if let Some(token) = token {
|
||||
Ok(cx.string(token).upcast())
|
||||
@@ -650,9 +671,9 @@ declare_types! {
|
||||
method addSpecialTokens(mut cx) {
|
||||
// addSpecialTokens(tokens: (string | AddedToken)[]): number
|
||||
|
||||
let tokens = cx.extract_vec::<AddedToken>(0)?
|
||||
let tokens = cx.extract_vec::<SpecialToken>(0)?
|
||||
.into_iter()
|
||||
.map(|token| token.into())
|
||||
.map(|token| token.0)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut this = cx.this();
|
||||
|
||||
@@ -29,7 +29,7 @@ impl AddedToken {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let mut token = tk::tokenizer::AddedToken::from(content.to_owned(), is_special_token);
|
||||
let mut token = tk::tokenizer::AddedToken::from(content, is_special_token);
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, value) in kwargs {
|
||||
|
||||
@@ -200,7 +200,13 @@ class AddedToken:
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
cls, content: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False,
|
||||
cls,
|
||||
content: str,
|
||||
is_special_token: bool,
|
||||
single_word: bool = False,
|
||||
lstrip: bool = False,
|
||||
rstrip: bool = False,
|
||||
normalized: bool = True,
|
||||
) -> AddedToken:
|
||||
""" Instantiate a new AddedToken
|
||||
|
||||
@@ -208,19 +214,30 @@ class AddedToken:
|
||||
content: str:
|
||||
The content of the token
|
||||
|
||||
is_special_token: bool:
|
||||
Whether this token is a special token. This has an impact on the default value for
|
||||
`normalized` which is False for special tokens, but True for others.
|
||||
|
||||
single_word: bool
|
||||
Whether this token should only match against single word. If True,
|
||||
this token will never match inside of a word.
|
||||
Whether this token should only match against single words. If True,
|
||||
this token will never match inside of a word. For example the token `ing` would
|
||||
match on `tokenizing` if this option if False, but not if this option is True.
|
||||
|
||||
lstrip: bool
|
||||
Whether this token should strip all potential whitespaces on the left side.
|
||||
If True, this token will greedily match any whitespace on the left and then strip
|
||||
them out.
|
||||
If True, this token will greedily match any whitespace on the left. For example,
|
||||
if we try to match the token `[MASK]` with lstrip=True, in the text `I saw a [MASK]`
|
||||
we will match on ` [MASK]`.
|
||||
|
||||
rstrip: bool
|
||||
Whether this token should strip all potential whitespaces on the right side.
|
||||
If True, this token will greedily match any whitespace on the right and then strip
|
||||
them out.
|
||||
If True, this token will greedily match any whitespace on the right. It works just
|
||||
like lstrip, but on the right.
|
||||
|
||||
normalized: bool:
|
||||
Whether this token should be match the normalized version of the input text. For
|
||||
example, with the added token `yesterday` and a normalizer in charge of lowercasing
|
||||
the text, the token could be extract from the input `I saw a lion Yesterday`.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -56,11 +56,11 @@ fn main() -> Result<()>{
|
||||
.vocab_size(vocab_size)
|
||||
.min_frequency(0)
|
||||
.special_tokens(vec![
|
||||
AddedToken::from("<s>".into()),
|
||||
AddedToken::from("<pad>".into()),
|
||||
AddedToken::from("</s>".into()),
|
||||
AddedToken::from("<unk>".into()),
|
||||
AddedToken::from("<mask>".into()),
|
||||
AddedToken::from("<s>", true),
|
||||
AddedToken::from("<pad>", true),
|
||||
AddedToken::from("</s>", true),
|
||||
AddedToken::from("<unk>", true),
|
||||
AddedToken::from("<mask>", true),
|
||||
])
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -17,9 +17,8 @@ fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
|
||||
let mut tokenizer = Tokenizer::new(Box::new(bpe));
|
||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
|
||||
tokenizer
|
||||
.add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
|
||||
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
|
||||
tokenizer
|
||||
}
|
||||
|
||||
|
||||
@@ -20,13 +20,14 @@ pub struct AddedToken {
|
||||
/// Whether this token should be normalized
|
||||
pub normalized: bool,
|
||||
}
|
||||
|
||||
impl AddedToken {
|
||||
/// Build this token from the given content, specifying if it is intented to be a
|
||||
/// special token. Special tokens are not normalized by default.
|
||||
pub fn from(content: String, special_token: bool) -> Self {
|
||||
pub fn from<S: Into<String>>(content: S, special: bool) -> Self {
|
||||
AddedToken {
|
||||
content,
|
||||
normalized: !special_token,
|
||||
content: content.into(),
|
||||
normalized: !special,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
@@ -48,7 +49,7 @@ impl AddedToken {
|
||||
self.rstrip = rstrip;
|
||||
self
|
||||
}
|
||||
/// Specify whether this token should be normalized, and/or match against its normalized
|
||||
/// Specify whether this token should be normalized and match against its normalized
|
||||
/// version in the input text.
|
||||
pub fn normalized(mut self, normalized: bool) -> Self {
|
||||
self.normalized = normalized;
|
||||
@@ -108,7 +109,7 @@ impl Default for AddedToken {
|
||||
single_word: false,
|
||||
lstrip: false,
|
||||
rstrip: false,
|
||||
normalized: false,
|
||||
normalized: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -144,22 +145,22 @@ type MatchingSet = (regex::RegexSet, Vec<u32>);
|
||||
/// exist as required.
|
||||
///
|
||||
pub(super) struct AddedVocabulary {
|
||||
/// The size of the original vocabulary. This is what we use to determine the new
|
||||
/// ids we need to generate
|
||||
original_vocab_size: usize,
|
||||
/// Contains the mapping from String to ID as the user intended it. This map
|
||||
/// contains both special tokens and classic added tokens.
|
||||
/// Contains the mapping from String (token content) to ID. This map contains both special
|
||||
/// tokens and classic added tokens that were added to the this vocabulary.
|
||||
added_tokens_map: HashMap<String, u32>,
|
||||
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
|
||||
/// and classic.
|
||||
added_tokens_map_r: HashMap<u32, AddedToken>,
|
||||
|
||||
/// Contains only the classic AddedToken, in the specific order the user gave them.
|
||||
added_tokens: Vec<AddedToken>,
|
||||
/// Contains only the special AddedToken, in the specific order the user gave them.
|
||||
special_tokens: Vec<AddedToken>,
|
||||
|
||||
/// A Set, containing all the special token for easy access while decoding. This let's
|
||||
/// use remove them easily with an O(1) complexity.
|
||||
/// us remove them easily with an O(1) complexity.
|
||||
special_tokens_set: HashSet<String>,
|
||||
|
||||
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
|
||||
split_re: MatchingSet,
|
||||
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
|
||||
@@ -167,9 +168,8 @@ pub(super) struct AddedVocabulary {
|
||||
}
|
||||
|
||||
impl AddedVocabulary {
|
||||
pub fn new(original_vocab_size: usize) -> Self {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
original_vocab_size,
|
||||
added_tokens_map: HashMap::new(),
|
||||
added_tokens_map_r: HashMap::new(),
|
||||
added_tokens: vec![],
|
||||
@@ -180,12 +180,6 @@ impl AddedVocabulary {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the original vocabulary size. We need this value to return IDs that
|
||||
/// are shifted according to the original vocabulary.
|
||||
pub fn update_original_vocab_size(&mut self, size: usize) {
|
||||
self.original_vocab_size = size;
|
||||
}
|
||||
|
||||
/// Size of the additional vocabulary
|
||||
pub fn len(&self) -> usize {
|
||||
self.added_tokens_map.len()
|
||||
@@ -252,7 +246,7 @@ impl AddedVocabulary {
|
||||
ignored += 1;
|
||||
id
|
||||
} else {
|
||||
let new_id = (self.original_vocab_size + self.added_tokens_map.len()) as u32;
|
||||
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
@@ -400,7 +394,6 @@ impl AddedVocabulary {
|
||||
splits
|
||||
.into_iter()
|
||||
.map(|(idx, (start, end))| {
|
||||
// TODO: Check this works (especially for offsets)
|
||||
let normalized = sentence
|
||||
.slice_bytes(Range::Normalized(start..end))
|
||||
.expect("Error while extracting normalized Range");
|
||||
@@ -472,7 +465,6 @@ impl Serialize for AddedVocabulary {
|
||||
.added_tokens_map_r
|
||||
.iter()
|
||||
.map(|(id, token)| AddedTokenWithId {
|
||||
// TODO: Make sure these are the right IDs (related to the model)
|
||||
id: *id,
|
||||
special: self.special_tokens_set.contains(&token.content),
|
||||
token: token.clone(),
|
||||
@@ -488,3 +480,211 @@ impl Serialize for AddedVocabulary {
|
||||
vocabulary.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::normalizers::utils::Lowercase;
|
||||
use crate::{Offsets, Result, Token};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct ModelMock {
|
||||
vocab: HashMap<String, u32>,
|
||||
vocab_r: HashMap<u32, String>,
|
||||
}
|
||||
impl ModelMock {
|
||||
pub fn new<I>(iter: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = &'static (&'static str, u32)>,
|
||||
{
|
||||
let vocab: HashMap<String, u32> = iter
|
||||
.into_iter()
|
||||
.map(|&(tok, id)| (tok.to_string(), id))
|
||||
.collect();
|
||||
Self {
|
||||
vocab_r: vocab
|
||||
.iter()
|
||||
.map(|(tok, id)| (*id, tok.to_owned()))
|
||||
.collect(),
|
||||
vocab,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[typetag::serde]
|
||||
impl Model for ModelMock {
|
||||
fn tokenize(&self, _tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
self.vocab_r.get(&id).map(String::as_ref)
|
||||
}
|
||||
fn get_vocab(&self) -> &HashMap<String, u32> {
|
||||
&self.vocab
|
||||
}
|
||||
fn get_vocab_size(&self) -> usize {
|
||||
self.vocab.len()
|
||||
}
|
||||
fn save(&self, _folder: &Path, _name: Option<&str>) -> Result<Vec<PathBuf>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_tokens() {
|
||||
let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
|
||||
let mut vocab = AddedVocabulary::new();
|
||||
|
||||
// Add tokens normally
|
||||
assert_eq!(
|
||||
vocab.add_tokens(&[AddedToken::from("added_token_1", false)], &model, None),
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 1);
|
||||
|
||||
// Does not add multiple time the same token
|
||||
assert_eq!(
|
||||
vocab.add_tokens(
|
||||
&[
|
||||
AddedToken::from("added_token_2", false),
|
||||
AddedToken::from("added_token_2", false)
|
||||
],
|
||||
&model,
|
||||
None
|
||||
),
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 2);
|
||||
|
||||
// Does not add tokens already covered by the model
|
||||
assert_eq!(
|
||||
vocab.add_tokens(&[AddedToken::from("test", false)], &model, None),
|
||||
0
|
||||
);
|
||||
assert_eq!(vocab.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_special_tokens() {
|
||||
let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
|
||||
let mut vocab = AddedVocabulary::new();
|
||||
|
||||
// Add tokens normally
|
||||
assert_eq!(
|
||||
vocab.add_special_tokens(&[AddedToken::from("added_token_1", true)], &model, None),
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 1);
|
||||
|
||||
// Does not add multiple time the same token
|
||||
assert_eq!(
|
||||
vocab.add_special_tokens(
|
||||
&[
|
||||
AddedToken::from("added_token_2", true),
|
||||
AddedToken::from("added_token_2", true)
|
||||
],
|
||||
&model,
|
||||
None
|
||||
),
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 2);
|
||||
|
||||
// Can add tokens already covered by the model
|
||||
assert_eq!(
|
||||
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, None),
|
||||
0
|
||||
);
|
||||
assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model
|
||||
assert_eq!(vocab.is_special_token("test"), true);
|
||||
assert_eq!(vocab.added_tokens_map.contains_key("test"), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_extract_added_tokens() {
|
||||
// Is able to extract both normal and special tokens
|
||||
let model = ModelMock::new(&[]);
|
||||
let mut vocab = AddedVocabulary::new();
|
||||
|
||||
vocab.add_tokens(
|
||||
&[
|
||||
AddedToken::from("my", false),
|
||||
AddedToken::from("name", false),
|
||||
],
|
||||
&model,
|
||||
None,
|
||||
);
|
||||
vocab.add_special_tokens(
|
||||
&[
|
||||
AddedToken::from("[CLS]", true),
|
||||
AddedToken::from("[SEP]", true),
|
||||
],
|
||||
&model,
|
||||
None,
|
||||
);
|
||||
|
||||
let result = vocab.extract_and_normalize(None, "[CLS] My name is Anthony [SEP]");
|
||||
assert_eq!(
|
||||
result
|
||||
.iter()
|
||||
.map(|(normalized, id)| (normalized.get(), *id))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("[CLS]", Some(2)),
|
||||
(" My ", None),
|
||||
("name", Some(1)),
|
||||
(" is Anthony ", None),
|
||||
("[SEP]", Some(3))
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn options_use_cases() {
|
||||
// Is able to extract both normal and special tokens, with various options (lstrip, rstrip,
|
||||
// single_word, normalized)
|
||||
let model = ModelMock::new(&[]);
|
||||
let normalizer = Lowercase;
|
||||
let mut vocab = AddedVocabulary::new();
|
||||
|
||||
vocab.add_tokens(
|
||||
&[
|
||||
AddedToken::from("my", false).lstrip(true).rstrip(true),
|
||||
AddedToken::from("name", false),
|
||||
AddedToken::from("ony", false).single_word(true),
|
||||
],
|
||||
&model,
|
||||
Some(&normalizer),
|
||||
);
|
||||
vocab.add_special_tokens(
|
||||
&[
|
||||
AddedToken::from("[CLS]", true),
|
||||
AddedToken::from("[SEP]", true),
|
||||
],
|
||||
&model,
|
||||
Some(&normalizer),
|
||||
);
|
||||
|
||||
let result =
|
||||
vocab.extract_and_normalize(Some(&normalizer), "[CLS] My name is Anthony [SEP]");
|
||||
assert_eq!(
|
||||
result
|
||||
.iter()
|
||||
.map(|(normalized, id)| (normalized.get(), *id))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("[CLS]", Some(3)),
|
||||
// This one includes both spaces because of the lstrip & rstrip
|
||||
// And it matches because normalized == true
|
||||
(" my ", Some(0)),
|
||||
("name", Some(1)),
|
||||
// `ony` is not extracted here thanks to single_word
|
||||
(" is anthony ", None),
|
||||
("[SEP]", Some(4))
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,7 +211,6 @@ impl std::str::FromStr for Tokenizer {
|
||||
impl Tokenizer {
|
||||
/// Instantiate a new Tokenizer, with the given Model
|
||||
pub fn new(model: Box<dyn Model>) -> Self {
|
||||
let original_vocab_size = model.get_vocab_size();
|
||||
Tokenizer {
|
||||
normalizer: None,
|
||||
pre_tokenizer: None,
|
||||
@@ -219,7 +218,7 @@ impl Tokenizer {
|
||||
post_processor: None,
|
||||
decoder: None,
|
||||
|
||||
added_vocabulary: AddedVocabulary::new(original_vocab_size),
|
||||
added_vocabulary: AddedVocabulary::new(),
|
||||
|
||||
truncation: None,
|
||||
padding: None,
|
||||
@@ -303,8 +302,6 @@ impl Tokenizer {
|
||||
/// Set the model
|
||||
pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self {
|
||||
self.model = model;
|
||||
self.added_vocabulary
|
||||
.update_original_vocab_size(self.model.get_vocab_size());
|
||||
self
|
||||
}
|
||||
|
||||
@@ -669,8 +666,6 @@ impl Tokenizer {
|
||||
|
||||
let (model, special_tokens) = trainer.train(words)?;
|
||||
self.model = model;
|
||||
self.added_vocabulary
|
||||
.update_original_vocab_size(self.model.get_vocab_size());
|
||||
self.add_special_tokens(&special_tokens);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -50,7 +50,7 @@ where
|
||||
/// It is possible to retrieve a part of the original string, by indexing it with offsets from the
|
||||
/// normalized one, and the other way around too. It is also possible to convert offsets from one
|
||||
/// referential to the other one easily.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct NormalizedString {
|
||||
/// The original version of the string, before any modification
|
||||
original: String,
|
||||
@@ -61,12 +61,6 @@ pub struct NormalizedString {
|
||||
alignments: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for NormalizedString {
|
||||
fn eq(&self, other: &NormalizedString) -> bool {
|
||||
self.normalized == other.normalized
|
||||
}
|
||||
}
|
||||
|
||||
impl NormalizedString {
|
||||
/// Create a NormalizedString from the given str
|
||||
pub fn from(s: &str) -> Self {
|
||||
@@ -441,7 +435,7 @@ impl NormalizedString {
|
||||
/// Merge with the given NormalizedString by appending it to self
|
||||
pub fn merge_with(&mut self, other: &NormalizedString) {
|
||||
self.original.push_str(&other.original);
|
||||
let len = self.len();
|
||||
let len = self.len() - 1;
|
||||
self.alignments.extend(
|
||||
other
|
||||
.alignments
|
||||
@@ -879,7 +873,7 @@ mod tests {
|
||||
Some(NormalizedString {
|
||||
original: "𝕞𝕠𝕣𝕟𝕚𝕟𝕘".to_string(),
|
||||
normalized: "morning".to_string(),
|
||||
alignments: vec![(5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12)]
|
||||
alignments: vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
|
||||
@@ -9,8 +9,8 @@ fn add_tokens() {
|
||||
|
||||
assert_eq!(
|
||||
tokenizer.add_special_tokens(&[
|
||||
AddedToken::from("<cls>".into(), true),
|
||||
AddedToken::from("<sep>".into(), true)
|
||||
AddedToken::from("<cls>", true),
|
||||
AddedToken::from("<sep>", true)
|
||||
]),
|
||||
2
|
||||
);
|
||||
@@ -19,8 +19,8 @@ fn add_tokens() {
|
||||
|
||||
assert_eq!(
|
||||
tokenizer.add_tokens(&[
|
||||
AddedToken::from("hello".into(), false),
|
||||
AddedToken::from("world".into(), false)
|
||||
AddedToken::from("hello", false),
|
||||
AddedToken::from("world", false)
|
||||
]),
|
||||
2
|
||||
);
|
||||
@@ -31,7 +31,7 @@ fn add_tokens() {
|
||||
#[test]
|
||||
fn lstrip_tokens() {
|
||||
let mut tokenizer = get_byte_level(true, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).lstrip(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>", true).lstrip(true)]);
|
||||
|
||||
let input = "I saw a <mask> 😺";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -49,7 +49,7 @@ fn lstrip_tokens() {
|
||||
#[test]
|
||||
fn rstrip_tokens() {
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>", true).rstrip(true)]);
|
||||
|
||||
let input = "I saw a <mask> 😺";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -62,7 +62,7 @@ fn rstrip_tokens() {
|
||||
// When `add_prefix_space = true` rstrip cannot work as a prefix space is added
|
||||
// to the next token
|
||||
let mut tokenizer = get_byte_level(true, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>", true).rstrip(true)]);
|
||||
|
||||
let input = "I saw a <mask> 😺";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -77,7 +77,7 @@ fn rstrip_tokens() {
|
||||
fn single_word_tokens() {
|
||||
// If `single_word = true` it shouldn't split `dancing`
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing", true).single_word(true)]);
|
||||
|
||||
let input = "I like dancing";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -86,7 +86,7 @@ fn single_word_tokens() {
|
||||
|
||||
// If `single_word = false` it should split `dancing`
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(false)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing", true).single_word(false)]);
|
||||
|
||||
let input = "I like dancing";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -98,9 +98,9 @@ fn single_word_tokens() {
|
||||
fn overlapping_tokens() {
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc", true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci", true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing", true)]);
|
||||
|
||||
let input = "I like dancing";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -109,10 +109,10 @@ fn overlapping_tokens() {
|
||||
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ike".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci", true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc", true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing", true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ike", true)]);
|
||||
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ fn split_on_added_tokens_bert() {
|
||||
let input = "Yesterday I saw a [MASK] far away";
|
||||
|
||||
let mut tokenizer = get_bert();
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]", true)]);
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
|
||||
Reference in New Issue
Block a user