Decode stream python (#1678)

* Python binding for decode stream

Different API because Python cannot handle lifetimes properly.

* Clippy.
This commit is contained in:
Nicolas Patry
2024-11-15 19:06:22 +08:00
committed by GitHub
parent 500db282a8
commit cc5fb01a2f
6 changed files with 148 additions and 18 deletions

View File

@ -12,3 +12,4 @@ Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC CTC = decoders.CTC
Sequence = decoders.Sequence Sequence = decoders.Sequence
DecodeStream = decoders.DecodeStream

View File

@ -1,4 +1,12 @@
# Generated content DO NOT EDIT # Generated content DO NOT EDIT
class DecodeStream:
"""
Class needed for streaming decode
"""
def __init__(self, skip_special_tokens):
pass
class Decoder: class Decoder:
""" """
Base class for all decoders Base class for all decoders

View File

@ -1,6 +1,7 @@
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use crate::pre_tokenizers::from_string; use crate::pre_tokenizers::from_string;
use crate::tokenizer::PyTokenizer;
use crate::utils::PyPattern; use crate::utils::PyPattern;
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
@ -590,9 +591,71 @@ pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyBPEDecoder>()?; m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?; m.add_class::<PyCTCDecoder>()?;
m.add_class::<PySequenceDecoder>()?; m.add_class::<PySequenceDecoder>()?;
m.add_class::<PyDecodeStream>()?;
Ok(()) Ok(())
} }
/// Class needed for streaming decode
///
#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")]
#[derive(Clone)]
pub struct PyDecodeStream {
/// Regular decode option that is kept throughout.
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks.
/// This typically contains 3 parts:
/// - read
/// - prefix
/// - rest
///
/// Read is the bit necessary to surround the prefix
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}
#[pymethods]
impl PyDecodeStream {
#[new]
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
fn new(skip_special_tokens: bool) -> Self {
PyDecodeStream {
skip_special_tokens,
ids: vec![],
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}
#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
))
.into()
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};

View File

@ -467,7 +467,7 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
#[derive(Clone, Serialize)] #[derive(Clone, Serialize)]
#[serde(transparent)] #[serde(transparent)]
pub struct PyTokenizer { pub struct PyTokenizer {
tokenizer: Tokenizer, pub(crate) tokenizer: Tokenizer,
} }
impl PyTokenizer { impl PyTokenizer {

View File

@ -9,6 +9,7 @@ from tokenizers.models import BPE, Model, Unigram
from tokenizers.pre_tokenizers import ByteLevel, Metaspace from tokenizers.pre_tokenizers import ByteLevel, Metaspace
from tokenizers.processors import RobertaProcessing, TemplateProcessing from tokenizers.processors import RobertaProcessing, TemplateProcessing
from tokenizers.normalizers import Strip, Lowercase, Sequence from tokenizers.normalizers import Strip, Lowercase, Sequence
from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
@ -365,6 +366,37 @@ class TestTokenizer:
output = tokenizer.decode_batch([[0, 1, 2, 3], [4]]) output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
assert output == ["my name is john", "pair"] assert output == ["my name is john", "pair"]
# Can decode stream
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 0) == "my"
assert stream.step(tokenizer, 1) == " name"
assert stream.step(tokenizer, 2) == " is"
assert stream.step(tokenizer, 3) == " john"
def test_decode_stream(self):
vocab = [
("<unk>", 0.0),
("<0x20>", -0.1),
("<0xC3>", -0.2),
("<0xA9>", -0.3),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
tokenizer.decoder = ByteFallback()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == " "
assert stream.step(tokenizer, 2) == None
assert stream.step(tokenizer, 3) == "é"
vocab = [
("<unk>", 0.0),
("▁This", -0.1),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
tokenizer.decoder = DecoderMetaspace()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == "This"
assert stream.step(tokenizer, 1) == " This"
def test_get_vocab(self): def test_get_vocab(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

View File

@ -1069,24 +1069,50 @@ where
/// See [`DecodeStream`] /// See [`DecodeStream`]
pub fn step(&mut self, id: u32) -> Result<Option<String>> { pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id); step_decode_stream(
let string = self self.tokenizer,
.tokenizer id,
.decode(self.ids.as_slice(), self.skip_special_tokens)?; self.skip_special_tokens,
if string.len() > self.prefix.len() && !string.ends_with('<27>') { &mut self.ids,
if !(string.starts_with(&self.prefix)) { &mut self.prefix,
return Err(Box::new(DecodeStreamError::InvalidPrefix)); &mut self.prefix_index,
} &mut self.read_index,
let new_text = &string[self.prefix.len()..].to_string(); )
let new_prefix_index = self.ids.len() - self.prefix_index; }
self.ids = self.ids.drain(self.read_index..).collect(); }
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
self.read_index = self.prefix_index; /// Internal function exposed only to bypass python limitations
self.prefix_index = new_prefix_index; pub fn step_decode_stream<M, N, PT, PP, D>(
Ok(Some(new_text.to_string())) tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
} else { id: u32,
Ok(None) skip_special_tokens: bool,
ids: &mut Vec<u32>,
prefix: &mut String,
prefix_index: &mut usize,
read_index: &mut usize,
) -> Result<Option<String>>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
ids.push(id);
let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
if string.len() > prefix.len() && !string.ends_with('<27>') {
if !(string.starts_with(&*prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
} }
let new_text = &string[prefix.len()..].to_string();
let new_prefix_index = ids.len() - *prefix_index;
*ids = ids.drain(*read_index..).collect();
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
*read_index = *prefix_index;
*prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
} }
} }