mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Decode stream python (#1678)
* Python binding for decode stream Different API because Python cannot handle lifetimes properly. * Clippy.
This commit is contained in:
@ -12,3 +12,4 @@ Metaspace = decoders.Metaspace
|
||||
BPEDecoder = decoders.BPEDecoder
|
||||
CTC = decoders.CTC
|
||||
Sequence = decoders.Sequence
|
||||
DecodeStream = decoders.DecodeStream
|
||||
|
@ -1,4 +1,12 @@
|
||||
# Generated content DO NOT EDIT
|
||||
class DecodeStream:
|
||||
"""
|
||||
Class needed for streaming decode
|
||||
|
||||
"""
|
||||
def __init__(self, skip_special_tokens):
|
||||
pass
|
||||
|
||||
class Decoder:
|
||||
"""
|
||||
Base class for all decoders
|
||||
|
@ -1,6 +1,7 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use crate::pre_tokenizers::from_string;
|
||||
use crate::tokenizer::PyTokenizer;
|
||||
use crate::utils::PyPattern;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
@ -590,9 +591,71 @@ pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyBPEDecoder>()?;
|
||||
m.add_class::<PyCTCDecoder>()?;
|
||||
m.add_class::<PySequenceDecoder>()?;
|
||||
m.add_class::<PyDecodeStream>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Class needed for streaming decode
|
||||
///
|
||||
#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")]
|
||||
#[derive(Clone)]
|
||||
pub struct PyDecodeStream {
|
||||
/// Regular decode option that is kept throughout.
|
||||
skip_special_tokens: bool,
|
||||
/// A temporary buffer of the necessary token_ids needed
|
||||
/// to produce valid string chunks.
|
||||
/// This typically contains 3 parts:
|
||||
/// - read
|
||||
/// - prefix
|
||||
/// - rest
|
||||
///
|
||||
/// Read is the bit necessary to surround the prefix
|
||||
/// so decoding the whole ids produces a valid prefix.
|
||||
/// Prefix is the previously produced string, kept around to trim off of
|
||||
/// the next valid chunk
|
||||
ids: Vec<u32>,
|
||||
/// The previously returned chunk that needs to be discarded from the
|
||||
/// decoding of the current ids to produce the next chunk
|
||||
prefix: String,
|
||||
/// The index within the ids corresponding to the prefix so we can drain
|
||||
/// correctly
|
||||
prefix_index: usize,
|
||||
/// We need to keep 2 prefixes.
|
||||
/// Prefix is the second one that was already emitted to discard the part
|
||||
/// of the text of all the ids
|
||||
/// read is the prefix kept only for starting side effects of the prefix
|
||||
read_index: usize,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyDecodeStream {
|
||||
#[new]
|
||||
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
|
||||
fn new(skip_special_tokens: bool) -> Self {
|
||||
PyDecodeStream {
|
||||
skip_special_tokens,
|
||||
ids: vec![],
|
||||
prefix: "".to_string(),
|
||||
prefix_index: 0,
|
||||
read_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
|
||||
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
|
||||
ToPyResult(tk::tokenizer::step_decode_stream(
|
||||
&tokenizer.tokenizer,
|
||||
id,
|
||||
self.skip_special_tokens,
|
||||
&mut self.ids,
|
||||
&mut self.prefix,
|
||||
&mut self.prefix_index,
|
||||
&mut self.read_index,
|
||||
))
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
@ -467,7 +467,7 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
|
||||
#[derive(Clone, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyTokenizer {
|
||||
tokenizer: Tokenizer,
|
||||
pub(crate) tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
impl PyTokenizer {
|
||||
|
@ -9,6 +9,7 @@ from tokenizers.models import BPE, Model, Unigram
|
||||
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
|
||||
from tokenizers.processors import RobertaProcessing, TemplateProcessing
|
||||
from tokenizers.normalizers import Strip, Lowercase, Sequence
|
||||
from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace
|
||||
|
||||
|
||||
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
|
||||
@ -365,6 +366,37 @@ class TestTokenizer:
|
||||
output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
|
||||
assert output == ["my name is john", "pair"]
|
||||
|
||||
# Can decode stream
|
||||
stream = DecodeStream(skip_special_tokens=False)
|
||||
assert stream.step(tokenizer, 0) == "my"
|
||||
assert stream.step(tokenizer, 1) == " name"
|
||||
assert stream.step(tokenizer, 2) == " is"
|
||||
assert stream.step(tokenizer, 3) == " john"
|
||||
|
||||
def test_decode_stream(self):
|
||||
vocab = [
|
||||
("<unk>", 0.0),
|
||||
("<0x20>", -0.1),
|
||||
("<0xC3>", -0.2),
|
||||
("<0xA9>", -0.3),
|
||||
]
|
||||
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
|
||||
tokenizer.decoder = ByteFallback()
|
||||
stream = DecodeStream(skip_special_tokens=False)
|
||||
assert stream.step(tokenizer, 1) == " "
|
||||
assert stream.step(tokenizer, 2) == None
|
||||
assert stream.step(tokenizer, 3) == "é"
|
||||
|
||||
vocab = [
|
||||
("<unk>", 0.0),
|
||||
("▁This", -0.1),
|
||||
]
|
||||
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
|
||||
tokenizer.decoder = DecoderMetaspace()
|
||||
stream = DecodeStream(skip_special_tokens=False)
|
||||
assert stream.step(tokenizer, 1) == "This"
|
||||
assert stream.step(tokenizer, 1) == " This"
|
||||
|
||||
def test_get_vocab(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
||||
|
@ -1069,24 +1069,50 @@ where
|
||||
|
||||
/// See [`DecodeStream`]
|
||||
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
||||
self.ids.push(id);
|
||||
let string = self
|
||||
.tokenizer
|
||||
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
|
||||
if string.len() > self.prefix.len() && !string.ends_with('<27>') {
|
||||
if !(string.starts_with(&self.prefix)) {
|
||||
return Err(Box::new(DecodeStreamError::InvalidPrefix));
|
||||
}
|
||||
let new_text = &string[self.prefix.len()..].to_string();
|
||||
let new_prefix_index = self.ids.len() - self.prefix_index;
|
||||
self.ids = self.ids.drain(self.read_index..).collect();
|
||||
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
|
||||
self.read_index = self.prefix_index;
|
||||
self.prefix_index = new_prefix_index;
|
||||
Ok(Some(new_text.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
step_decode_stream(
|
||||
self.tokenizer,
|
||||
id,
|
||||
self.skip_special_tokens,
|
||||
&mut self.ids,
|
||||
&mut self.prefix,
|
||||
&mut self.prefix_index,
|
||||
&mut self.read_index,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal function exposed only to bypass python limitations
|
||||
pub fn step_decode_stream<M, N, PT, PP, D>(
|
||||
tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
|
||||
id: u32,
|
||||
skip_special_tokens: bool,
|
||||
ids: &mut Vec<u32>,
|
||||
prefix: &mut String,
|
||||
prefix_index: &mut usize,
|
||||
read_index: &mut usize,
|
||||
) -> Result<Option<String>>
|
||||
where
|
||||
M: Model,
|
||||
N: Normalizer,
|
||||
PT: PreTokenizer,
|
||||
PP: PostProcessor,
|
||||
D: Decoder,
|
||||
{
|
||||
ids.push(id);
|
||||
let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
|
||||
if string.len() > prefix.len() && !string.ends_with('<27>') {
|
||||
if !(string.starts_with(&*prefix)) {
|
||||
return Err(Box::new(DecodeStreamError::InvalidPrefix));
|
||||
}
|
||||
let new_text = &string[prefix.len()..].to_string();
|
||||
let new_prefix_index = ids.len() - *prefix_index;
|
||||
*ids = ids.drain(*read_index..).collect();
|
||||
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
|
||||
*read_index = *prefix_index;
|
||||
*prefix_index = new_prefix_index;
|
||||
Ok(Some(new_text.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user