mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-30 20:19:21 +00:00
Rust - encode & encode_batch with add_special_tokens
This commit is contained in:
@ -7,6 +7,8 @@ a high number of files as it avoids having too many progress bars on screen.
|
||||
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
|
||||
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
|
||||
whitespaces are part of the actual token.
|
||||
- `encode` and `encode_batch` now take a new argument, specifying whether we should add the
|
||||
special tokens.
|
||||
|
||||
## Fixes:
|
||||
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
|
||||
|
@ -45,7 +45,7 @@ fn shell(matches: &ArgMatches) -> Result<()> {
|
||||
let buffer = buffer.trim_end();
|
||||
|
||||
let timer = std::time::Instant::now();
|
||||
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()))?;
|
||||
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()), false)?;
|
||||
let elapsed = timer.elapsed();
|
||||
println!("\nInput:\t\t{}", buffer);
|
||||
println!("Tokens:\t\t{:?}", encoded.get_tokens());
|
||||
|
@ -37,7 +37,7 @@
|
||||
//!
|
||||
//! let mut tokenizer = Tokenizer::new(Box::new(bpe));
|
||||
//!
|
||||
//! let encoding = tokenizer.encode(EncodeInput::Single("Hey there!".into()))?;
|
||||
//! let encoding = tokenizer.encode(EncodeInput::Single("Hey there!".into()), false)?;
|
||||
//! println!("{:?}", encoding.get_tokens());
|
||||
//!
|
||||
//! Ok(())
|
||||
|
@ -193,7 +193,12 @@ impl PostProcessor for ByteLevel {
|
||||
0
|
||||
}
|
||||
|
||||
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||
fn process(
|
||||
&self,
|
||||
mut encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
_add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
let process_offsets = |encoding: &mut Encoding| {
|
||||
if !self.trim_offsets {
|
||||
return;
|
||||
@ -423,13 +428,18 @@ mod tests {
|
||||
);
|
||||
|
||||
let bytelevel = ByteLevel::default().trim_offsets(true);
|
||||
assert_eq!(expected, bytelevel.process(start.clone(), None).unwrap());
|
||||
assert_eq!(
|
||||
expected,
|
||||
bytelevel.process(start.clone(), None, false).unwrap()
|
||||
);
|
||||
|
||||
let mut pair_expected = expected.clone();
|
||||
pair_expected.merge_with(expected);
|
||||
assert_eq!(
|
||||
pair_expected,
|
||||
bytelevel.process(start.clone(), Some(start)).unwrap()
|
||||
bytelevel
|
||||
.process(start.clone(), Some(start), false)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,16 @@ impl PostProcessor for BertProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||
fn process(
|
||||
&self,
|
||||
mut encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
if !add_special_tokens {
|
||||
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
|
||||
}
|
||||
|
||||
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
|
||||
let tokens = [
|
||||
|
@ -20,7 +20,16 @@ impl PostProcessor for RobertaProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||
fn process(
|
||||
&self,
|
||||
mut encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
if !add_special_tokens {
|
||||
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
|
||||
}
|
||||
|
||||
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
|
||||
let tokens = [
|
||||
|
@ -55,7 +55,27 @@ pub trait PostProcessor {
|
||||
/// Returns the number of tokens that will be added during the processing step
|
||||
fn added_tokens(&self, is_pair: bool) -> usize;
|
||||
/// Process both encodings and returns a new merged one
|
||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
|
||||
fn process(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding>;
|
||||
}
|
||||
impl dyn PostProcessor {
|
||||
pub fn default_process(
|
||||
mut encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
_add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
match pair_encoding {
|
||||
None => Ok(encoding),
|
||||
Some(pair) => {
|
||||
encoding.merge_with(pair);
|
||||
Ok(encoding)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A `Decoder` has the responsibility to merge the given `Vec<String>` in a `String`.
|
||||
@ -277,7 +297,7 @@ impl Tokenizer {
|
||||
}
|
||||
|
||||
/// Encode the given sentence
|
||||
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
|
||||
pub fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> Result<Encoding> {
|
||||
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
|
||||
// First we need to split into as many sequences as needed to avoid splitting
|
||||
// on our added tokens
|
||||
@ -362,14 +382,18 @@ impl Tokenizer {
|
||||
};
|
||||
|
||||
// 4. Post processing
|
||||
self.post_process(encoding, pair_encoding)
|
||||
self.post_process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
|
||||
/// Encode all the sentences in parallel, using multiple threads
|
||||
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
|
||||
pub fn encode_batch(
|
||||
&self,
|
||||
inputs: Vec<EncodeInput>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
let encodings = inputs
|
||||
.into_par_iter()
|
||||
.map(|input| self.encode(input))
|
||||
.map(|input| self.encode(input, add_special_tokens))
|
||||
.collect::<Result<Vec<Encoding>>>()?;
|
||||
|
||||
if let Some(params) = &self.padding {
|
||||
@ -513,9 +537,10 @@ impl Tokenizer {
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
// 1. First we truncate if needed
|
||||
let (mut encoding, pair_encoding) = {
|
||||
let (encoding, pair_encoding) = {
|
||||
if let Some(trunc) = &self.trunc {
|
||||
let n_added_tokens = if let Some(processor) = &self.post_processor {
|
||||
processor.added_tokens(pair_encoding.is_some())
|
||||
@ -523,7 +548,7 @@ impl Tokenizer {
|
||||
0
|
||||
};
|
||||
|
||||
if n_added_tokens > 0 {
|
||||
if add_special_tokens && n_added_tokens > 0 {
|
||||
let params = TruncationParams {
|
||||
max_length: trunc.max_length - n_added_tokens,
|
||||
..*trunc
|
||||
@ -539,15 +564,9 @@ impl Tokenizer {
|
||||
|
||||
// 2. Then We post process
|
||||
let mut final_encoding = if let Some(processor) = &self.post_processor {
|
||||
processor.process(encoding, pair_encoding)?
|
||||
processor.process(encoding, pair_encoding, add_special_tokens)?
|
||||
} else {
|
||||
match pair_encoding {
|
||||
None => encoding,
|
||||
Some(pair) => {
|
||||
encoding.merge_with(pair);
|
||||
encoding
|
||||
}
|
||||
}
|
||||
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)?
|
||||
};
|
||||
|
||||
// 3. Then we pad if needed
|
||||
|
Reference in New Issue
Block a user