Rust - encode & encode_batch with add_special_tokens

This commit is contained in:
Anthony MOI
2020-03-10 16:10:07 -04:00
parent 523e173ddf
commit d761d406cf
7 changed files with 71 additions and 22 deletions

View File

@ -7,6 +7,8 @@ a high number of files as it avoids having too many progress bars on screen.
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
whitespaces are part of the actual token.
- `encode` and `encode_batch` now take a new argument, specifying whether we should add the
special tokens.
## Fixes:
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:

View File

@ -45,7 +45,7 @@ fn shell(matches: &ArgMatches) -> Result<()> {
let buffer = buffer.trim_end();
let timer = std::time::Instant::now();
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()))?;
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()), false)?;
let elapsed = timer.elapsed();
println!("\nInput:\t\t{}", buffer);
println!("Tokens:\t\t{:?}", encoded.get_tokens());

View File

@ -37,7 +37,7 @@
//!
//! let mut tokenizer = Tokenizer::new(Box::new(bpe));
//!
//! let encoding = tokenizer.encode(EncodeInput::Single("Hey there!".into()))?;
//! let encoding = tokenizer.encode(EncodeInput::Single("Hey there!".into()), false)?;
//! println!("{:?}", encoding.get_tokens());
//!
//! Ok(())

View File

@ -193,7 +193,12 @@ impl PostProcessor for ByteLevel {
0
}
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
fn process(
&self,
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
_add_special_tokens: bool,
) -> Result<Encoding> {
let process_offsets = |encoding: &mut Encoding| {
if !self.trim_offsets {
return;
@ -423,13 +428,18 @@ mod tests {
);
let bytelevel = ByteLevel::default().trim_offsets(true);
assert_eq!(expected, bytelevel.process(start.clone(), None).unwrap());
assert_eq!(
expected,
bytelevel.process(start.clone(), None, false).unwrap()
);
let mut pair_expected = expected.clone();
pair_expected.merge_with(expected);
assert_eq!(
pair_expected,
bytelevel.process(start.clone(), Some(start)).unwrap()
bytelevel
.process(start.clone(), Some(start), false)
.unwrap()
);
}
}

View File

@ -20,7 +20,16 @@ impl PostProcessor for BertProcessing {
}
}
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
fn process(
&self,
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
if !add_special_tokens {
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
}
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
let tokens = [

View File

@ -20,7 +20,16 @@ impl PostProcessor for RobertaProcessing {
}
}
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
fn process(
&self,
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
if !add_special_tokens {
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
}
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
let tokens = [

View File

@ -55,7 +55,27 @@ pub trait PostProcessor {
/// Returns the number of tokens that will be added during the processing step
fn added_tokens(&self, is_pair: bool) -> usize;
/// Process both encodings and returns a new merged one
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
fn process(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding>;
}
impl dyn PostProcessor {
pub fn default_process(
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
_add_special_tokens: bool,
) -> Result<Encoding> {
match pair_encoding {
None => Ok(encoding),
Some(pair) => {
encoding.merge_with(pair);
Ok(encoding)
}
}
}
}
/// A `Decoder` has the responsibility to merge the given `Vec<String>` in a `String`.
@ -277,7 +297,7 @@ impl Tokenizer {
}
/// Encode the given sentence
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
pub fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> Result<Encoding> {
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
// First we need to split into as many sequences as needed to avoid splitting
// on our added tokens
@ -362,14 +382,18 @@ impl Tokenizer {
};
// 4. Post processing
self.post_process(encoding, pair_encoding)
self.post_process(encoding, pair_encoding, add_special_tokens)
}
/// Encode all the sentences in parallel, using multiple threads
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
pub fn encode_batch(
&self,
inputs: Vec<EncodeInput>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
let encodings = inputs
.into_par_iter()
.map(|input| self.encode(input))
.map(|input| self.encode(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
if let Some(params) = &self.padding {
@ -513,9 +537,10 @@ impl Tokenizer {
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
// 1. First we truncate if needed
let (mut encoding, pair_encoding) = {
let (encoding, pair_encoding) = {
if let Some(trunc) = &self.trunc {
let n_added_tokens = if let Some(processor) = &self.post_processor {
processor.added_tokens(pair_encoding.is_some())
@ -523,7 +548,7 @@ impl Tokenizer {
0
};
if n_added_tokens > 0 {
if add_special_tokens && n_added_tokens > 0 {
let params = TruncationParams {
max_length: trunc.max_length - n_added_tokens,
..*trunc
@ -539,15 +564,9 @@ impl Tokenizer {
// 2. Then We post process
let mut final_encoding = if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding)?
processor.process(encoding, pair_encoding, add_special_tokens)?
} else {
match pair_encoding {
None => encoding,
Some(pair) => {
encoding.merge_with(pair);
encoding
}
}
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)?
};
// 3. Then we pad if needed