mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fix cli + whitespace
This commit is contained in:
@ -4,11 +4,11 @@
|
|||||||
|
|
||||||
use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
|
use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
|
||||||
use std::io::{self, BufRead, Write};
|
use std::io::{self, BufRead, Write};
|
||||||
use tokenizers::models::bpe::{Error, BPE};
|
use tokenizers::models::bpe::BPE;
|
||||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||||
use tokenizers::tokenizer::{EncodeInput, Tokenizer};
|
use tokenizers::tokenizer::{EncodeInput, Result, Tokenizer};
|
||||||
|
|
||||||
fn shell(matches: &ArgMatches) -> Result<(), Error> {
|
fn shell(matches: &ArgMatches) -> Result<()> {
|
||||||
let vocab = matches
|
let vocab = matches
|
||||||
.value_of("vocab")
|
.value_of("vocab")
|
||||||
.expect("Must give a vocab.json file");
|
.expect("Must give a vocab.json file");
|
||||||
@ -33,7 +33,7 @@ fn shell(matches: &ArgMatches) -> Result<(), Error> {
|
|||||||
let buffer = buffer.trim_end();
|
let buffer = buffer.trim_end();
|
||||||
|
|
||||||
let timer = std::time::Instant::now();
|
let timer = std::time::Instant::now();
|
||||||
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()));
|
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()))?;
|
||||||
let elapsed = timer.elapsed();
|
let elapsed = timer.elapsed();
|
||||||
println!("\nInput:\t\t{}", buffer);
|
println!("\nInput:\t\t{}", buffer);
|
||||||
println!("Tokens:\t\t{:?}", encoded.get_tokens());
|
println!("Tokens:\t\t{:?}", encoded.get_tokens());
|
||||||
@ -43,7 +43,7 @@ fn shell(matches: &ArgMatches) -> Result<(), Error> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Error> {
|
fn main() -> Result<()> {
|
||||||
let matches = App::new("tokenizers")
|
let matches = App::new("tokenizers")
|
||||||
.version("0.0.1")
|
.version("0.0.1")
|
||||||
.author("Anthony M. <anthony@huggingface.co>")
|
.author("Anthony M. <anthony@huggingface.co>")
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
use crate::tokenizer::PreTokenizer;
|
use crate::tokenizer::{PreTokenizer, Result};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
|
||||||
pub struct Whitespace;
|
pub struct Whitespace;
|
||||||
impl PreTokenizer for Whitespace {
|
impl PreTokenizer for Whitespace {
|
||||||
fn pre_tokenize(&self, s: &str) -> Vec<String> {
|
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref RE: Regex = Regex::new(r"\w+|[^\w\s]+").unwrap();
|
static ref RE: Regex = Regex::new(r"\w+|[^\w\s]+").unwrap();
|
||||||
}
|
}
|
||||||
RE.captures_iter(s)
|
Ok(RE
|
||||||
|
.captures_iter(s)
|
||||||
.map(|captures| {
|
.map(|captures| {
|
||||||
captures
|
captures
|
||||||
.iter()
|
.iter()
|
||||||
@ -17,7 +18,7 @@ impl PreTokenizer for Whitespace {
|
|||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.collect()
|
.collect())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ mod tests {
|
|||||||
];
|
];
|
||||||
let pretok = Whitespace;
|
let pretok = Whitespace;
|
||||||
for (s, res) in tests {
|
for (s, res) in tests {
|
||||||
assert_eq!(pretok.pre_tokenize(s), res);
|
assert_eq!(pretok.pre_tokenize(s).unwrap(), res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user