mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Decode stream python (#1678)
* Python binding for decode stream Different API because Python cannot handle lifetimes properly. * Clippy.
This commit is contained in:
@ -12,3 +12,4 @@ Metaspace = decoders.Metaspace
|
||||
BPEDecoder = decoders.BPEDecoder
|
||||
CTC = decoders.CTC
|
||||
Sequence = decoders.Sequence
|
||||
DecodeStream = decoders.DecodeStream
|
||||
|
@ -1,4 +1,12 @@
|
||||
# Generated content DO NOT EDIT
|
||||
class DecodeStream:
|
||||
"""
|
||||
Class needed for streaming decode
|
||||
|
||||
"""
|
||||
def __init__(self, skip_special_tokens):
|
||||
pass
|
||||
|
||||
class Decoder:
|
||||
"""
|
||||
Base class for all decoders
|
||||
|
@ -1,6 +1,7 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use crate::pre_tokenizers::from_string;
|
||||
use crate::tokenizer::PyTokenizer;
|
||||
use crate::utils::PyPattern;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
@ -590,9 +591,71 @@ pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyBPEDecoder>()?;
|
||||
m.add_class::<PyCTCDecoder>()?;
|
||||
m.add_class::<PySequenceDecoder>()?;
|
||||
m.add_class::<PyDecodeStream>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Class needed for streaming decode
|
||||
///
|
||||
#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")]
|
||||
#[derive(Clone)]
|
||||
pub struct PyDecodeStream {
|
||||
/// Regular decode option that is kept throughout.
|
||||
skip_special_tokens: bool,
|
||||
/// A temporary buffer of the necessary token_ids needed
|
||||
/// to produce valid string chunks.
|
||||
/// This typically contains 3 parts:
|
||||
/// - read
|
||||
/// - prefix
|
||||
/// - rest
|
||||
///
|
||||
/// Read is the bit necessary to surround the prefix
|
||||
/// so decoding the whole ids produces a valid prefix.
|
||||
/// Prefix is the previously produced string, kept around to trim off of
|
||||
/// the next valid chunk
|
||||
ids: Vec<u32>,
|
||||
/// The previously returned chunk that needs to be discarded from the
|
||||
/// decoding of the current ids to produce the next chunk
|
||||
prefix: String,
|
||||
/// The index within the ids corresponding to the prefix so we can drain
|
||||
/// correctly
|
||||
prefix_index: usize,
|
||||
/// We need to keep 2 prefixes.
|
||||
/// Prefix is the second one that was already emitted to discard the part
|
||||
/// of the text of all the ids
|
||||
/// read is the prefix kept only for starting side effects of the prefix
|
||||
read_index: usize,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyDecodeStream {
|
||||
#[new]
|
||||
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
|
||||
fn new(skip_special_tokens: bool) -> Self {
|
||||
PyDecodeStream {
|
||||
skip_special_tokens,
|
||||
ids: vec![],
|
||||
prefix: "".to_string(),
|
||||
prefix_index: 0,
|
||||
read_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
|
||||
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
|
||||
ToPyResult(tk::tokenizer::step_decode_stream(
|
||||
&tokenizer.tokenizer,
|
||||
id,
|
||||
self.skip_special_tokens,
|
||||
&mut self.ids,
|
||||
&mut self.prefix,
|
||||
&mut self.prefix_index,
|
||||
&mut self.read_index,
|
||||
))
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
@ -467,7 +467,7 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
|
||||
#[derive(Clone, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyTokenizer {
|
||||
tokenizer: Tokenizer,
|
||||
pub(crate) tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
impl PyTokenizer {
|
||||
|
@ -9,6 +9,7 @@ from tokenizers.models import BPE, Model, Unigram
|
||||
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
|
||||
from tokenizers.processors import RobertaProcessing, TemplateProcessing
|
||||
from tokenizers.normalizers import Strip, Lowercase, Sequence
|
||||
from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace
|
||||
|
||||
|
||||
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
|
||||
@ -365,6 +366,37 @@ class TestTokenizer:
|
||||
output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
|
||||
assert output == ["my name is john", "pair"]
|
||||
|
||||
# Can decode stream
|
||||
stream = DecodeStream(skip_special_tokens=False)
|
||||
assert stream.step(tokenizer, 0) == "my"
|
||||
assert stream.step(tokenizer, 1) == " name"
|
||||
assert stream.step(tokenizer, 2) == " is"
|
||||
assert stream.step(tokenizer, 3) == " john"
|
||||
|
||||
def test_decode_stream(self):
|
||||
vocab = [
|
||||
("<unk>", 0.0),
|
||||
("<0x20>", -0.1),
|
||||
("<0xC3>", -0.2),
|
||||
("<0xA9>", -0.3),
|
||||
]
|
||||
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
|
||||
tokenizer.decoder = ByteFallback()
|
||||
stream = DecodeStream(skip_special_tokens=False)
|
||||
assert stream.step(tokenizer, 1) == " "
|
||||
assert stream.step(tokenizer, 2) == None
|
||||
assert stream.step(tokenizer, 3) == "é"
|
||||
|
||||
vocab = [
|
||||
("<unk>", 0.0),
|
||||
("▁This", -0.1),
|
||||
]
|
||||
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
|
||||
tokenizer.decoder = DecoderMetaspace()
|
||||
stream = DecodeStream(skip_special_tokens=False)
|
||||
assert stream.step(tokenizer, 1) == "This"
|
||||
assert stream.step(tokenizer, 1) == " This"
|
||||
|
||||
def test_get_vocab(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
||||
|
Reference in New Issue
Block a user