mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 08:15:49 +00:00
Adding an API for decode streaming. (#1677)
* Adding an API for decode streaming. * Add another missing test case (proving the effect of state.) * Ellide lifetime. * Ellide bis. * Fixing the streaming implementation. * Adding more docs. * End of list. * Fix internal link. * Skip doctest on Windows (no tokenizer file because no make)
This commit is contained in:
@ -12,8 +12,7 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::{read_to_string, File},
|
||||
io::prelude::*,
|
||||
io::BufReader,
|
||||
io::{prelude::*, BufReader},
|
||||
ops::{Deref, DerefMut},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
@ -906,6 +905,189 @@ where
|
||||
Ok(tokens.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode the given ids, back to a String
|
||||
/// See [`DecodeStream`]
|
||||
pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> {
|
||||
DecodeStream::new(self, skip_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
/// DecodeStream will keep the state necessary to produce individual chunks of
|
||||
/// strings given an input stream of token_ids.
|
||||
///
|
||||
/// This is necessary because decoding in general cannot achieve that since strings
|
||||
/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// # #[cfg(not(target_os = "windows"))]
|
||||
/// # {
|
||||
/// use tokenizers::Tokenizer;
|
||||
/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();
|
||||
///
|
||||
/// let mut decode_stream = tokenizer.decode_stream(false);
|
||||
/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
|
||||
/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
|
||||
/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
|
||||
/// assert_eq!(
|
||||
/// decode_stream.step(1246).unwrap(),
|
||||
/// Some(" example".to_string())
|
||||
/// );
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Returning `None` means the given id is not enough to produce a chunk.
|
||||
/// This typically happens with `byte_fallback` options where some tokens do
|
||||
/// not represent valid utf-8, and only follow-up token_ids will help produce
|
||||
/// a valid chunk.
|
||||
/// ```
|
||||
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC};
|
||||
/// use std::collections::HashMap;
|
||||
/// use std::iter::FromIterator;
|
||||
///
|
||||
/// let vocab = HashMap::from_iter([
|
||||
/// ("<0x20>".to_string(), 0),
|
||||
/// ("<0xC3>".to_string(), 1),
|
||||
/// ("<0xA9>".to_string(), 2),
|
||||
/// (" This".to_string(), 3),
|
||||
/// ]);
|
||||
/// let merges = vec![];
|
||||
/// let bpe = BPE::builder()
|
||||
/// .vocab_and_merges(vocab, merges)
|
||||
/// .byte_fallback(true)
|
||||
/// .build()
|
||||
/// .unwrap();
|
||||
/// let tokenizer = TokenizerBuilder::default()
|
||||
/// .with_model(bpe)
|
||||
/// .with_decoder(Some(ByteFallback::default()))
|
||||
/// .with_normalizer(Some(NFC))
|
||||
/// .with_pre_tokenizer(Some(ByteLevel::default()))
|
||||
/// .with_post_processor(Some(ByteLevel::default()))
|
||||
/// .build().unwrap();
|
||||
///
|
||||
/// let mut decode_stream = tokenizer.decode_stream(false);
|
||||
/// // Single byte_fallback is valid utf-8
|
||||
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
|
||||
/// // Invalid utf-8
|
||||
/// assert_eq!(decode_stream.step(1).unwrap(), None);
|
||||
/// // Valid utf-8 again, this corresponds to both tokens: [1, 2]
|
||||
/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
|
||||
/// ```
|
||||
///
|
||||
/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would
|
||||
/// fail.
|
||||
///
|
||||
/// ```
|
||||
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC};
|
||||
/// use std::collections::HashMap;
|
||||
/// use std::iter::FromIterator;
|
||||
///
|
||||
/// let vocab = HashMap::from_iter([
|
||||
/// ("▁This".to_string(), 0),
|
||||
/// ]);
|
||||
/// let merges = vec![];
|
||||
/// let bpe = BPE::builder()
|
||||
/// .vocab_and_merges(vocab, merges)
|
||||
/// .byte_fallback(true)
|
||||
/// .build()
|
||||
/// .unwrap();
|
||||
/// let tokenizer = TokenizerBuilder::new()
|
||||
/// .with_model(bpe)
|
||||
/// .with_decoder(Some(Metaspace::default()))
|
||||
/// .with_normalizer(Some(NFC))
|
||||
/// .with_pre_tokenizer(Some(ByteLevel::default()))
|
||||
/// .with_post_processor(Some(ByteLevel::default()))
|
||||
/// .build()
|
||||
/// .unwrap();
|
||||
///
|
||||
/// // Strip decoder removes the extra initial space
|
||||
/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This");
|
||||
/// // Decoding one token at a time would produce "ThisThis"
|
||||
/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This");
|
||||
///
|
||||
/// // Using a stream fixes it by keeping the necessary state.
|
||||
/// let mut decode_stream = tokenizer.decode_stream(false);
|
||||
/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string()));
|
||||
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string()));
|
||||
/// ```
|
||||
pub struct DecodeStream<'tok, M, N, PT, PP, D> {
|
||||
/// A reference to the tokenizer
|
||||
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
|
||||
/// Regular decode option that is kept throughout.
|
||||
skip_special_tokens: bool,
|
||||
/// A temporary buffer of the necessary token_ids needed
|
||||
/// to produce valid string chunks.
|
||||
/// This typically contains 3 parts:
|
||||
/// - read
|
||||
/// - prefix
|
||||
/// - rest
|
||||
///
|
||||
/// Read is the bit necessary to surround the prefix
|
||||
/// so decoding the whole ids produces a valid prefix.
|
||||
/// Prefix is the previously produced string, kept around to trim off of
|
||||
/// the next valid chunk
|
||||
ids: Vec<u32>,
|
||||
/// The previously returned chunk that needs to be discarded from the
|
||||
/// decoding of the current ids to produce the next chunk
|
||||
prefix: String,
|
||||
/// The index within the ids corresponding to the prefix so we can drain
|
||||
/// correctly
|
||||
prefix_index: usize,
|
||||
/// We need to keep 2 prefixes.
|
||||
/// Prefix is the second one that was already emitted to discard the part
|
||||
/// of the text of all the ids
|
||||
/// read is the prefix kept only for starting side effects of the prefix
|
||||
read_index: usize,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum DecodeStreamError {
|
||||
#[error("Invalid prefix encountered")]
|
||||
InvalidPrefix,
|
||||
}
|
||||
|
||||
impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
|
||||
where
|
||||
M: Model,
|
||||
N: Normalizer,
|
||||
PT: PreTokenizer,
|
||||
PP: PostProcessor,
|
||||
D: Decoder,
|
||||
{
|
||||
fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
ids: vec![],
|
||||
skip_special_tokens,
|
||||
prefix: "".to_string(),
|
||||
prefix_index: 0,
|
||||
read_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// See [`DecodeStream`]
|
||||
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
||||
self.ids.push(id);
|
||||
let string = self
|
||||
.tokenizer
|
||||
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
|
||||
if string.len() > self.prefix.len() && !string.ends_with('<27>') {
|
||||
if !(string.starts_with(&self.prefix)) {
|
||||
return Err(Box::new(DecodeStreamError::InvalidPrefix));
|
||||
}
|
||||
let new_text = &string[self.prefix.len()..].to_string();
|
||||
let new_prefix_index = self.ids.len() - self.prefix_index;
|
||||
self.ids = self.ids.drain(self.read_index..).collect();
|
||||
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
|
||||
self.read_index = self.prefix_index;
|
||||
self.prefix_index = new_prefix_index;
|
||||
Ok(Some(new_text.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
||||
|
@ -1,3 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::iter::FromIterator;
|
||||
|
||||
use tokenizers::decoders::byte_fallback::ByteFallback;
|
||||
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
|
||||
use tokenizers::normalizers::{Sequence, Strip, NFC};
|
||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||
@ -58,6 +62,65 @@ fn load_tokenizer() {
|
||||
assert_eq!(decoded, example);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_tokenizer() {
|
||||
let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();
|
||||
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
|
||||
assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
|
||||
assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
|
||||
assert_eq!(
|
||||
decode_stream.step(1246).unwrap(),
|
||||
Some(" example".to_string())
|
||||
);
|
||||
|
||||
let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
|
||||
let encoded = tokenizer.encode("This is an example", false).unwrap();
|
||||
assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]);
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
// No space anymore
|
||||
assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string()));
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string()));
|
||||
assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string()));
|
||||
assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string()));
|
||||
assert_eq!(
|
||||
decode_stream.step(823).unwrap(),
|
||||
Some(" example".to_string())
|
||||
);
|
||||
|
||||
// None example
|
||||
let vocab = HashMap::from_iter([
|
||||
("<0x20>".to_string(), 0),
|
||||
("<0xC3>".to_string(), 1),
|
||||
("<0xA9>".to_string(), 2),
|
||||
(" This".to_string(), 3),
|
||||
]);
|
||||
let merges = vec![];
|
||||
let bpe = BPE::builder()
|
||||
.vocab_and_merges(vocab, merges)
|
||||
.byte_fallback(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
let tokenizer = TokenizerBuilder::new()
|
||||
.with_model(bpe)
|
||||
.with_normalizer(Some(Sequence::new(vec![
|
||||
Strip::new(true, true).into(),
|
||||
NFC.into(),
|
||||
])))
|
||||
.with_pre_tokenizer(Some(ByteLevel::default()))
|
||||
.with_post_processor(Some(ByteLevel::default()))
|
||||
.with_decoder(Some(ByteFallback::default()))
|
||||
.build()
|
||||
.unwrap();
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
|
||||
assert_eq!(decode_stream.step(1).unwrap(), None);
|
||||
assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
|
||||
assert_eq!(decode_stream.step(2).unwrap(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn quicktour_slow_train() -> tokenizers::Result<()> {
|
||||
|
Reference in New Issue
Block a user