mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Decode stream python (#1678)
* Python binding for decode stream Different API because Python cannot handle lifetimes properly. * Clippy.
This commit is contained in:
@ -12,3 +12,4 @@ Metaspace = decoders.Metaspace
|
|||||||
BPEDecoder = decoders.BPEDecoder
|
BPEDecoder = decoders.BPEDecoder
|
||||||
CTC = decoders.CTC
|
CTC = decoders.CTC
|
||||||
Sequence = decoders.Sequence
|
Sequence = decoders.Sequence
|
||||||
|
DecodeStream = decoders.DecodeStream
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
# Generated content DO NOT EDIT
|
# Generated content DO NOT EDIT
|
||||||
|
class DecodeStream:
|
||||||
|
"""
|
||||||
|
Class needed for streaming decode
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, skip_special_tokens):
|
||||||
|
pass
|
||||||
|
|
||||||
class Decoder:
|
class Decoder:
|
||||||
"""
|
"""
|
||||||
Base class for all decoders
|
Base class for all decoders
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use crate::pre_tokenizers::from_string;
|
use crate::pre_tokenizers::from_string;
|
||||||
|
use crate::tokenizer::PyTokenizer;
|
||||||
use crate::utils::PyPattern;
|
use crate::utils::PyPattern;
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
@ -590,9 +591,71 @@ pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
m.add_class::<PyBPEDecoder>()?;
|
m.add_class::<PyBPEDecoder>()?;
|
||||||
m.add_class::<PyCTCDecoder>()?;
|
m.add_class::<PyCTCDecoder>()?;
|
||||||
m.add_class::<PySequenceDecoder>()?;
|
m.add_class::<PySequenceDecoder>()?;
|
||||||
|
m.add_class::<PyDecodeStream>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Class needed for streaming decode
|
||||||
|
///
|
||||||
|
#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")]
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PyDecodeStream {
|
||||||
|
/// Regular decode option that is kept throughout.
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
/// A temporary buffer of the necessary token_ids needed
|
||||||
|
/// to produce valid string chunks.
|
||||||
|
/// This typically contains 3 parts:
|
||||||
|
/// - read
|
||||||
|
/// - prefix
|
||||||
|
/// - rest
|
||||||
|
///
|
||||||
|
/// Read is the bit necessary to surround the prefix
|
||||||
|
/// so decoding the whole ids produces a valid prefix.
|
||||||
|
/// Prefix is the previously produced string, kept around to trim off of
|
||||||
|
/// the next valid chunk
|
||||||
|
ids: Vec<u32>,
|
||||||
|
/// The previously returned chunk that needs to be discarded from the
|
||||||
|
/// decoding of the current ids to produce the next chunk
|
||||||
|
prefix: String,
|
||||||
|
/// The index within the ids corresponding to the prefix so we can drain
|
||||||
|
/// correctly
|
||||||
|
prefix_index: usize,
|
||||||
|
/// We need to keep 2 prefixes.
|
||||||
|
/// Prefix is the second one that was already emitted to discard the part
|
||||||
|
/// of the text of all the ids
|
||||||
|
/// read is the prefix kept only for starting side effects of the prefix
|
||||||
|
read_index: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl PyDecodeStream {
|
||||||
|
#[new]
|
||||||
|
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
|
||||||
|
fn new(skip_special_tokens: bool) -> Self {
|
||||||
|
PyDecodeStream {
|
||||||
|
skip_special_tokens,
|
||||||
|
ids: vec![],
|
||||||
|
prefix: "".to_string(),
|
||||||
|
prefix_index: 0,
|
||||||
|
read_index: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
|
||||||
|
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
|
||||||
|
ToPyResult(tk::tokenizer::step_decode_stream(
|
||||||
|
&tokenizer.tokenizer,
|
||||||
|
id,
|
||||||
|
self.skip_special_tokens,
|
||||||
|
&mut self.ids,
|
||||||
|
&mut self.prefix,
|
||||||
|
&mut self.prefix_index,
|
||||||
|
&mut self.read_index,
|
||||||
|
))
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
@ -467,7 +467,7 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
|
|||||||
#[derive(Clone, Serialize)]
|
#[derive(Clone, Serialize)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub struct PyTokenizer {
|
pub struct PyTokenizer {
|
||||||
tokenizer: Tokenizer,
|
pub(crate) tokenizer: Tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PyTokenizer {
|
impl PyTokenizer {
|
||||||
|
@ -9,6 +9,7 @@ from tokenizers.models import BPE, Model, Unigram
|
|||||||
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
|
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
|
||||||
from tokenizers.processors import RobertaProcessing, TemplateProcessing
|
from tokenizers.processors import RobertaProcessing, TemplateProcessing
|
||||||
from tokenizers.normalizers import Strip, Lowercase, Sequence
|
from tokenizers.normalizers import Strip, Lowercase, Sequence
|
||||||
|
from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace
|
||||||
|
|
||||||
|
|
||||||
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
|
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
|
||||||
@ -365,6 +366,37 @@ class TestTokenizer:
|
|||||||
output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
|
output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
|
||||||
assert output == ["my name is john", "pair"]
|
assert output == ["my name is john", "pair"]
|
||||||
|
|
||||||
|
# Can decode stream
|
||||||
|
stream = DecodeStream(skip_special_tokens=False)
|
||||||
|
assert stream.step(tokenizer, 0) == "my"
|
||||||
|
assert stream.step(tokenizer, 1) == " name"
|
||||||
|
assert stream.step(tokenizer, 2) == " is"
|
||||||
|
assert stream.step(tokenizer, 3) == " john"
|
||||||
|
|
||||||
|
def test_decode_stream(self):
|
||||||
|
vocab = [
|
||||||
|
("<unk>", 0.0),
|
||||||
|
("<0x20>", -0.1),
|
||||||
|
("<0xC3>", -0.2),
|
||||||
|
("<0xA9>", -0.3),
|
||||||
|
]
|
||||||
|
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
|
||||||
|
tokenizer.decoder = ByteFallback()
|
||||||
|
stream = DecodeStream(skip_special_tokens=False)
|
||||||
|
assert stream.step(tokenizer, 1) == " "
|
||||||
|
assert stream.step(tokenizer, 2) == None
|
||||||
|
assert stream.step(tokenizer, 3) == "é"
|
||||||
|
|
||||||
|
vocab = [
|
||||||
|
("<unk>", 0.0),
|
||||||
|
("▁This", -0.1),
|
||||||
|
]
|
||||||
|
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
|
||||||
|
tokenizer.decoder = DecoderMetaspace()
|
||||||
|
stream = DecodeStream(skip_special_tokens=False)
|
||||||
|
assert stream.step(tokenizer, 1) == "This"
|
||||||
|
assert stream.step(tokenizer, 1) == " This"
|
||||||
|
|
||||||
def test_get_vocab(self):
|
def test_get_vocab(self):
|
||||||
tokenizer = Tokenizer(BPE())
|
tokenizer = Tokenizer(BPE())
|
||||||
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
||||||
|
@ -1069,25 +1069,51 @@ where
|
|||||||
|
|
||||||
/// See [`DecodeStream`]
|
/// See [`DecodeStream`]
|
||||||
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
||||||
self.ids.push(id);
|
step_decode_stream(
|
||||||
let string = self
|
self.tokenizer,
|
||||||
.tokenizer
|
id,
|
||||||
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
|
self.skip_special_tokens,
|
||||||
if string.len() > self.prefix.len() && !string.ends_with('<27>') {
|
&mut self.ids,
|
||||||
if !(string.starts_with(&self.prefix)) {
|
&mut self.prefix,
|
||||||
|
&mut self.prefix_index,
|
||||||
|
&mut self.read_index,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal function exposed only to bypass python limitations
|
||||||
|
pub fn step_decode_stream<M, N, PT, PP, D>(
|
||||||
|
tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
|
||||||
|
id: u32,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
ids: &mut Vec<u32>,
|
||||||
|
prefix: &mut String,
|
||||||
|
prefix_index: &mut usize,
|
||||||
|
read_index: &mut usize,
|
||||||
|
) -> Result<Option<String>>
|
||||||
|
where
|
||||||
|
M: Model,
|
||||||
|
N: Normalizer,
|
||||||
|
PT: PreTokenizer,
|
||||||
|
PP: PostProcessor,
|
||||||
|
D: Decoder,
|
||||||
|
{
|
||||||
|
ids.push(id);
|
||||||
|
let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
|
||||||
|
if string.len() > prefix.len() && !string.ends_with('<27>') {
|
||||||
|
if !(string.starts_with(&*prefix)) {
|
||||||
return Err(Box::new(DecodeStreamError::InvalidPrefix));
|
return Err(Box::new(DecodeStreamError::InvalidPrefix));
|
||||||
}
|
}
|
||||||
let new_text = &string[self.prefix.len()..].to_string();
|
let new_text = &string[prefix.len()..].to_string();
|
||||||
let new_prefix_index = self.ids.len() - self.prefix_index;
|
let new_prefix_index = ids.len() - *prefix_index;
|
||||||
self.ids = self.ids.drain(self.read_index..).collect();
|
*ids = ids.drain(*read_index..).collect();
|
||||||
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
|
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
|
||||||
self.read_index = self.prefix_index;
|
*read_index = *prefix_index;
|
||||||
self.prefix_index = new_prefix_index;
|
*prefix_index = new_prefix_index;
|
||||||
Ok(Some(new_text.to_string()))
|
Ok(Some(new_text.to_string()))
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
||||||
|
Reference in New Issue
Block a user