mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding an API for decode streaming. (#1677)
* Adding an API for decode streaming. * Add another missing test case (proving the effect of state.) * Ellide lifetime. * Ellide bis. * Fixing the streaming implementation. * Adding more docs. * End of list. * Fix internal link. * Skip doctest on Windows (no tokenizer file because no make)
This commit is contained in:
@ -12,8 +12,7 @@
|
|||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
fs::{read_to_string, File},
|
fs::{read_to_string, File},
|
||||||
io::prelude::*,
|
io::{prelude::*, BufReader},
|
||||||
io::BufReader,
|
|
||||||
ops::{Deref, DerefMut},
|
ops::{Deref, DerefMut},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
@ -906,6 +905,189 @@ where
|
|||||||
Ok(tokens.join(" "))
|
Ok(tokens.join(" "))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Decode the given ids, back to a String
|
||||||
|
/// See [`DecodeStream`]
|
||||||
|
pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> {
|
||||||
|
DecodeStream::new(self, skip_special_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DecodeStream will keep the state necessary to produce individual chunks of
|
||||||
|
/// strings given an input stream of token_ids.
|
||||||
|
///
|
||||||
|
/// This is necessary because decoding in general cannot achieve that since strings
|
||||||
|
/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # #[cfg(not(target_os = "windows"))]
|
||||||
|
/// # {
|
||||||
|
/// use tokenizers::Tokenizer;
|
||||||
|
/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();
|
||||||
|
///
|
||||||
|
/// let mut decode_stream = tokenizer.decode_stream(false);
|
||||||
|
/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
|
||||||
|
/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
|
||||||
|
/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
|
||||||
|
/// assert_eq!(
|
||||||
|
/// decode_stream.step(1246).unwrap(),
|
||||||
|
/// Some(" example".to_string())
|
||||||
|
/// );
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Returning `None` means the given id is not enough to produce a chunk.
|
||||||
|
/// This typically happens with `byte_fallback` options where some tokens do
|
||||||
|
/// not represent valid utf-8, and only follow-up token_ids will help produce
|
||||||
|
/// a valid chunk.
|
||||||
|
/// ```
|
||||||
|
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC};
|
||||||
|
/// use std::collections::HashMap;
|
||||||
|
/// use std::iter::FromIterator;
|
||||||
|
///
|
||||||
|
/// let vocab = HashMap::from_iter([
|
||||||
|
/// ("<0x20>".to_string(), 0),
|
||||||
|
/// ("<0xC3>".to_string(), 1),
|
||||||
|
/// ("<0xA9>".to_string(), 2),
|
||||||
|
/// (" This".to_string(), 3),
|
||||||
|
/// ]);
|
||||||
|
/// let merges = vec![];
|
||||||
|
/// let bpe = BPE::builder()
|
||||||
|
/// .vocab_and_merges(vocab, merges)
|
||||||
|
/// .byte_fallback(true)
|
||||||
|
/// .build()
|
||||||
|
/// .unwrap();
|
||||||
|
/// let tokenizer = TokenizerBuilder::default()
|
||||||
|
/// .with_model(bpe)
|
||||||
|
/// .with_decoder(Some(ByteFallback::default()))
|
||||||
|
/// .with_normalizer(Some(NFC))
|
||||||
|
/// .with_pre_tokenizer(Some(ByteLevel::default()))
|
||||||
|
/// .with_post_processor(Some(ByteLevel::default()))
|
||||||
|
/// .build().unwrap();
|
||||||
|
///
|
||||||
|
/// let mut decode_stream = tokenizer.decode_stream(false);
|
||||||
|
/// // Single byte_fallback is valid utf-8
|
||||||
|
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
|
||||||
|
/// // Invalid utf-8
|
||||||
|
/// assert_eq!(decode_stream.step(1).unwrap(), None);
|
||||||
|
/// // Valid utf-8 again, this corresponds to both tokens: [1, 2]
|
||||||
|
/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would
|
||||||
|
/// fail.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC};
|
||||||
|
/// use std::collections::HashMap;
|
||||||
|
/// use std::iter::FromIterator;
|
||||||
|
///
|
||||||
|
/// let vocab = HashMap::from_iter([
|
||||||
|
/// ("▁This".to_string(), 0),
|
||||||
|
/// ]);
|
||||||
|
/// let merges = vec![];
|
||||||
|
/// let bpe = BPE::builder()
|
||||||
|
/// .vocab_and_merges(vocab, merges)
|
||||||
|
/// .byte_fallback(true)
|
||||||
|
/// .build()
|
||||||
|
/// .unwrap();
|
||||||
|
/// let tokenizer = TokenizerBuilder::new()
|
||||||
|
/// .with_model(bpe)
|
||||||
|
/// .with_decoder(Some(Metaspace::default()))
|
||||||
|
/// .with_normalizer(Some(NFC))
|
||||||
|
/// .with_pre_tokenizer(Some(ByteLevel::default()))
|
||||||
|
/// .with_post_processor(Some(ByteLevel::default()))
|
||||||
|
/// .build()
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// // Strip decoder removes the extra initial space
|
||||||
|
/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This");
|
||||||
|
/// // Decoding one token at a time would produce "ThisThis"
|
||||||
|
/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This");
|
||||||
|
///
|
||||||
|
/// // Using a stream fixes it by keeping the necessary state.
|
||||||
|
/// let mut decode_stream = tokenizer.decode_stream(false);
|
||||||
|
/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string()));
|
||||||
|
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string()));
|
||||||
|
/// ```
|
||||||
|
pub struct DecodeStream<'tok, M, N, PT, PP, D> {
|
||||||
|
/// A reference to the tokenizer
|
||||||
|
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
|
||||||
|
/// Regular decode option that is kept throughout.
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
/// A temporary buffer of the necessary token_ids needed
|
||||||
|
/// to produce valid string chunks.
|
||||||
|
/// This typically contains 3 parts:
|
||||||
|
/// - read
|
||||||
|
/// - prefix
|
||||||
|
/// - rest
|
||||||
|
///
|
||||||
|
/// Read is the bit necessary to surround the prefix
|
||||||
|
/// so decoding the whole ids produces a valid prefix.
|
||||||
|
/// Prefix is the previously produced string, kept around to trim off of
|
||||||
|
/// the next valid chunk
|
||||||
|
ids: Vec<u32>,
|
||||||
|
/// The previously returned chunk that needs to be discarded from the
|
||||||
|
/// decoding of the current ids to produce the next chunk
|
||||||
|
prefix: String,
|
||||||
|
/// The index within the ids corresponding to the prefix so we can drain
|
||||||
|
/// correctly
|
||||||
|
prefix_index: usize,
|
||||||
|
/// We need to keep 2 prefixes.
|
||||||
|
/// Prefix is the second one that was already emitted to discard the part
|
||||||
|
/// of the text of all the ids
|
||||||
|
/// read is the prefix kept only for starting side effects of the prefix
|
||||||
|
read_index: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum DecodeStreamError {
|
||||||
|
#[error("Invalid prefix encountered")]
|
||||||
|
InvalidPrefix,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
|
||||||
|
where
|
||||||
|
M: Model,
|
||||||
|
N: Normalizer,
|
||||||
|
PT: PreTokenizer,
|
||||||
|
PP: PostProcessor,
|
||||||
|
D: Decoder,
|
||||||
|
{
|
||||||
|
fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
tokenizer,
|
||||||
|
ids: vec![],
|
||||||
|
skip_special_tokens,
|
||||||
|
prefix: "".to_string(),
|
||||||
|
prefix_index: 0,
|
||||||
|
read_index: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// See [`DecodeStream`]
|
||||||
|
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
||||||
|
self.ids.push(id);
|
||||||
|
let string = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
|
||||||
|
if string.len() > self.prefix.len() && !string.ends_with('<27>') {
|
||||||
|
if !(string.starts_with(&self.prefix)) {
|
||||||
|
return Err(Box::new(DecodeStreamError::InvalidPrefix));
|
||||||
|
}
|
||||||
|
let new_text = &string[self.prefix.len()..].to_string();
|
||||||
|
let new_prefix_index = self.ids.len() - self.prefix_index;
|
||||||
|
self.ids = self.ids.drain(self.read_index..).collect();
|
||||||
|
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
|
||||||
|
self.read_index = self.prefix_index;
|
||||||
|
self.prefix_index = new_prefix_index;
|
||||||
|
Ok(Some(new_text.to_string()))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::iter::FromIterator;
|
||||||
|
|
||||||
|
use tokenizers::decoders::byte_fallback::ByteFallback;
|
||||||
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
|
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
|
||||||
use tokenizers::normalizers::{Sequence, Strip, NFC};
|
use tokenizers::normalizers::{Sequence, Strip, NFC};
|
||||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||||
@ -58,6 +62,65 @@ fn load_tokenizer() {
|
|||||||
assert_eq!(decoded, example);
|
assert_eq!(decoded, example);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn streaming_tokenizer() {
|
||||||
|
let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();
|
||||||
|
|
||||||
|
let mut decode_stream = tokenizer.decode_stream(false);
|
||||||
|
assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
|
||||||
|
assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
|
||||||
|
assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
|
||||||
|
assert_eq!(
|
||||||
|
decode_stream.step(1246).unwrap(),
|
||||||
|
Some(" example".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
|
||||||
|
let encoded = tokenizer.encode("This is an example", false).unwrap();
|
||||||
|
assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]);
|
||||||
|
let mut decode_stream = tokenizer.decode_stream(false);
|
||||||
|
// No space anymore
|
||||||
|
assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string()));
|
||||||
|
let mut decode_stream = tokenizer.decode_stream(false);
|
||||||
|
assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string()));
|
||||||
|
assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string()));
|
||||||
|
assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string()));
|
||||||
|
assert_eq!(
|
||||||
|
decode_stream.step(823).unwrap(),
|
||||||
|
Some(" example".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
// None example
|
||||||
|
let vocab = HashMap::from_iter([
|
||||||
|
("<0x20>".to_string(), 0),
|
||||||
|
("<0xC3>".to_string(), 1),
|
||||||
|
("<0xA9>".to_string(), 2),
|
||||||
|
(" This".to_string(), 3),
|
||||||
|
]);
|
||||||
|
let merges = vec![];
|
||||||
|
let bpe = BPE::builder()
|
||||||
|
.vocab_and_merges(vocab, merges)
|
||||||
|
.byte_fallback(true)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let tokenizer = TokenizerBuilder::new()
|
||||||
|
.with_model(bpe)
|
||||||
|
.with_normalizer(Some(Sequence::new(vec![
|
||||||
|
Strip::new(true, true).into(),
|
||||||
|
NFC.into(),
|
||||||
|
])))
|
||||||
|
.with_pre_tokenizer(Some(ByteLevel::default()))
|
||||||
|
.with_post_processor(Some(ByteLevel::default()))
|
||||||
|
.with_decoder(Some(ByteFallback::default()))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let mut decode_stream = tokenizer.decode_stream(false);
|
||||||
|
assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
|
||||||
|
assert_eq!(decode_stream.step(1).unwrap(), None);
|
||||||
|
assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
|
||||||
|
assert_eq!(decode_stream.step(2).unwrap(), None);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn quicktour_slow_train() -> tokenizers::Result<()> {
|
fn quicktour_slow_train() -> tokenizers::Result<()> {
|
||||||
|
Reference in New Issue
Block a user