mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 21:28:19 +00:00
Python - Add bindings for new AddedToken options
This commit is contained in:
@@ -85,6 +85,7 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<tokenizer::Tokenizer>()?;
|
m.add_class::<tokenizer::Tokenizer>()?;
|
||||||
|
m.add_class::<tokenizer::AddedToken>()?;
|
||||||
m.add_class::<encoding::Encoding>()?;
|
m.add_class::<encoding::Encoding>()?;
|
||||||
m.add_wrapped(wrap_pymodule!(models))?;
|
m.add_wrapped(wrap_pymodule!(models))?;
|
||||||
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
||||||
|
|||||||
@@ -18,6 +18,34 @@ use tk::tokenizer::{
|
|||||||
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
|
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[pyclass(dict)]
|
||||||
|
pub struct AddedToken {
|
||||||
|
pub token: tk::tokenizer::AddedToken,
|
||||||
|
}
|
||||||
|
#[pymethods]
|
||||||
|
impl AddedToken {
|
||||||
|
#[new]
|
||||||
|
#[args(kwargs = "**")]
|
||||||
|
fn new(obj: &PyRawObject, content: &str, kwargs: Option<&PyDict>) -> PyResult<()> {
|
||||||
|
let mut token = tk::tokenizer::AddedToken::from(content.to_owned());
|
||||||
|
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
for (key, value) in kwargs {
|
||||||
|
let key: &str = key.extract()?;
|
||||||
|
match key {
|
||||||
|
"single_word" => token = token.single_word(value.extract()?),
|
||||||
|
"lstrip" => token = token.lstrip(value.extract()?),
|
||||||
|
"rstrip" => token = token.rstrip(value.extract()?),
|
||||||
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
obj.init({ AddedToken { token } });
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
pub struct Tokenizer {
|
pub struct Tokenizer {
|
||||||
tokenizer: tk::tokenizer::Tokenizer,
|
tokenizer: tk::tokenizer::Tokenizer,
|
||||||
@@ -256,14 +284,11 @@ impl Tokenizer {
|
|||||||
content,
|
content,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
})
|
||||||
} else if let Ok((content, single_word)) = token.extract::<(String, bool)>() {
|
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||||
Ok(tk::tokenizer::AddedToken {
|
Ok(token.token.clone())
|
||||||
content,
|
|
||||||
single_word,
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::Exception::py_err(
|
Err(exceptions::Exception::py_err(
|
||||||
"Input must be a list[str] or list[(str, bool)]",
|
"Input must be a List[Union[str, AddedToken]]",
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -272,7 +297,25 @@ impl Tokenizer {
|
|||||||
Ok(self.tokenizer.add_tokens(&tokens))
|
Ok(self.tokenizer.add_tokens(&tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_special_tokens(&mut self, tokens: Vec<&str>) -> PyResult<usize> {
|
fn add_special_tokens(&mut self, tokens: &PyList) -> PyResult<usize> {
|
||||||
|
let tokens = tokens
|
||||||
|
.into_iter()
|
||||||
|
.map(|token| {
|
||||||
|
if let Ok(content) = token.extract::<String>() {
|
||||||
|
Ok(tk::tokenizer::AddedToken {
|
||||||
|
content,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||||
|
Ok(token.token.clone())
|
||||||
|
} else {
|
||||||
|
Err(exceptions::Exception::py_err(
|
||||||
|
"Input must be a List[Union[str, AddedToken]]",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<_>>>()?;
|
||||||
|
|
||||||
Ok(self.tokenizer.add_special_tokens(&tokens))
|
Ok(self.tokenizer.add_special_tokens(&tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use super::utils::Container;
|
use super::utils::Container;
|
||||||
|
use crate::tokenizer::AddedToken;
|
||||||
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
|
|
||||||
@@ -28,7 +30,27 @@ impl BpeTrainer {
|
|||||||
"vocab_size" => builder = builder.vocab_size(val.extract()?),
|
"vocab_size" => builder = builder.vocab_size(val.extract()?),
|
||||||
"min_frequency" => builder = builder.min_frequency(val.extract()?),
|
"min_frequency" => builder = builder.min_frequency(val.extract()?),
|
||||||
"show_progress" => builder = builder.show_progress(val.extract()?),
|
"show_progress" => builder = builder.show_progress(val.extract()?),
|
||||||
"special_tokens" => builder = builder.special_tokens(val.extract()?),
|
"special_tokens" => {
|
||||||
|
builder = builder.special_tokens(
|
||||||
|
val.cast_as::<PyList>()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|token| {
|
||||||
|
if let Ok(content) = token.extract::<String>() {
|
||||||
|
Ok(tk::tokenizer::AddedToken {
|
||||||
|
content,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||||
|
Ok(token.token.clone())
|
||||||
|
} else {
|
||||||
|
Err(exceptions::Exception::py_err(
|
||||||
|
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<_>>>()?,
|
||||||
|
);
|
||||||
|
}
|
||||||
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
||||||
"initial_alphabet" => {
|
"initial_alphabet" => {
|
||||||
let alphabet: Vec<String> = val.extract()?;
|
let alphabet: Vec<String> = val.extract()?;
|
||||||
@@ -74,7 +96,27 @@ impl WordPieceTrainer {
|
|||||||
"vocab_size" => builder = builder.vocab_size(val.extract()?),
|
"vocab_size" => builder = builder.vocab_size(val.extract()?),
|
||||||
"min_frequency" => builder = builder.min_frequency(val.extract()?),
|
"min_frequency" => builder = builder.min_frequency(val.extract()?),
|
||||||
"show_progress" => builder = builder.show_progress(val.extract()?),
|
"show_progress" => builder = builder.show_progress(val.extract()?),
|
||||||
"special_tokens" => builder = builder.special_tokens(val.extract()?),
|
"special_tokens" => {
|
||||||
|
builder = builder.special_tokens(
|
||||||
|
val.cast_as::<PyList>()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|token| {
|
||||||
|
if let Ok(content) = token.extract::<String>() {
|
||||||
|
Ok(tk::tokenizer::AddedToken {
|
||||||
|
content,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||||
|
Ok(token.token.clone())
|
||||||
|
} else {
|
||||||
|
Err(exceptions::Exception::py_err(
|
||||||
|
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<_>>>()?,
|
||||||
|
);
|
||||||
|
}
|
||||||
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
||||||
"initial_alphabet" => {
|
"initial_alphabet" => {
|
||||||
let alphabet: Vec<String> = val.extract()?;
|
let alphabet: Vec<String> = val.extract()?;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
__version__ = "0.6.0"
|
__version__ = "0.6.0"
|
||||||
|
|
||||||
from .tokenizers import Tokenizer, Encoding
|
from .tokenizers import Tokenizer, Encoding, AddedToken
|
||||||
from .tokenizers import decoders
|
from .tokenizers import decoders
|
||||||
from .tokenizers import models
|
from .tokenizers import models
|
||||||
from .tokenizers import normalizers
|
from .tokenizers import normalizers
|
||||||
|
|||||||
@@ -91,6 +91,37 @@ class Encoding:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class AddedToken:
|
||||||
|
""" AddedToken represents a token to be added to a Tokenizer
|
||||||
|
|
||||||
|
An AddedToken can have special options defining the way it should behave.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls, content: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False
|
||||||
|
) -> AddedToken:
|
||||||
|
""" Instantiate a new AddedToken
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: str:
|
||||||
|
The content of the token
|
||||||
|
|
||||||
|
single_word: bool
|
||||||
|
Whether this token should only match against single word. If True,
|
||||||
|
this token will never match inside of a word.
|
||||||
|
|
||||||
|
lstrip: bool
|
||||||
|
Whether this token should strip all potential whitespaces on the left side.
|
||||||
|
If True, this token will greedily match any whitespace on the left and then strip
|
||||||
|
them out.
|
||||||
|
|
||||||
|
rstrip: bool
|
||||||
|
Whether this token should strip all potential whitespaces on the right side.
|
||||||
|
If True, this token will greedily match any whitespace on the right and then strip
|
||||||
|
them out.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
""" Tokenizer
|
""" Tokenizer
|
||||||
|
|
||||||
@@ -320,29 +351,28 @@ class Tokenizer:
|
|||||||
The corresponding string if it exists, None otherwise
|
The corresponding string if it exists, None otherwise
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
def add_tokens(self, tokens: List[Union[str, Tuple[str, bool]]]) -> int:
|
def add_tokens(self, tokens: List[Union[str, AddedToken]]) -> int:
|
||||||
""" Add the given tokens to the vocabulary
|
""" Add the given tokens to the vocabulary
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: List[Union[str, Tuple[str, bool]]]:
|
tokens: List[Union[str, AddedToken]]:
|
||||||
A list of tokens to add to the vocabulary. Each token can either be
|
A list of tokens to add to the vocabulary. Each token can either be
|
||||||
a string, or a tuple with a string representing the token, and a boolean
|
a string, or an instance of AddedToken
|
||||||
option representing whether to match on single words only.
|
|
||||||
If the boolean is not included, it defaults to False
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens that were added to the vocabulary
|
The number of tokens that were added to the vocabulary
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
def add_special_tokens(self, tokens: List[str]) -> int:
|
def add_special_tokens(self, tokens: List[Union[str, AddedToken]]) -> int:
|
||||||
""" Add the given special tokens to the vocabulary, and treat them as special tokens.
|
""" Add the given special tokens to the vocabulary, and treat them as special tokens.
|
||||||
|
|
||||||
The special tokens will never be processed by the model, and will be
|
The special tokens will never be processed by the model, and will be
|
||||||
removed while decoding.
|
removed while decoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: List[str]:
|
tokens: List[Union[str, AddedToken]]:
|
||||||
The list of special tokens to add
|
The list of special tokens to add. Each token can either be a string
|
||||||
|
or an instance of AddedToken
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens that were added to the vocabulary
|
The number of tokens that were added to the vocabulary
|
||||||
|
|||||||
@@ -95,30 +95,29 @@ class BaseTokenizer:
|
|||||||
""" Disable truncation """
|
""" Disable truncation """
|
||||||
return self._tokenizer.no_truncation()
|
return self._tokenizer.no_truncation()
|
||||||
|
|
||||||
def add_tokens(self, tokens: List[Union[str, Tuple[str, bool]]]) -> int:
|
def add_tokens(self, tokens: List[Union[str, AddedToken]]) -> int:
|
||||||
""" Add the given tokens to the vocabulary
|
""" Add the given tokens to the vocabulary
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: List[Union[str, Tuple[str, bool]]]:
|
tokens: List[Union[str, AddedToken]]:
|
||||||
A list of tokens to add to the vocabulary. Each token can either be
|
A list of tokens to add to the vocabulary. Each token can either be
|
||||||
a string, or a tuple with a string representing the token, and a boolean
|
a string, or an instance of AddedToken
|
||||||
option representing whether to match on single words only.
|
|
||||||
If the boolean is not included, it defaults to False
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens that were added to the vocabulary
|
The number of tokens that were added to the vocabulary
|
||||||
"""
|
"""
|
||||||
return self._tokenizer.add_tokens(tokens)
|
return self._tokenizer.add_tokens(tokens)
|
||||||
|
|
||||||
def add_special_tokens(self, special_tokens: List[str]) -> int:
|
def add_special_tokens(self, special_tokens: List[Union[str, AddedToken]]) -> int:
|
||||||
""" Add the given special tokens to the vocabulary, and treat them as special tokens.
|
""" Add the given special tokens to the vocabulary, and treat them as special tokens.
|
||||||
|
|
||||||
The special tokens will never be processed by the model, and will be
|
The special tokens will never be processed by the model, and will be
|
||||||
removed while decoding.
|
removed while decoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: List[str]:
|
tokens: List[Union[str, AddedToken]]:
|
||||||
The list of special tokens to add
|
A list of special tokens to add to the vocabulary. Each token can either be
|
||||||
|
a string, or an instance of AddedToken
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens that were added to the vocabulary
|
The number of tokens that were added to the vocabulary
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from tokenizers import Tokenizer, decoders, trainers
|
from tokenizers import Tokenizer, AddedToken, decoders, trainers
|
||||||
from tokenizers.models import WordPiece
|
from tokenizers.models import WordPiece
|
||||||
from tokenizers.normalizers import BertNormalizer
|
from tokenizers.normalizers import BertNormalizer
|
||||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||||
@@ -15,11 +15,11 @@ class BertWordPieceTokenizer(BaseTokenizer):
|
|||||||
self,
|
self,
|
||||||
vocab_file: Optional[str] = None,
|
vocab_file: Optional[str] = None,
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
unk_token: str = "[UNK]",
|
unk_token: Union[str, AddedToken] = "[UNK]",
|
||||||
sep_token: str = "[SEP]",
|
sep_token: Union[str, AddedToken] = "[SEP]",
|
||||||
cls_token: str = "[CLS]",
|
cls_token: Union[str, AddedToken] = "[CLS]",
|
||||||
pad_token: str = "[PAD]",
|
pad_token: Union[str, AddedToken] = "[PAD]",
|
||||||
mask_token: str = "[MASK]",
|
mask_token: Union[str, AddedToken] = "[MASK]",
|
||||||
clean_text: bool = True,
|
clean_text: bool = True,
|
||||||
handle_chinese_chars: bool = True,
|
handle_chinese_chars: bool = True,
|
||||||
strip_accents: bool = True,
|
strip_accents: bool = True,
|
||||||
@@ -89,7 +89,13 @@ class BertWordPieceTokenizer(BaseTokenizer):
|
|||||||
min_frequency: int = 2,
|
min_frequency: int = 2,
|
||||||
limit_alphabet: int = 1000,
|
limit_alphabet: int = 1000,
|
||||||
initial_alphabet: List[str] = [],
|
initial_alphabet: List[str] = [],
|
||||||
special_tokens: List[str] = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
|
special_tokens: List[Union[str, AddedToken]] = [
|
||||||
|
"[PAD]",
|
||||||
|
"[UNK]",
|
||||||
|
"[CLS]",
|
||||||
|
"[SEP]",
|
||||||
|
"[MASK]",
|
||||||
|
],
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
wordpieces_prefix: str = "##",
|
wordpieces_prefix: str = "##",
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from tokenizers import Tokenizer, pre_tokenizers, decoders, trainers, processors
|
from tokenizers import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers, processors
|
||||||
from tokenizers.models import BPE
|
from tokenizers.models import BPE
|
||||||
from tokenizers.normalizers import unicode_normalizer_from_str, Lowercase, Sequence
|
from tokenizers.normalizers import unicode_normalizer_from_str, Lowercase, Sequence
|
||||||
from .base_tokenizer import BaseTokenizer
|
from .base_tokenizer import BaseTokenizer
|
||||||
@@ -76,7 +76,7 @@ class ByteLevelBPETokenizer(BaseTokenizer):
|
|||||||
vocab_size: int = 30000,
|
vocab_size: int = 30000,
|
||||||
min_frequency: int = 2,
|
min_frequency: int = 2,
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
special_tokens: List[str] = [],
|
special_tokens: List[Union[str, AddedToken]] = [],
|
||||||
):
|
):
|
||||||
""" Train the model using the given files """
|
""" Train the model using the given files """
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .. import Tokenizer, pre_tokenizers, decoders, trainers
|
from .. import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers
|
||||||
from ..models import BPE
|
from ..models import BPE
|
||||||
from ..normalizers import Sequence, Lowercase, unicode_normalizer_from_str
|
from ..normalizers import Sequence, Lowercase, unicode_normalizer_from_str
|
||||||
from .base_tokenizer import BaseTokenizer
|
from .base_tokenizer import BaseTokenizer
|
||||||
@@ -16,8 +16,8 @@ class CharBPETokenizer(BaseTokenizer):
|
|||||||
self,
|
self,
|
||||||
vocab_file: Optional[str] = None,
|
vocab_file: Optional[str] = None,
|
||||||
merges_file: Optional[str] = None,
|
merges_file: Optional[str] = None,
|
||||||
unk_token: Optional[str] = "<unk>",
|
unk_token: Union[str, AddedToken] = "<unk>",
|
||||||
suffix: Optional[str] = "</w>",
|
suffix: str = "</w>",
|
||||||
dropout: Optional[float] = None,
|
dropout: Optional[float] = None,
|
||||||
lowercase: bool = False,
|
lowercase: bool = False,
|
||||||
unicode_normalizer: Optional[str] = None,
|
unicode_normalizer: Optional[str] = None,
|
||||||
@@ -73,7 +73,7 @@ class CharBPETokenizer(BaseTokenizer):
|
|||||||
files: Union[str, List[str]],
|
files: Union[str, List[str]],
|
||||||
vocab_size: int = 30000,
|
vocab_size: int = 30000,
|
||||||
min_frequency: int = 2,
|
min_frequency: int = 2,
|
||||||
special_tokens: List[str] = ["<unk>"],
|
special_tokens: List[Union[str, AddedToken]] = ["<unk>"],
|
||||||
limit_alphabet: int = 1000,
|
limit_alphabet: int = 1000,
|
||||||
initial_alphabet: List[str] = [],
|
initial_alphabet: List[str] = [],
|
||||||
suffix: Optional[str] = "</w>",
|
suffix: Optional[str] = "</w>",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from tokenizers import Tokenizer, pre_tokenizers, decoders, trainers
|
from tokenizers import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers
|
||||||
from tokenizers.models import BPE
|
from tokenizers.models import BPE
|
||||||
from tokenizers.normalizers import NFKC
|
from tokenizers.normalizers import NFKC
|
||||||
from .base_tokenizer import BaseTokenizer
|
from .base_tokenizer import BaseTokenizer
|
||||||
@@ -16,7 +16,7 @@ class SentencePieceBPETokenizer(BaseTokenizer):
|
|||||||
self,
|
self,
|
||||||
vocab_file: Optional[str] = None,
|
vocab_file: Optional[str] = None,
|
||||||
merges_file: Optional[str] = None,
|
merges_file: Optional[str] = None,
|
||||||
unk_token: str = "<unk>",
|
unk_token: Union[str, AddedToken] = "<unk>",
|
||||||
replacement: str = "▁",
|
replacement: str = "▁",
|
||||||
add_prefix_space: bool = True,
|
add_prefix_space: bool = True,
|
||||||
dropout: Optional[float] = None,
|
dropout: Optional[float] = None,
|
||||||
@@ -54,7 +54,7 @@ class SentencePieceBPETokenizer(BaseTokenizer):
|
|||||||
files: Union[str, List[str]],
|
files: Union[str, List[str]],
|
||||||
vocab_size: int = 30000,
|
vocab_size: int = 30000,
|
||||||
min_frequency: int = 2,
|
min_frequency: int = 2,
|
||||||
special_tokens: List[str] = ["<unk>"],
|
special_tokens: List[Union[str, AddedToken]] = ["<unk>"],
|
||||||
limit_alphabet: int = 1000,
|
limit_alphabet: int = 1000,
|
||||||
initial_alphabet: List[str] = [],
|
initial_alphabet: List[str] = [],
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional, List
|
from .. import AddedToken
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
""" Base class for all trainers
|
""" Base class for all trainers
|
||||||
@@ -18,7 +19,7 @@ class BpeTrainer(Trainer):
|
|||||||
vocab_size: int = 30000,
|
vocab_size: int = 30000,
|
||||||
min_frequency: int = 0,
|
min_frequency: int = 0,
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
special_tokens: List[str] = [],
|
special_tokens: List[Union[str, AddedToken]] = [],
|
||||||
limit_alphabet: Optional[int] = None,
|
limit_alphabet: Optional[int] = None,
|
||||||
initial_alphabet: List[str] = [],
|
initial_alphabet: List[str] = [],
|
||||||
continuing_subword_prefix: Optional[str] = None,
|
continuing_subword_prefix: Optional[str] = None,
|
||||||
@@ -36,7 +37,7 @@ class BpeTrainer(Trainer):
|
|||||||
show_progress: boolean:
|
show_progress: boolean:
|
||||||
Whether to show progress bars while training.
|
Whether to show progress bars while training.
|
||||||
|
|
||||||
special_tokens: List[str]:
|
special_tokens: List[Union[str, AddedToken]]:
|
||||||
A list of special tokens the model should know of.
|
A list of special tokens the model should know of.
|
||||||
|
|
||||||
limit_alphabet: unsigned int:
|
limit_alphabet: unsigned int:
|
||||||
@@ -70,7 +71,7 @@ class WordPieceTrainer(Trainer):
|
|||||||
vocab_size: int = 30000,
|
vocab_size: int = 30000,
|
||||||
min_frequency: int = 0,
|
min_frequency: int = 0,
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
special_tokens: List[str] = [],
|
special_tokens: List[Union[str, AddedToken]] = [],
|
||||||
limit_alphabet: Optional[int] = None,
|
limit_alphabet: Optional[int] = None,
|
||||||
initial_alphabet: List[str] = [],
|
initial_alphabet: List[str] = [],
|
||||||
continuing_subword_prefix: Optional[str] = "##",
|
continuing_subword_prefix: Optional[str] = "##",
|
||||||
@@ -88,7 +89,7 @@ class WordPieceTrainer(Trainer):
|
|||||||
show_progress: boolean:
|
show_progress: boolean:
|
||||||
Whether to show progress bars while training.
|
Whether to show progress bars while training.
|
||||||
|
|
||||||
special_tokens: List[str]:
|
special_tokens: List[Union[str, AddedToken]]:
|
||||||
A list of special tokens the model should know of.
|
A list of special tokens the model should know of.
|
||||||
|
|
||||||
limit_alphabet: unsigned int:
|
limit_alphabet: unsigned int:
|
||||||
|
|||||||
Reference in New Issue
Block a user