mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Use Self where possible, some minor refactoring (#923)
* Use Self where possible, some minor refactoring * fixed test * fixed n_sequences * reverted non-Self changes
This commit is contained in:
@ -13,13 +13,13 @@ pub struct BPEDecoder {
|
||||
|
||||
impl BPEDecoder {
|
||||
pub fn new(suffix: String) -> Self {
|
||||
BPEDecoder { suffix }
|
||||
Self { suffix }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BPEDecoder {
|
||||
fn default() -> Self {
|
||||
BPEDecoder::new("</w>".into())
|
||||
Self::new("</w>".into())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,11 +28,11 @@ pub enum DecoderWrapper {
|
||||
impl Decoder for DecoderWrapper {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
match self {
|
||||
DecoderWrapper::BPE(bpe) => bpe.decode(tokens),
|
||||
DecoderWrapper::ByteLevel(bl) => bl.decode(tokens),
|
||||
DecoderWrapper::Metaspace(ms) => ms.decode(tokens),
|
||||
DecoderWrapper::WordPiece(wp) => wp.decode(tokens),
|
||||
DecoderWrapper::CTC(ctc) => ctc.decode(tokens),
|
||||
Self::BPE(bpe) => bpe.decode(tokens),
|
||||
Self::ByteLevel(bl) => bl.decode(tokens),
|
||||
Self::Metaspace(ms) => ms.decode(tokens),
|
||||
Self::WordPiece(wp) => wp.decode(tokens),
|
||||
Self::CTC(ctc) => ctc.decode(tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ impl WordPiece {
|
||||
impl Default for WordPiece {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
prefix: String::from("##"),
|
||||
prefix: "##".to_owned(),
|
||||
cleanup: true,
|
||||
}
|
||||
}
|
||||
|
@ -30,30 +30,30 @@ pub enum Error {
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(error: io::Error) -> Self {
|
||||
Error::Io(error)
|
||||
Self::Io(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(error: serde_json::Error) -> Self {
|
||||
Error::JsonError(error)
|
||||
Self::JsonError(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Error::Io(e) => write!(f, "IoError: {}", e),
|
||||
Error::JsonError(e) => write!(f, "JsonError: {}", e),
|
||||
Error::BadVocabulary => write!(f, "Bad vocabulary json file"),
|
||||
Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
|
||||
Error::MergeTokenOutOfVocabulary(token) => {
|
||||
Self::Io(e) => write!(f, "IoError: {}", e),
|
||||
Self::JsonError(e) => write!(f, "JsonError: {}", e),
|
||||
Self::BadVocabulary => write!(f, "Bad vocabulary json file"),
|
||||
Self::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
|
||||
Self::MergeTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Token `{}` out of vocabulary", token)
|
||||
}
|
||||
Error::UnkTokenOutOfVocabulary(token) => {
|
||||
Self::UnkTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Unk token `{}` not found in the vocabulary", token)
|
||||
}
|
||||
Error::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
|
||||
Self::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -61,8 +61,8 @@ impl std::fmt::Display for Error {
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Error::Io(e) => Some(e),
|
||||
Error::JsonError(e) => Some(e),
|
||||
Self::Io(e) => Some(e),
|
||||
Self::JsonError(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ impl BPE {
|
||||
|
||||
/// Initialize a BpeBuilder model from vocab and merges files
|
||||
pub fn from_file(vocab: &str, merges: &str) -> BpeBuilder {
|
||||
BPE::builder().files(vocab.to_owned(), merges.to_owned())
|
||||
Self::builder().files(vocab.to_owned(), merges.to_owned())
|
||||
}
|
||||
|
||||
/// Read the given files to extract the vocab and merges
|
||||
|
@ -79,7 +79,7 @@ impl Word {
|
||||
}
|
||||
|
||||
pub(super) fn with_capacity(capacity: usize) -> Self {
|
||||
Word {
|
||||
Self {
|
||||
symbols: Vec::with_capacity(capacity),
|
||||
}
|
||||
}
|
||||
|
@ -58,72 +58,65 @@ impl Model for ModelWrapper {
|
||||
type Trainer = TrainerWrapper;
|
||||
|
||||
fn tokenize(&self, tokens: &str) -> Result<Vec<Token>> {
|
||||
use ModelWrapper::*;
|
||||
match self {
|
||||
WordLevel(t) => t.tokenize(tokens),
|
||||
WordPiece(t) => t.tokenize(tokens),
|
||||
BPE(t) => t.tokenize(tokens),
|
||||
Unigram(t) => t.tokenize(tokens),
|
||||
Self::WordLevel(t) => t.tokenize(tokens),
|
||||
Self::WordPiece(t) => t.tokenize(tokens),
|
||||
Self::BPE(t) => t.tokenize(tokens),
|
||||
Self::Unigram(t) => t.tokenize(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
use ModelWrapper::*;
|
||||
match self {
|
||||
WordLevel(t) => t.token_to_id(token),
|
||||
WordPiece(t) => t.token_to_id(token),
|
||||
BPE(t) => t.token_to_id(token),
|
||||
Unigram(t) => t.token_to_id(token),
|
||||
Self::WordLevel(t) => t.token_to_id(token),
|
||||
Self::WordPiece(t) => t.token_to_id(token),
|
||||
Self::BPE(t) => t.token_to_id(token),
|
||||
Self::Unigram(t) => t.token_to_id(token),
|
||||
}
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
use ModelWrapper::*;
|
||||
match self {
|
||||
WordLevel(t) => t.id_to_token(id),
|
||||
WordPiece(t) => t.id_to_token(id),
|
||||
BPE(t) => t.id_to_token(id),
|
||||
Unigram(t) => t.id_to_token(id),
|
||||
Self::WordLevel(t) => t.id_to_token(id),
|
||||
Self::WordPiece(t) => t.id_to_token(id),
|
||||
Self::BPE(t) => t.id_to_token(id),
|
||||
Self::Unigram(t) => t.id_to_token(id),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_vocab(&self) -> HashMap<String, u32> {
|
||||
use ModelWrapper::*;
|
||||
match self {
|
||||
WordLevel(t) => t.get_vocab(),
|
||||
WordPiece(t) => t.get_vocab(),
|
||||
BPE(t) => t.get_vocab(),
|
||||
Unigram(t) => t.get_vocab(),
|
||||
Self::WordLevel(t) => t.get_vocab(),
|
||||
Self::WordPiece(t) => t.get_vocab(),
|
||||
Self::BPE(t) => t.get_vocab(),
|
||||
Self::Unigram(t) => t.get_vocab(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_vocab_size(&self) -> usize {
|
||||
use ModelWrapper::*;
|
||||
match self {
|
||||
WordLevel(t) => t.get_vocab_size(),
|
||||
WordPiece(t) => t.get_vocab_size(),
|
||||
BPE(t) => t.get_vocab_size(),
|
||||
Unigram(t) => t.get_vocab_size(),
|
||||
Self::WordLevel(t) => t.get_vocab_size(),
|
||||
Self::WordPiece(t) => t.get_vocab_size(),
|
||||
Self::BPE(t) => t.get_vocab_size(),
|
||||
Self::Unigram(t) => t.get_vocab_size(),
|
||||
}
|
||||
}
|
||||
|
||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
||||
use ModelWrapper::*;
|
||||
match self {
|
||||
WordLevel(t) => t.save(folder, name),
|
||||
WordPiece(t) => t.save(folder, name),
|
||||
BPE(t) => t.save(folder, name),
|
||||
Unigram(t) => t.save(folder, name),
|
||||
Self::WordLevel(t) => t.save(folder, name),
|
||||
Self::WordPiece(t) => t.save(folder, name),
|
||||
Self::BPE(t) => t.save(folder, name),
|
||||
Self::Unigram(t) => t.save(folder, name),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_trainer(&self) -> Self::Trainer {
|
||||
use ModelWrapper::*;
|
||||
match self {
|
||||
WordLevel(t) => t.get_trainer().into(),
|
||||
WordPiece(t) => t.get_trainer().into(),
|
||||
BPE(t) => t.get_trainer().into(),
|
||||
Unigram(t) => t.get_trainer().into(),
|
||||
Self::WordLevel(t) => t.get_trainer().into(),
|
||||
Self::WordPiece(t) => t.get_trainer().into(),
|
||||
Self::BPE(t) => t.get_trainer().into(),
|
||||
Self::Unigram(t) => t.get_trainer().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -140,28 +133,28 @@ impl Trainer for TrainerWrapper {
|
||||
|
||||
fn should_show_progress(&self) -> bool {
|
||||
match self {
|
||||
TrainerWrapper::BpeTrainer(bpe) => bpe.should_show_progress(),
|
||||
TrainerWrapper::WordPieceTrainer(wpt) => wpt.should_show_progress(),
|
||||
TrainerWrapper::WordLevelTrainer(wpt) => wpt.should_show_progress(),
|
||||
TrainerWrapper::UnigramTrainer(wpt) => wpt.should_show_progress(),
|
||||
Self::BpeTrainer(bpe) => bpe.should_show_progress(),
|
||||
Self::WordPieceTrainer(wpt) => wpt.should_show_progress(),
|
||||
Self::WordLevelTrainer(wpt) => wpt.should_show_progress(),
|
||||
Self::UnigramTrainer(wpt) => wpt.should_show_progress(),
|
||||
}
|
||||
}
|
||||
|
||||
fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
|
||||
match self {
|
||||
TrainerWrapper::BpeTrainer(t) => match model {
|
||||
Self::BpeTrainer(t) => match model {
|
||||
ModelWrapper::BPE(bpe) => t.train(bpe),
|
||||
_ => Err("BpeTrainer can only train a BPE".into()),
|
||||
},
|
||||
TrainerWrapper::WordPieceTrainer(t) => match model {
|
||||
Self::WordPieceTrainer(t) => match model {
|
||||
ModelWrapper::WordPiece(wp) => t.train(wp),
|
||||
_ => Err("WordPieceTrainer can only train a WordPiece".into()),
|
||||
},
|
||||
TrainerWrapper::WordLevelTrainer(t) => match model {
|
||||
Self::WordLevelTrainer(t) => match model {
|
||||
ModelWrapper::WordLevel(wl) => t.train(wl),
|
||||
_ => Err("WordLevelTrainer can only train a WordLevel".into()),
|
||||
},
|
||||
TrainerWrapper::UnigramTrainer(t) => match model {
|
||||
Self::UnigramTrainer(t) => match model {
|
||||
ModelWrapper::Unigram(u) => t.train(u),
|
||||
_ => Err("UnigramTrainer can only train a Unigram".into()),
|
||||
},
|
||||
@ -175,10 +168,10 @@ impl Trainer for TrainerWrapper {
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
match self {
|
||||
TrainerWrapper::BpeTrainer(bpe) => bpe.feed(iterator, process),
|
||||
TrainerWrapper::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
|
||||
TrainerWrapper::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
|
||||
TrainerWrapper::UnigramTrainer(wpt) => wpt.feed(iterator, process),
|
||||
Self::BpeTrainer(bpe) => bpe.feed(iterator, process),
|
||||
Self::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
|
||||
Self::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
|
||||
Self::UnigramTrainer(wpt) => wpt.feed(iterator, process),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,8 +16,8 @@ struct Hypothesis {
|
||||
gx: f64,
|
||||
}
|
||||
impl Hypothesis {
|
||||
pub fn new(node_ref: NodeRef, next: Option<HypothesisRef>, fx: f64, gx: f64) -> Hypothesis {
|
||||
Hypothesis {
|
||||
pub fn new(node_ref: NodeRef, next: Option<HypothesisRef>, fx: f64, gx: f64) -> Self {
|
||||
Self {
|
||||
node_ref,
|
||||
next,
|
||||
fx,
|
||||
@ -26,19 +26,19 @@ impl Hypothesis {
|
||||
}
|
||||
}
|
||||
impl PartialEq for Hypothesis {
|
||||
fn eq(&self, other: &Hypothesis) -> bool {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.fx == other.fx
|
||||
}
|
||||
}
|
||||
impl Eq for Hypothesis {}
|
||||
impl PartialOrd for Hypothesis {
|
||||
fn partial_cmp(&self, other: &Hypothesis) -> Option<Ordering> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
// TODO Maybe use Ordered Floats (https://docs.rs/ordered-float/1.0.2/ordered_float/)
|
||||
impl Ord for Hypothesis {
|
||||
fn cmp(&self, other: &Hypothesis) -> Ordering {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
if self.fx < other.fx {
|
||||
Ordering::Less
|
||||
} else {
|
||||
@ -102,8 +102,8 @@ impl PartialEq for Node {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn new(id: usize, node_id: usize, pos: usize, length: usize, score: f64) -> Node {
|
||||
Node {
|
||||
pub fn new(id: usize, node_id: usize, pos: usize, length: usize, score: f64) -> Self {
|
||||
Self {
|
||||
id,
|
||||
node_id,
|
||||
pos,
|
||||
@ -135,7 +135,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
|
||||
}
|
||||
|
||||
impl<'a> Lattice<'a> {
|
||||
pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Lattice<'a> {
|
||||
pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Self {
|
||||
let len = sentence.bytes().count();
|
||||
let k_reserved_node_size = 16;
|
||||
// We are adding 2 tokens, bos and eos
|
||||
@ -152,7 +152,7 @@ impl<'a> Lattice<'a> {
|
||||
nodes.push(bos);
|
||||
nodes.push(eos);
|
||||
|
||||
Lattice {
|
||||
Self {
|
||||
sentence,
|
||||
len,
|
||||
nodes,
|
||||
|
@ -75,13 +75,13 @@ pub enum UnigramError {
|
||||
impl std::fmt::Display for UnigramError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
UnigramError::EmptyVocabulary => {
|
||||
Self::EmptyVocabulary => {
|
||||
write!(f, "The vocabulary is empty but at least <unk> is needed")
|
||||
}
|
||||
UnigramError::UnkIdNotInVocabulary => {
|
||||
Self::UnkIdNotInVocabulary => {
|
||||
write!(f, "The `unk_id` is larger than vocabulary size")
|
||||
}
|
||||
UnigramError::MissingUnkId => {
|
||||
Self::MissingUnkId => {
|
||||
write!(f, "Encountered an unknown token but `unk_id` is missing")
|
||||
}
|
||||
}
|
||||
@ -134,7 +134,7 @@ impl Unigram {
|
||||
let fuse_unk = true;
|
||||
let is_optimized = true;
|
||||
|
||||
Ok(Unigram {
|
||||
Ok(Self {
|
||||
vocab,
|
||||
token_to_ids,
|
||||
trie,
|
||||
|
@ -69,7 +69,7 @@ where
|
||||
|
||||
impl<Label> Default for Trie<Label> {
|
||||
fn default() -> Self {
|
||||
Trie {
|
||||
Self {
|
||||
root: Node::default(),
|
||||
}
|
||||
}
|
||||
@ -83,7 +83,7 @@ pub struct Node<Label> {
|
||||
|
||||
impl<Label> Default for Node<Label> {
|
||||
fn default() -> Self {
|
||||
Node {
|
||||
Self {
|
||||
is_leaf: false,
|
||||
children: HashMap::new(),
|
||||
}
|
||||
|
@ -25,11 +25,11 @@ impl std::error::Error for Error {}
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Error::MissingUnkToken => write!(
|
||||
Self::MissingUnkToken => write!(
|
||||
fmt,
|
||||
"WordLevel error: Missing [UNK] token from the vocabulary"
|
||||
),
|
||||
Error::BadVocabulary => write!(fmt, "Bad vocabulary json file"),
|
||||
Self::BadVocabulary => write!(fmt, "Bad vocabulary json file"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ impl std::error::Error for Error {}
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Error::MissingUnkToken => write!(
|
||||
Self::MissingUnkToken => write!(
|
||||
fmt,
|
||||
"WordPiece error: Missing [UNK] token from the vocabulary"
|
||||
),
|
||||
|
@ -6,23 +6,21 @@ use unicode_categories::UnicodeCategories;
|
||||
/// Checks whether a character is whitespace
|
||||
fn is_whitespace(c: char) -> bool {
|
||||
// These are technically control characters but we count them as whitespace
|
||||
if c == '\t' || c == '\n' || c == '\r' {
|
||||
true
|
||||
} else {
|
||||
c.is_whitespace()
|
||||
match c {
|
||||
'\t' | '\n' | '\r' => true,
|
||||
_ => c.is_whitespace(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks whether a character is a control character
|
||||
fn is_control(c: char) -> bool {
|
||||
// These are technically control characters but we count them as whitespace
|
||||
if c == '\t' || c == '\n' || c == '\r' {
|
||||
false
|
||||
} else {
|
||||
match c {
|
||||
'\t' | '\n' | '\r' => false,
|
||||
// The definition of `is_control` here is quite large and contains also
|
||||
// Cc, Cf, Cn or Co
|
||||
// cf. https://unicode.org/reports/tr44/ (Table 12)
|
||||
c.is_other()
|
||||
_ => c.is_other(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -83,7 +81,7 @@ impl BertNormalizer {
|
||||
strip_accents: Option<bool>,
|
||||
lowercase: bool,
|
||||
) -> Self {
|
||||
BertNormalizer {
|
||||
Self {
|
||||
clean_text,
|
||||
handle_chinese_chars,
|
||||
strip_accents,
|
||||
|
@ -37,18 +37,18 @@ pub enum NormalizerWrapper {
|
||||
impl Normalizer for NormalizerWrapper {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
|
||||
match self {
|
||||
NormalizerWrapper::BertNormalizer(bn) => bn.normalize(normalized),
|
||||
NormalizerWrapper::StripNormalizer(sn) => sn.normalize(normalized),
|
||||
NormalizerWrapper::StripAccents(sn) => sn.normalize(normalized),
|
||||
NormalizerWrapper::NFC(nfc) => nfc.normalize(normalized),
|
||||
NormalizerWrapper::NFD(nfd) => nfd.normalize(normalized),
|
||||
NormalizerWrapper::NFKC(nfkc) => nfkc.normalize(normalized),
|
||||
NormalizerWrapper::NFKD(nfkd) => nfkd.normalize(normalized),
|
||||
NormalizerWrapper::Sequence(sequence) => sequence.normalize(normalized),
|
||||
NormalizerWrapper::Lowercase(lc) => lc.normalize(normalized),
|
||||
NormalizerWrapper::Nmt(lc) => lc.normalize(normalized),
|
||||
NormalizerWrapper::Precompiled(lc) => lc.normalize(normalized),
|
||||
NormalizerWrapper::Replace(lc) => lc.normalize(normalized),
|
||||
Self::BertNormalizer(bn) => bn.normalize(normalized),
|
||||
Self::StripNormalizer(sn) => sn.normalize(normalized),
|
||||
Self::StripAccents(sn) => sn.normalize(normalized),
|
||||
Self::NFC(nfc) => nfc.normalize(normalized),
|
||||
Self::NFD(nfd) => nfd.normalize(normalized),
|
||||
Self::NFKC(nfkc) => nfkc.normalize(normalized),
|
||||
Self::NFKD(nfkd) => nfkd.normalize(normalized),
|
||||
Self::Sequence(sequence) => sequence.normalize(normalized),
|
||||
Self::Lowercase(lc) => lc.normalize(normalized),
|
||||
Self::Nmt(lc) => lc.normalize(normalized),
|
||||
Self::Precompiled(lc) => lc.normalize(normalized),
|
||||
Self::Replace(lc) => lc.normalize(normalized),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,13 +11,13 @@ pub enum ReplacePattern {
|
||||
|
||||
impl From<String> for ReplacePattern {
|
||||
fn from(v: String) -> Self {
|
||||
ReplacePattern::String(v)
|
||||
Self::String(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ReplacePattern {
|
||||
fn from(v: &str) -> Self {
|
||||
ReplacePattern::String(v.to_owned())
|
||||
Self::String(v.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
@ -34,7 +34,7 @@ impl std::convert::TryFrom<ReplaceDeserializer> for Replace {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from(v: ReplaceDeserializer) -> Result<Self> {
|
||||
Replace::new(v.pattern, v.content)
|
||||
Self::new(v.pattern, v.content)
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,12 +51,12 @@ pub struct Replace {
|
||||
|
||||
impl Clone for Replace {
|
||||
fn clone(&self) -> Self {
|
||||
Replace::new(self.pattern.clone(), &self.content).unwrap()
|
||||
Self::new(self.pattern.clone(), &self.content).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Replace {
|
||||
fn eq(&self, other: &Replace) -> bool {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.pattern == other.pattern && self.content == other.content
|
||||
}
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ impl Default for ByteLevel {
|
||||
|
||||
impl ByteLevel {
|
||||
pub fn new(add_prefix_space: bool, trim_offsets: bool) -> Self {
|
||||
ByteLevel {
|
||||
Self {
|
||||
add_prefix_space,
|
||||
trim_offsets,
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ pub struct CharDelimiterSplit {
|
||||
|
||||
impl CharDelimiterSplit {
|
||||
pub fn new(delimiter: char) -> Self {
|
||||
CharDelimiterSplit { delimiter }
|
||||
Self { delimiter }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ impl<'de> Deserialize<'de> for Metaspace {
|
||||
}
|
||||
|
||||
let helper = MetaspaceHelper::deserialize(deserializer)?;
|
||||
Ok(Metaspace::new(helper.replacement, helper.add_prefix_space))
|
||||
Ok(Self::new(helper.replacement, helper.add_prefix_space))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,17 +42,17 @@ pub enum PreTokenizerWrapper {
|
||||
impl PreTokenizer for PreTokenizerWrapper {
|
||||
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> crate::Result<()> {
|
||||
match self {
|
||||
PreTokenizerWrapper::BertPreTokenizer(bpt) => bpt.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::ByteLevel(bpt) => bpt.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::Delimiter(dpt) => dpt.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::Metaspace(mspt) => mspt.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::Whitespace(wspt) => wspt.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::Punctuation(tok) => tok.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::Sequence(tok) => tok.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::Split(tok) => tok.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::Digits(wspt) => wspt.pre_tokenize(normalized),
|
||||
PreTokenizerWrapper::UnicodeScripts(us) => us.pre_tokenize(normalized),
|
||||
Self::BertPreTokenizer(bpt) => bpt.pre_tokenize(normalized),
|
||||
Self::ByteLevel(bpt) => bpt.pre_tokenize(normalized),
|
||||
Self::Delimiter(dpt) => dpt.pre_tokenize(normalized),
|
||||
Self::Metaspace(mspt) => mspt.pre_tokenize(normalized),
|
||||
Self::Whitespace(wspt) => wspt.pre_tokenize(normalized),
|
||||
Self::Punctuation(tok) => tok.pre_tokenize(normalized),
|
||||
Self::Sequence(tok) => tok.pre_tokenize(normalized),
|
||||
Self::Split(tok) => tok.pre_tokenize(normalized),
|
||||
Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
|
||||
Self::Digits(wspt) => wspt.pre_tokenize(normalized),
|
||||
Self::UnicodeScripts(us) => us.pre_tokenize(normalized),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,13 +14,13 @@ pub enum SplitPattern {
|
||||
|
||||
impl From<String> for SplitPattern {
|
||||
fn from(v: String) -> Self {
|
||||
SplitPattern::String(v)
|
||||
Self::String(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for SplitPattern {
|
||||
fn from(v: &str) -> Self {
|
||||
SplitPattern::String(v.to_owned())
|
||||
Self::String(v.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,18 +54,18 @@ impl<'de> Deserialize<'de> for Split {
|
||||
}
|
||||
|
||||
let helper = SplitHelper::deserialize(deserializer)?;
|
||||
Split::new(helper.pattern, helper.behavior, helper.invert).map_err(serde::de::Error::custom)
|
||||
Self::new(helper.pattern, helper.behavior, helper.invert).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Split {
|
||||
fn clone(&self) -> Self {
|
||||
Split::new(self.pattern.clone(), self.behavior, self.invert).unwrap()
|
||||
Self::new(self.pattern.clone(), self.behavior, self.invert).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Split {
|
||||
fn eq(&self, other: &Split) -> bool {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.pattern == other.pattern
|
||||
&& self.behavior == other.behavior
|
||||
&& self.invert == other.invert
|
||||
|
@ -21,7 +21,7 @@ impl Default for BertProcessing {
|
||||
|
||||
impl BertProcessing {
|
||||
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
|
||||
BertProcessing { sep, cls }
|
||||
Self { sep, cls }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,10 +26,10 @@ pub enum PostProcessorWrapper {
|
||||
impl PostProcessor for PostProcessorWrapper {
|
||||
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||
match self {
|
||||
PostProcessorWrapper::Bert(bert) => bert.added_tokens(is_pair),
|
||||
PostProcessorWrapper::ByteLevel(bl) => bl.added_tokens(is_pair),
|
||||
PostProcessorWrapper::Roberta(roberta) => roberta.added_tokens(is_pair),
|
||||
PostProcessorWrapper::Template(template) => template.added_tokens(is_pair),
|
||||
Self::Bert(bert) => bert.added_tokens(is_pair),
|
||||
Self::ByteLevel(bl) => bl.added_tokens(is_pair),
|
||||
Self::Roberta(roberta) => roberta.added_tokens(is_pair),
|
||||
Self::Template(template) => template.added_tokens(is_pair),
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,16 +40,10 @@ impl PostProcessor for PostProcessorWrapper {
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
match self {
|
||||
PostProcessorWrapper::Bert(bert) => {
|
||||
bert.process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
PostProcessorWrapper::ByteLevel(bl) => {
|
||||
bl.process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
PostProcessorWrapper::Roberta(roberta) => {
|
||||
roberta.process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
PostProcessorWrapper::Template(template) => {
|
||||
Self::Bert(bert) => bert.process(encoding, pair_encoding, add_special_tokens),
|
||||
Self::ByteLevel(bl) => bl.process(encoding, pair_encoding, add_special_tokens),
|
||||
Self::Roberta(roberta) => roberta.process(encoding, pair_encoding, add_special_tokens),
|
||||
Self::Template(template) => {
|
||||
template.process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ impl Default for RobertaProcessing {
|
||||
|
||||
impl RobertaProcessing {
|
||||
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
|
||||
RobertaProcessing {
|
||||
Self {
|
||||
sep,
|
||||
cls,
|
||||
..Default::default()
|
||||
|
@ -98,7 +98,7 @@ pub enum Piece {
|
||||
}
|
||||
|
||||
impl Piece {
|
||||
fn extract_id(s: &str) -> Option<Piece> {
|
||||
fn extract_id(s: &str) -> Option<Self> {
|
||||
if s.starts_with('$') {
|
||||
let rest = &s['$'.len_utf8()..];
|
||||
|
||||
@ -135,10 +135,10 @@ impl Piece {
|
||||
}
|
||||
}
|
||||
|
||||
fn with_type_id(self, type_id: u32) -> Piece {
|
||||
fn with_type_id(self, type_id: u32) -> Self {
|
||||
match self {
|
||||
Piece::Sequence { id, .. } => Piece::Sequence { id, type_id },
|
||||
Piece::SpecialToken { id, .. } => Piece::SpecialToken { id, type_id },
|
||||
Self::Sequence { id, .. } => Self::Sequence { id, type_id },
|
||||
Self::SpecialToken { id, .. } => Self::SpecialToken { id, type_id },
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -153,10 +153,10 @@ impl TryFrom<String> for Piece {
|
||||
match parts.as_slice() {
|
||||
[id, type_id] => {
|
||||
let type_id: u32 = type_id.parse().map_err(|_| err())?;
|
||||
let piece = Piece::extract_id(id).ok_or_else(err)?;
|
||||
let piece = Self::extract_id(id).ok_or_else(err)?;
|
||||
Ok(piece.with_type_id(type_id))
|
||||
}
|
||||
[id] => Piece::extract_id(id).ok_or_else(err),
|
||||
[id] => Self::extract_id(id).ok_or_else(err),
|
||||
_ => Err(err()),
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ impl AddedToken {
|
||||
/// Build this token from the given content, specifying if it is intented to be a
|
||||
/// special token. Special tokens are not normalized by default.
|
||||
pub fn from<S: Into<String>>(content: S, special: bool) -> Self {
|
||||
AddedToken {
|
||||
Self {
|
||||
content: content.into(),
|
||||
normalized: !special,
|
||||
special,
|
||||
@ -69,7 +69,7 @@ impl AddedToken {
|
||||
}
|
||||
impl Default for AddedToken {
|
||||
fn default() -> Self {
|
||||
AddedToken {
|
||||
Self {
|
||||
content: String::new(),
|
||||
single_word: false,
|
||||
lstrip: false,
|
||||
|
@ -39,10 +39,10 @@ impl Encoding {
|
||||
offsets: Vec<Offsets>,
|
||||
special_tokens_mask: Vec<u32>,
|
||||
attention_mask: Vec<u32>,
|
||||
overflowing: Vec<Encoding>,
|
||||
overflowing: Vec<Self>,
|
||||
sequence_ranges: HashMap<usize, Range<usize>>,
|
||||
) -> Self {
|
||||
Encoding {
|
||||
Self {
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
@ -56,7 +56,7 @@ impl Encoding {
|
||||
}
|
||||
|
||||
pub fn with_capacity(len: usize) -> Self {
|
||||
Encoding {
|
||||
Self {
|
||||
ids: Vec::with_capacity(len),
|
||||
type_ids: Vec::with_capacity(len),
|
||||
tokens: Vec::with_capacity(len),
|
||||
@ -85,7 +85,7 @@ impl Encoding {
|
||||
},
|
||||
);
|
||||
|
||||
Encoding {
|
||||
Self {
|
||||
ids,
|
||||
tokens,
|
||||
offsets,
|
||||
|
@ -151,7 +151,7 @@ pub struct Token {
|
||||
}
|
||||
impl Token {
|
||||
pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self {
|
||||
Token { id, value, offsets }
|
||||
Self { id, value, offsets }
|
||||
}
|
||||
}
|
||||
|
||||
@ -166,55 +166,55 @@ pub enum InputSequence<'s> {
|
||||
|
||||
impl<'s> From<Cow<'s, str>> for InputSequence<'s> {
|
||||
fn from(input: Cow<'s, str>) -> Self {
|
||||
InputSequence::Raw(input)
|
||||
Self::Raw(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<&'s str> for InputSequence<'s> {
|
||||
fn from(input: &'s str) -> Self {
|
||||
InputSequence::Raw(Cow::Borrowed(input))
|
||||
Self::Raw(Cow::Borrowed(input))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for InputSequence<'_> {
|
||||
fn from(input: String) -> Self {
|
||||
InputSequence::Raw(Cow::Owned(input))
|
||||
Self::Raw(Cow::Owned(input))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<&'s [&'s str]> for InputSequence<'s> {
|
||||
fn from(input: &'s [&'s str]) -> Self {
|
||||
InputSequence::PreTokenized(Cow::Borrowed(input))
|
||||
Self::PreTokenized(Cow::Borrowed(input))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<Vec<&'s str>> for InputSequence<'s> {
|
||||
fn from(input: Vec<&'s str>) -> Self {
|
||||
InputSequence::PreTokenized(Cow::Owned(input))
|
||||
Self::PreTokenized(Cow::Owned(input))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<&'s [String]> for InputSequence<'s> {
|
||||
fn from(input: &'s [String]) -> Self {
|
||||
InputSequence::PreTokenizedOwned(Cow::Borrowed(input))
|
||||
Self::PreTokenizedOwned(Cow::Borrowed(input))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<Vec<String>> for InputSequence<'s> {
|
||||
fn from(input: Vec<String>) -> Self {
|
||||
InputSequence::PreTokenizedOwned(Cow::Owned(input))
|
||||
Self::PreTokenizedOwned(Cow::Owned(input))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<Vec<Cow<'s, str>>> for InputSequence<'s> {
|
||||
fn from(input: Vec<Cow<'s, str>>) -> Self {
|
||||
InputSequence::PreTokenizedCow(Cow::Owned(input))
|
||||
Self::PreTokenizedCow(Cow::Owned(input))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<&'s [Cow<'s, str>]> for InputSequence<'s> {
|
||||
fn from(input: &'s [Cow<'s, str>]) -> Self {
|
||||
InputSequence::PreTokenizedCow(Cow::Borrowed(input))
|
||||
Self::PreTokenizedCow(Cow::Borrowed(input))
|
||||
}
|
||||
}
|
||||
|
||||
@ -226,7 +226,7 @@ pub enum EncodeInput<'s> {
|
||||
|
||||
impl<'s, I: Into<InputSequence<'s>>> From<I> for EncodeInput<'s> {
|
||||
fn from(input: I) -> Self {
|
||||
EncodeInput::Single(input.into())
|
||||
Self::Single(input.into())
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,7 +236,7 @@ where
|
||||
I2: Into<InputSequence<'s>>,
|
||||
{
|
||||
fn from(input: (I1, I2)) -> Self {
|
||||
EncodeInput::Dual(input.0.into(), input.1.into())
|
||||
Self::Dual(input.0.into(), input.1.into())
|
||||
}
|
||||
}
|
||||
|
||||
@ -290,7 +290,7 @@ where
|
||||
{
|
||||
/// Get an empty TokenizerBuilder.
|
||||
pub fn new() -> Self {
|
||||
TokenizerBuilder {
|
||||
Self {
|
||||
model: None,
|
||||
normalizer: None,
|
||||
pre_tokenizer: None,
|
||||
@ -494,7 +494,7 @@ where
|
||||
{
|
||||
/// Instantiate a new Tokenizer, with the given Model
|
||||
pub fn new(model: M) -> Self {
|
||||
TokenizerImpl {
|
||||
Self {
|
||||
normalizer: None,
|
||||
pre_tokenizer: None,
|
||||
model,
|
||||
|
@ -41,8 +41,8 @@ where
|
||||
/// Unwrap the underlying range
|
||||
pub fn unwrap(self) -> T {
|
||||
match self {
|
||||
Range::Original(r) => r,
|
||||
Range::Normalized(r) => r,
|
||||
Self::Original(r) => r,
|
||||
Self::Normalized(r) => r,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,16 +76,16 @@ pub enum TruncationStrategy {
|
||||
|
||||
impl Default for TruncationStrategy {
|
||||
fn default() -> Self {
|
||||
TruncationStrategy::LongestFirst
|
||||
Self::LongestFirst
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for TruncationStrategy {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
TruncationStrategy::LongestFirst => "longest_first",
|
||||
TruncationStrategy::OnlyFirst => "only_first",
|
||||
TruncationStrategy::OnlySecond => "only_second",
|
||||
Self::LongestFirst => "longest_first",
|
||||
Self::OnlyFirst => "only_first",
|
||||
Self::OnlySecond => "only_second",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user