Upgrade pyo3 to 0.16 (#956)

* Upgrade pyo3 to 0.15

Rebase-conflicts-fixed-by: H. Vetinari <h.vetinari@gmx.com>

* Upgrade pyo3 to 0.16

Rebase-conflicts-fixed-by: H. Vetinari <h.vetinari@gmx.com>

* Install Python before running cargo clippy

* Fix clippy warnings

* Use `PyArray_Check` instead of downcasting to `PyArray1<u8>`

* Enable `auto-initialize` of pyo3 to fix `cargo test
--no-default-features`

* Fix some test cases

Why do they change?

* Refactor and add SAFETY comments to `PyArrayUnicode`

Replace deprecated `PyUnicode_FromUnicode` with `PyUnicode_FromKindAndData`

Co-authored-by: messense <messense@icloud.com>
This commit is contained in:
h-vetinari
2022-05-06 00:48:40 +11:00
committed by GitHub
parent 6533bf0fad
commit 519cc13be0
19 changed files with 620 additions and 620 deletions

View File

@ -63,6 +63,13 @@ jobs:
toolchain: stable toolchain: stable
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Python
uses: actions/setup-python@v2
with:
python-version: 3.9
architecture: "x64"
- name: Cache Cargo Registry - name: Cache Cargo Registry
uses: actions/cache@v1 uses: actions/cache@v1
with: with:
@ -88,12 +95,6 @@ jobs:
command: clippy command: clippy
args: --manifest-path ./bindings/python/Cargo.toml --all-targets --all-features -- -D warnings args: --manifest-path ./bindings/python/Cargo.toml --all-targets --all-features -- -D warnings
- name: Install Python
uses: actions/setup-python@v2
with:
python-version: 3.9
architecture: "x64"
- name: Install - name: Install
working-directory: ./bindings/python working-directory: ./bindings/python
run: | run: |

View File

@ -11,4 +11,3 @@ rustflags = [
"-C", "link-arg=dynamic_lookup", "-C", "link-arg=dynamic_lookup",
"-C", "link-arg=-mmacosx-version-min=10.11", "-C", "link-arg=-mmacosx-version-min=10.11",
] ]

File diff suppressed because it is too large Load Diff

View File

@ -14,8 +14,8 @@ serde = { version = "1.0", features = [ "rc", "derive" ]}
serde_json = "1.0" serde_json = "1.0"
libc = "0.2" libc = "0.2"
env_logger = "0.7.1" env_logger = "0.7.1"
pyo3 = "0.12" pyo3 = "0.16.2"
numpy = "0.12" numpy = "0.16.2"
ndarray = "0.13" ndarray = "0.13"
onig = { version = "6.0", default-features = false } onig = { version = "6.0", default-features = false }
itertools = "0.9" itertools = "0.9"
@ -26,8 +26,7 @@ path = "../../tokenizers"
[dev-dependencies] [dev-dependencies]
tempfile = "3.1" tempfile = "3.1"
pyo3 = { version = "0.16.2", features = ["auto-initialize"] }
[features] [features]
default = ["pyo3/extension-module"] default = ["pyo3/extension-module"]

View File

@ -21,7 +21,7 @@ use super::error::ToPyResult;
/// ///
/// This class is not supposed to be instantiated directly. Instead, any implementation of /// This class is not supposed to be instantiated directly. Instead, any implementation of
/// a Decoder will return an instance of this class when instantiated. /// a Decoder will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.decoders", name=Decoder)] #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)]
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
pub struct PyDecoder { pub struct PyDecoder {
#[serde(flatten)] #[serde(flatten)]
@ -97,7 +97,7 @@ impl PyDecoder {
/// ///
/// Returns: /// Returns:
/// :obj:`str`: The decoded string /// :obj:`str`: The decoded string
#[text_signature = "(self, tokens)"] #[pyo3(text_signature = "(self, tokens)")]
fn decode(&self, tokens: Vec<String>) -> PyResult<String> { fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into() ToPyResult(self.decoder.decode(tokens)).into()
} }
@ -141,8 +141,8 @@ macro_rules! setter {
/// ///
/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.ByteLevel` /// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.ByteLevel`
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`. /// :class:`~tokenizers.pre_tokenizers.PreTokenizer`.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=ByteLevel)] #[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "ByteLevel")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyByteLevelDec {} pub struct PyByteLevelDec {}
#[pymethods] #[pymethods]
impl PyByteLevelDec { impl PyByteLevelDec {
@ -161,8 +161,8 @@ impl PyByteLevelDec {
/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): /// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to cleanup some tokenization artifacts. Mainly spaces before punctuation, /// Whether to cleanup some tokenization artifacts. Mainly spaces before punctuation,
/// and some abbreviated english forms. /// and some abbreviated english forms.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=WordPiece)] #[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "WordPiece")]
#[text_signature = "(self, prefix=\"##\", cleanup=True)"] #[pyo3(text_signature = "(self, prefix=\"##\", cleanup=True)")]
pub struct PyWordPieceDec {} pub struct PyWordPieceDec {}
#[pymethods] #[pymethods]
impl PyWordPieceDec { impl PyWordPieceDec {
@ -203,8 +203,8 @@ impl PyWordPieceDec {
/// add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`): /// add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to add a space to the first word if there isn't already one. This /// Whether to add a space to the first word if there isn't already one. This
/// lets us treat `hello` exactly like `say hello`. /// lets us treat `hello` exactly like `say hello`.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=Metaspace)] #[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Metaspace")]
#[text_signature = "(self, replacement = \"\", add_prefix_space = True)"] #[pyo3(text_signature = "(self, replacement = \"\", add_prefix_space = True)")]
pub struct PyMetaspaceDec {} pub struct PyMetaspaceDec {}
#[pymethods] #[pymethods]
impl PyMetaspaceDec { impl PyMetaspaceDec {
@ -244,8 +244,8 @@ impl PyMetaspaceDec {
/// suffix (:obj:`str`, `optional`, defaults to :obj:`</w>`): /// suffix (:obj:`str`, `optional`, defaults to :obj:`</w>`):
/// The suffix that was used to caracterize an end-of-word. This suffix will /// The suffix that was used to caracterize an end-of-word. This suffix will
/// be replaced by whitespaces during the decoding /// be replaced by whitespaces during the decoding
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=BPEDecoder)] #[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "BPEDecoder")]
#[text_signature = "(self, suffix=\"</w>\")"] #[pyo3(text_signature = "(self, suffix=\"</w>\")")]
pub struct PyBPEDecoder {} pub struct PyBPEDecoder {}
#[pymethods] #[pymethods]
impl PyBPEDecoder { impl PyBPEDecoder {
@ -276,8 +276,8 @@ impl PyBPEDecoder {
/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): /// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to cleanup some tokenization artifacts. /// Whether to cleanup some tokenization artifacts.
/// Mainly spaces before punctuation, and some abbreviated english forms. /// Mainly spaces before punctuation, and some abbreviated english forms.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=CTC)] #[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "CTC")]
#[text_signature = "(self, pad_token=\"<pad>\", word_delimiter_token=\"|\", cleanup=True)"] #[pyo3(text_signature = "(self, pad_token=\"<pad>\", word_delimiter_token=\"|\", cleanup=True)")]
pub struct PyCTCDecoder {} pub struct PyCTCDecoder {}
#[pymethods] #[pymethods]
impl PyCTCDecoder { impl PyCTCDecoder {
@ -420,8 +420,8 @@ mod test {
let py_meta = py_dec.get_as_subtype().unwrap(); let py_meta = py_dec.get_as_subtype().unwrap();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
assert_eq!( assert_eq!(
"tokenizers.decoders.Metaspace", "Metaspace",
py_meta.as_ref(gil.python()).get_type().name() py_meta.as_ref(gil.python()).get_type().name().unwrap()
); );
} }

View File

@ -1,7 +1,6 @@
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::{Offsets, PaddingDirection}; use tk::tokenizer::{Offsets, PaddingDirection};
use tk::utils::truncation::TruncationDirection; use tk::utils::truncation::TruncationDirection;
use tokenizers as tk; use tokenizers as tk;
@ -9,7 +8,7 @@ use tokenizers as tk;
use crate::error::{deprecation_warning, PyError}; use crate::error::{deprecation_warning, PyError};
/// The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`. /// The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`.
#[pyclass(dict, module = "tokenizers", name=Encoding)] #[pyclass(dict, module = "tokenizers", name = "Encoding")]
#[repr(transparent)] #[repr(transparent)]
pub struct PyEncoding { pub struct PyEncoding {
pub encoding: tk::tokenizer::Encoding, pub encoding: tk::tokenizer::Encoding,
@ -21,24 +20,6 @@ impl From<tk::tokenizer::Encoding> for PyEncoding {
} }
} }
#[pyproto]
impl PyObjectProtocol for PyEncoding {
fn __repr__(&self) -> PyResult<String> {
Ok(format!(
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
attention_mask, special_tokens_mask, overflowing])",
self.encoding.get_ids().len()
))
}
}
#[pyproto]
impl PySequenceProtocol for PyEncoding {
fn __len__(&self) -> PyResult<usize> {
Ok(self.encoding.len())
}
}
#[pymethods] #[pymethods]
impl PyEncoding { impl PyEncoding {
#[new] #[new]
@ -73,6 +54,18 @@ impl PyEncoding {
} }
} }
fn __repr__(&self) -> PyResult<String> {
Ok(format!(
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
attention_mask, special_tokens_mask, overflowing])",
self.encoding.get_ids().len()
))
}
fn __len__(&self) -> PyResult<usize> {
Ok(self.encoding.len())
}
/// Merge the list of encodings into one final :class:`~tokenizers.Encoding` /// Merge the list of encodings into one final :class:`~tokenizers.Encoding`
/// ///
/// Args: /// Args:
@ -86,7 +79,7 @@ impl PyEncoding {
/// :class:`~tokenizers.Encoding`: The resulting Encoding /// :class:`~tokenizers.Encoding`: The resulting Encoding
#[staticmethod] #[staticmethod]
#[args(growing_offsets = true)] #[args(growing_offsets = true)]
#[text_signature = "(encodings, growing_offsets=True)"] #[pyo3(text_signature = "(encodings, growing_offsets=True)")]
fn merge(encodings: Vec<PyRef<PyEncoding>>, growing_offsets: bool) -> PyEncoding { fn merge(encodings: Vec<PyRef<PyEncoding>>, growing_offsets: bool) -> PyEncoding {
tk::tokenizer::Encoding::merge( tk::tokenizer::Encoding::merge(
encodings.into_iter().map(|e| e.encoding.clone()), encodings.into_iter().map(|e| e.encoding.clone()),
@ -108,7 +101,7 @@ impl PyEncoding {
/// ///
/// Set the given sequence index for the whole range of tokens contained in this /// Set the given sequence index for the whole range of tokens contained in this
/// :class:`~tokenizers.Encoding`. /// :class:`~tokenizers.Encoding`.
#[text_signature = "(self, sequence_id)"] #[pyo3(text_signature = "(self, sequence_id)")]
fn set_sequence_id(&mut self, sequence_id: usize) { fn set_sequence_id(&mut self, sequence_id: usize) {
self.encoding.set_sequence_id(sequence_id); self.encoding.set_sequence_id(sequence_id);
} }
@ -270,7 +263,7 @@ impl PyEncoding {
/// Returns: /// Returns:
/// :obj:`Tuple[int, int]`: The range of tokens: :obj:`(first, last + 1)` /// :obj:`Tuple[int, int]`: The range of tokens: :obj:`(first, last + 1)`
#[args(sequence_index = 0)] #[args(sequence_index = 0)]
#[text_signature = "(self, word_index, sequence_index=0)"] #[pyo3(text_signature = "(self, word_index, sequence_index=0)")]
fn word_to_tokens(&self, word_index: u32, sequence_index: usize) -> Option<(usize, usize)> { fn word_to_tokens(&self, word_index: u32, sequence_index: usize) -> Option<(usize, usize)> {
self.encoding.word_to_tokens(word_index, sequence_index) self.encoding.word_to_tokens(word_index, sequence_index)
} }
@ -286,7 +279,7 @@ impl PyEncoding {
/// Returns: /// Returns:
/// :obj:`Tuple[int, int]`: The range of characters (span) :obj:`(first, last + 1)` /// :obj:`Tuple[int, int]`: The range of characters (span) :obj:`(first, last + 1)`
#[args(sequence_index = 0)] #[args(sequence_index = 0)]
#[text_signature = "(self, word_index, sequence_index=0)"] #[pyo3(text_signature = "(self, word_index, sequence_index=0)")]
fn word_to_chars(&self, word_index: u32, sequence_index: usize) -> Option<Offsets> { fn word_to_chars(&self, word_index: u32, sequence_index: usize) -> Option<Offsets> {
self.encoding.word_to_chars(word_index, sequence_index) self.encoding.word_to_chars(word_index, sequence_index)
} }
@ -302,7 +295,7 @@ impl PyEncoding {
/// ///
/// Returns: /// Returns:
/// :obj:`int`: The sequence id of the given token /// :obj:`int`: The sequence id of the given token
#[text_signature = "(self, token_index)"] #[pyo3(text_signature = "(self, token_index)")]
fn token_to_sequence(&self, token_index: usize) -> Option<usize> { fn token_to_sequence(&self, token_index: usize) -> Option<usize> {
self.encoding.token_to_sequence(token_index) self.encoding.token_to_sequence(token_index)
} }
@ -319,7 +312,7 @@ impl PyEncoding {
/// ///
/// Returns: /// Returns:
/// :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)` /// :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)`
#[text_signature = "(self, token_index)"] #[pyo3(text_signature = "(self, token_index)")]
fn token_to_chars(&self, token_index: usize) -> Option<Offsets> { fn token_to_chars(&self, token_index: usize) -> Option<Offsets> {
let (_, offsets) = self.encoding.token_to_chars(token_index)?; let (_, offsets) = self.encoding.token_to_chars(token_index)?;
Some(offsets) Some(offsets)
@ -337,7 +330,7 @@ impl PyEncoding {
/// ///
/// Returns: /// Returns:
/// :obj:`int`: The index of the word in the relevant input sequence. /// :obj:`int`: The index of the word in the relevant input sequence.
#[text_signature = "(self, token_index)"] #[pyo3(text_signature = "(self, token_index)")]
fn token_to_word(&self, token_index: usize) -> Option<u32> { fn token_to_word(&self, token_index: usize) -> Option<u32> {
let (_, word_idx) = self.encoding.token_to_word(token_index)?; let (_, word_idx) = self.encoding.token_to_word(token_index)?;
Some(word_idx) Some(word_idx)
@ -354,7 +347,7 @@ impl PyEncoding {
/// Returns: /// Returns:
/// :obj:`int`: The index of the token that contains this char in the encoded sequence /// :obj:`int`: The index of the token that contains this char in the encoded sequence
#[args(sequence_index = 0)] #[args(sequence_index = 0)]
#[text_signature = "(self, char_pos, sequence_index=0)"] #[pyo3(text_signature = "(self, char_pos, sequence_index=0)")]
fn char_to_token(&self, char_pos: usize, sequence_index: usize) -> Option<usize> { fn char_to_token(&self, char_pos: usize, sequence_index: usize) -> Option<usize> {
self.encoding.char_to_token(char_pos, sequence_index) self.encoding.char_to_token(char_pos, sequence_index)
} }
@ -370,7 +363,7 @@ impl PyEncoding {
/// Returns: /// Returns:
/// :obj:`int`: The index of the word that contains this char in the input sequence /// :obj:`int`: The index of the word that contains this char in the input sequence
#[args(sequence_index = 0)] #[args(sequence_index = 0)]
#[text_signature = "(self, char_pos, sequence_index=0)"] #[pyo3(text_signature = "(self, char_pos, sequence_index=0)")]
fn char_to_word(&self, char_pos: usize, sequence_index: usize) -> Option<u32> { fn char_to_word(&self, char_pos: usize, sequence_index: usize) -> Option<u32> {
self.encoding.char_to_word(char_pos, sequence_index) self.encoding.char_to_word(char_pos, sequence_index)
} }
@ -393,7 +386,9 @@ impl PyEncoding {
/// pad_token (:obj:`str`, defaults to `[PAD]`): /// pad_token (:obj:`str`, defaults to `[PAD]`):
/// The pad token to use /// The pad token to use
#[args(kwargs = "**")] #[args(kwargs = "**")]
#[text_signature = "(self, length, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]')"] #[pyo3(
text_signature = "(self, length, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]')"
)]
fn pad(&mut self, length: usize, kwargs: Option<&PyDict>) -> PyResult<()> { fn pad(&mut self, length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut pad_id = 0; let mut pad_id = 0;
let mut pad_type_id = 0; let mut pad_type_id = 0;
@ -445,7 +440,7 @@ impl PyEncoding {
/// Truncate direction /// Truncate direction
#[args(stride = "0")] #[args(stride = "0")]
#[args(direction = "\"right\"")] #[args(direction = "\"right\"")]
#[text_signature = "(self, max_length, stride=0, direction='right')"] #[pyo3(text_signature = "(self, max_length, stride=0, direction='right')")]
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> { fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> {
let tdir = match direction { let tdir = match direction {
"left" => Ok(TruncationDirection::Left), "left" => Ok(TruncationDirection::Left),

View File

@ -37,7 +37,7 @@ impl<T> ToPyResult<T> {
pub(crate) fn deprecation_warning(version: &str, message: &str) -> PyResult<()> { pub(crate) fn deprecation_warning(version: &str, message: &str) -> PyResult<()> {
let gil = pyo3::Python::acquire_gil(); let gil = pyo3::Python::acquire_gil();
let python = gil.python(); let python = gil.python();
let deprecation_warning = python.import("builtins")?.get("DeprecationWarning")?; let deprecation_warning = python.import("builtins")?.getattr("DeprecationWarning")?;
let full_message = format!("Deprecated in {}: {}", version, message); let full_message = format!("Deprecated in {}: {}", version, message);
pyo3::PyErr::warn(python, deprecation_warning, &full_message, 0) pyo3::PyErr::warn(python, deprecation_warning, &full_message, 0)
} }

View File

@ -126,7 +126,7 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
/// Tokenizers Module /// Tokenizers Module
#[pymodule] #[pymodule]
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> { fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
env_logger::init_from_env("TOKENIZERS_LOG"); let _ = env_logger::try_init_from_env("TOKENIZERS_LOG");
// Register the fork callback // Register the fork callback
#[cfg(target_family = "unix")] #[cfg(target_family = "unix")]

View File

@ -24,7 +24,7 @@ use super::error::{deprecation_warning, ToPyResult};
/// will contain and manage the learned vocabulary. /// will contain and manage the learned vocabulary.
/// ///
/// This class cannot be constructed directly. Please use one of the concrete models. /// This class cannot be constructed directly. Please use one of the concrete models.
#[pyclass(module = "tokenizers.models", name=Model)] #[pyclass(module = "tokenizers.models", name = "Model", subclass)]
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct PyModel { pub struct PyModel {
#[serde(flatten)] #[serde(flatten)]
@ -132,7 +132,7 @@ impl PyModel {
/// ///
/// Returns: /// Returns:
/// A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens /// A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens
#[text_signature = "(self, sequence)"] #[pyo3(text_signature = "(self, sequence)")]
fn tokenize(&self, sequence: &str) -> PyResult<Vec<PyToken>> { fn tokenize(&self, sequence: &str) -> PyResult<Vec<PyToken>> {
Ok(ToPyResult(self.model.read().unwrap().tokenize(sequence)) Ok(ToPyResult(self.model.read().unwrap().tokenize(sequence))
.into_py()? .into_py()?
@ -149,7 +149,7 @@ impl PyModel {
/// ///
/// Returns: /// Returns:
/// :obj:`int`: The ID associated to the token /// :obj:`int`: The ID associated to the token
#[text_signature = "(self, tokens)"] #[pyo3(text_signature = "(self, tokens)")]
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {
self.model.read().unwrap().token_to_id(token) self.model.read().unwrap().token_to_id(token)
} }
@ -162,7 +162,7 @@ impl PyModel {
/// ///
/// Returns: /// Returns:
/// :obj:`str`: The token associated to the ID /// :obj:`str`: The token associated to the ID
#[text_signature = "(self, id)"] #[pyo3(text_signature = "(self, id)")]
fn id_to_token(&self, id: u32) -> Option<String> { fn id_to_token(&self, id: u32) -> Option<String> {
self.model.read().unwrap().id_to_token(id) self.model.read().unwrap().id_to_token(id)
} }
@ -182,7 +182,7 @@ impl PyModel {
/// ///
/// Returns: /// Returns:
/// :obj:`List[str]`: The list of saved files /// :obj:`List[str]`: The list of saved files
#[text_signature = "(self, folder, prefix)"] #[pyo3(text_signature = "(self, folder, prefix)")]
fn save<'a>( fn save<'a>(
&self, &self,
folder: &str, folder: &str,
@ -248,8 +248,10 @@ impl PyModel {
/// ///
/// fuse_unk (:obj:`bool`, `optional`): /// fuse_unk (:obj:`bool`, `optional`):
/// Whether to fuse any subsequent unknown tokens into a single one /// Whether to fuse any subsequent unknown tokens into a single one
#[pyclass(extends=PyModel, module = "tokenizers.models", name=BPE)] #[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
#[text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None)"] #[pyo3(
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None)"
)]
pub struct PyBPE {} pub struct PyBPE {}
impl PyBPE { impl PyBPE {
@ -437,7 +439,7 @@ impl PyBPE {
/// A :obj:`Tuple` with the vocab and the merges: /// A :obj:`Tuple` with the vocab and the merges:
/// The vocabulary and merges loaded into memory /// The vocabulary and merges loaded into memory
#[staticmethod] #[staticmethod]
#[text_signature = "(self, vocab, merges)"] #[pyo3(text_signature = "(self, vocab, merges)")]
fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> { fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> {
BPE::read_file(vocab, merges).map_err(|e| { BPE::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!( exceptions::PyException::new_err(format!(
@ -469,7 +471,7 @@ impl PyBPE {
/// :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files /// :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files
#[classmethod] #[classmethod]
#[args(kwargs = "**")] #[args(kwargs = "**")]
#[text_signature = "(cls, vocab, merge, **kwargs)"] #[pyo3(text_signature = "(cls, vocab, merge, **kwargs)")]
fn from_file( fn from_file(
_cls: &PyType, _cls: &PyType,
py: Python, py: Python,
@ -502,8 +504,8 @@ impl PyBPE {
/// ///
/// max_input_chars_per_word (:obj:`int`, `optional`): /// max_input_chars_per_word (:obj:`int`, `optional`):
/// The maximum number of characters to authorize in a single word. /// The maximum number of characters to authorize in a single word.
#[pyclass(extends=PyModel, module = "tokenizers.models", name=WordPiece)] #[pyclass(extends=PyModel, module = "tokenizers.models", name = "WordPiece")]
#[text_signature = "(self, vocab, unk_token, max_input_chars_per_word)"] #[pyo3(text_signature = "(self, vocab, unk_token, max_input_chars_per_word)")]
pub struct PyWordPiece {} pub struct PyWordPiece {}
impl PyWordPiece { impl PyWordPiece {
@ -613,7 +615,7 @@ impl PyWordPiece {
/// Returns: /// Returns:
/// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` /// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
#[staticmethod] #[staticmethod]
#[text_signature = "(vocab)"] #[pyo3(text_signature = "(vocab)")]
fn read_file(vocab: &str) -> PyResult<Vocab> { fn read_file(vocab: &str) -> PyResult<Vocab> {
WordPiece::read_file(vocab).map_err(|e| { WordPiece::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e)) exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e))
@ -639,7 +641,7 @@ impl PyWordPiece {
/// :class:`~tokenizers.models.WordPiece`: An instance of WordPiece loaded from file /// :class:`~tokenizers.models.WordPiece`: An instance of WordPiece loaded from file
#[classmethod] #[classmethod]
#[args(kwargs = "**")] #[args(kwargs = "**")]
#[text_signature = "(vocab, **kwargs)"] #[pyo3(text_signature = "(vocab, **kwargs)")]
fn from_file( fn from_file(
_cls: &PyType, _cls: &PyType,
py: Python, py: Python,
@ -663,8 +665,8 @@ impl PyWordPiece {
/// ///
/// unk_token (:obj:`str`, `optional`): /// unk_token (:obj:`str`, `optional`):
/// The unknown token to be used by the model. /// The unknown token to be used by the model.
#[pyclass(extends=PyModel, module = "tokenizers.models", name=WordLevel)] #[pyclass(extends=PyModel, module = "tokenizers.models", name = "WordLevel")]
#[text_signature = "(self, vocab, unk_token)"] #[pyo3(text_signature = "(self, vocab, unk_token)")]
pub struct PyWordLevel {} pub struct PyWordLevel {}
#[pymethods] #[pymethods]
@ -725,7 +727,7 @@ impl PyWordLevel {
/// Returns: /// Returns:
/// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` /// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
#[staticmethod] #[staticmethod]
#[text_signature = "(vocab)"] #[pyo3(text_signature = "(vocab)")]
fn read_file(vocab: &str) -> PyResult<Vocab> { fn read_file(vocab: &str) -> PyResult<Vocab> {
WordLevel::read_file(vocab).map_err(|e| { WordLevel::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e)) exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
@ -751,7 +753,7 @@ impl PyWordLevel {
/// :class:`~tokenizers.models.WordLevel`: An instance of WordLevel loaded from file /// :class:`~tokenizers.models.WordLevel`: An instance of WordLevel loaded from file
#[classmethod] #[classmethod]
#[args(unk_token = "None")] #[args(unk_token = "None")]
#[text_signature = "(vocab, unk_token)"] #[pyo3(text_signature = "(vocab, unk_token)")]
fn from_file( fn from_file(
_cls: &PyType, _cls: &PyType,
py: Python, py: Python,
@ -773,8 +775,8 @@ impl PyWordLevel {
/// Args: /// Args:
/// vocab (:obj:`List[Tuple[str, float]]`, `optional`): /// vocab (:obj:`List[Tuple[str, float]]`, `optional`):
/// A list of vocabulary items and their relative score [("am", -0.2442),...] /// A list of vocabulary items and their relative score [("am", -0.2442),...]
#[pyclass(extends=PyModel, module = "tokenizers.models", name=Unigram)] #[pyclass(extends=PyModel, module = "tokenizers.models", name = "Unigram")]
#[text_signature = "(self, vocab)"] #[pyo3(text_signature = "(self, vocab)")]
pub struct PyUnigram {} pub struct PyUnigram {}
#[pymethods] #[pymethods]
@ -809,8 +811,8 @@ mod test {
let py_bpe = py_model.get_as_subtype().unwrap(); let py_bpe = py_model.get_as_subtype().unwrap();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
assert_eq!( assert_eq!(
"tokenizers.models.BPE", "BPE",
py_bpe.as_ref(gil.python()).get_type().name() py_bpe.as_ref(gil.python()).get_type().name().unwrap()
); );
} }

View File

@ -3,7 +3,6 @@ use std::sync::{Arc, RwLock};
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use pyo3::PySequenceProtocol;
use crate::error::ToPyResult; use crate::error::ToPyResult;
use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
@ -43,7 +42,7 @@ impl PyNormalizedStringMut<'_> {
/// ///
/// This class is not supposed to be instantiated directly. Instead, any implementation of a /// This class is not supposed to be instantiated directly. Instead, any implementation of a
/// Normalizer will return an instance of this class when instantiated. /// Normalizer will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.normalizers", name=Normalizer)] #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct PyNormalizer { pub struct PyNormalizer {
#[serde(flatten)] #[serde(flatten)]
@ -144,7 +143,7 @@ impl PyNormalizer {
/// normalized (:class:`~tokenizers.NormalizedString`): /// normalized (:class:`~tokenizers.NormalizedString`):
/// The normalized string on which to apply this /// The normalized string on which to apply this
/// :class:`~tokenizers.normalizers.Normalizer` /// :class:`~tokenizers.normalizers.Normalizer`
#[text_signature = "(self, normalized)"] #[pyo3(text_signature = "(self, normalized)")]
fn normalize(&self, mut normalized: PyNormalizedStringMut) -> PyResult<()> { fn normalize(&self, mut normalized: PyNormalizedStringMut) -> PyResult<()> {
normalized.normalize_with(&self.normalizer) normalized.normalize_with(&self.normalizer)
} }
@ -162,7 +161,7 @@ impl PyNormalizer {
/// ///
/// Returns: /// Returns:
/// :obj:`str`: A string after normalization /// :obj:`str`: A string after normalization
#[text_signature = "(self, sequence)"] #[pyo3(text_signature = "(self, sequence)")]
fn normalize_str(&self, sequence: &str) -> PyResult<String> { fn normalize_str(&self, sequence: &str) -> PyResult<String> {
let mut normalized = NormalizedString::from(sequence); let mut normalized = NormalizedString::from(sequence);
ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?; ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?;
@ -217,8 +216,10 @@ macro_rules! setter {
/// ///
/// lowercase (:obj:`bool`, `optional`, defaults to :obj:`True`): /// lowercase (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to lowercase. /// Whether to lowercase.
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=BertNormalizer)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "BertNormalizer")]
#[text_signature = "(self, clean_text=True, handle_chinese_chars=True, strip_accents=None, lowercase=True)"] #[pyo3(
text_signature = "(self, clean_text=True, handle_chinese_chars=True, strip_accents=None, lowercase=True)"
)]
pub struct PyBertNormalizer {} pub struct PyBertNormalizer {}
#[pymethods] #[pymethods]
impl PyBertNormalizer { impl PyBertNormalizer {
@ -287,8 +288,8 @@ impl PyBertNormalizer {
} }
/// NFD Unicode Normalizer /// NFD Unicode Normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFD)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "NFD")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyNFD {} pub struct PyNFD {}
#[pymethods] #[pymethods]
impl PyNFD { impl PyNFD {
@ -299,8 +300,8 @@ impl PyNFD {
} }
/// NFKD Unicode Normalizer /// NFKD Unicode Normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFKD)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "NFKD")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyNFKD {} pub struct PyNFKD {}
#[pymethods] #[pymethods]
impl PyNFKD { impl PyNFKD {
@ -311,8 +312,8 @@ impl PyNFKD {
} }
/// NFC Unicode Normalizer /// NFC Unicode Normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFC)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "NFC")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyNFC {} pub struct PyNFC {}
#[pymethods] #[pymethods]
impl PyNFC { impl PyNFC {
@ -323,8 +324,8 @@ impl PyNFC {
} }
/// NFKC Unicode Normalizer /// NFKC Unicode Normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFKC)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "NFKC")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyNFKC {} pub struct PyNFKC {}
#[pymethods] #[pymethods]
impl PyNFKC { impl PyNFKC {
@ -340,7 +341,7 @@ impl PyNFKC {
/// Args: /// Args:
/// normalizers (:obj:`List[Normalizer]`): /// normalizers (:obj:`List[Normalizer]`):
/// A list of Normalizer to be run as a sequence /// A list of Normalizer to be run as a sequence
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Sequence)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")]
pub struct PySequence {} pub struct PySequence {}
#[pymethods] #[pymethods]
impl PySequence { impl PySequence {
@ -363,18 +364,15 @@ impl PySequence {
fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple { fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple {
PyTuple::new(py, &[PyList::empty(py)]) PyTuple::new(py, &[PyList::empty(py)])
} }
}
#[pyproto]
impl PySequenceProtocol for PySequence {
fn __len__(&self) -> usize { fn __len__(&self) -> usize {
0 0
} }
} }
/// Lowercase Normalizer /// Lowercase Normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Lowercase)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Lowercase")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyLowercase {} pub struct PyLowercase {}
#[pymethods] #[pymethods]
impl PyLowercase { impl PyLowercase {
@ -385,8 +383,8 @@ impl PyLowercase {
} }
/// Strip normalizer /// Strip normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Strip)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Strip")]
#[text_signature = "(self, left=True, right=True)"] #[pyo3(text_signature = "(self, left=True, right=True)")]
pub struct PyStrip {} pub struct PyStrip {}
#[pymethods] #[pymethods]
impl PyStrip { impl PyStrip {
@ -418,8 +416,8 @@ impl PyStrip {
} }
/// StripAccents normalizer /// StripAccents normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=StripAccents)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "StripAccents")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyStripAccents {} pub struct PyStripAccents {}
#[pymethods] #[pymethods]
impl PyStripAccents { impl PyStripAccents {
@ -430,8 +428,8 @@ impl PyStripAccents {
} }
/// Nmt normalizer /// Nmt normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Nmt)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Nmt")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyNmt {} pub struct PyNmt {}
#[pymethods] #[pymethods]
impl PyNmt { impl PyNmt {
@ -443,8 +441,8 @@ impl PyNmt {
/// Precompiled normalizer /// Precompiled normalizer
/// Don't use manually it is used for compatiblity for SentencePiece. /// Don't use manually it is used for compatiblity for SentencePiece.
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Precompiled)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Precompiled")]
#[text_signature = "(self, precompiled_charsmap)"] #[pyo3(text_signature = "(self, precompiled_charsmap)")]
pub struct PyPrecompiled {} pub struct PyPrecompiled {}
#[pymethods] #[pymethods]
impl PyPrecompiled { impl PyPrecompiled {
@ -466,8 +464,8 @@ impl PyPrecompiled {
} }
/// Replace normalizer /// Replace normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Replace)] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Replace")]
#[text_signature = "(self, pattern, content)"] #[pyo3(text_signature = "(self, pattern, content)")]
pub struct PyReplace {} pub struct PyReplace {}
#[pymethods] #[pymethods]
impl PyReplace { impl PyReplace {
@ -630,8 +628,8 @@ mod test {
let py_nfc = py_norm.get_as_subtype().unwrap(); let py_nfc = py_norm.get_as_subtype().unwrap();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
assert_eq!( assert_eq!(
"tokenizers.normalizers.NFC", "NFC",
py_nfc.as_ref(gil.python()).get_type().name() py_nfc.as_ref(gil.python()).get_type().name().unwrap()
); );
} }

View File

@ -28,7 +28,12 @@ use super::utils::*;
/// ///
/// This class is not supposed to be instantiated directly. Instead, any implementation of a /// This class is not supposed to be instantiated directly. Instead, any implementation of a
/// PreTokenizer will return an instance of this class when instantiated. /// PreTokenizer will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)] #[pyclass(
dict,
module = "tokenizers.pre_tokenizers",
name = "PreTokenizer",
subclass
)]
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct PyPreTokenizer { pub struct PyPreTokenizer {
#[serde(flatten)] #[serde(flatten)]
@ -146,7 +151,7 @@ impl PyPreTokenizer {
/// pretok (:class:`~tokenizers.PreTokenizedString): /// pretok (:class:`~tokenizers.PreTokenizedString):
/// The pre-tokenized string on which to apply this /// The pre-tokenized string on which to apply this
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer` /// :class:`~tokenizers.pre_tokenizers.PreTokenizer`
#[text_signature = "(self, pretok)"] #[pyo3(text_signature = "(self, pretok)")]
fn pre_tokenize(&self, pretok: &mut PyPreTokenizedString) -> PyResult<()> { fn pre_tokenize(&self, pretok: &mut PyPreTokenizedString) -> PyResult<()> {
ToPyResult(self.pretok.pre_tokenize(&mut pretok.pretok)).into() ToPyResult(self.pretok.pre_tokenize(&mut pretok.pretok)).into()
} }
@ -166,7 +171,7 @@ impl PyPreTokenizer {
/// Returns: /// Returns:
/// :obj:`List[Tuple[str, Offsets]]`: /// :obj:`List[Tuple[str, Offsets]]`:
/// A list of tuple with the pre-tokenized parts and their offsets /// A list of tuple with the pre-tokenized parts and their offsets
#[text_signature = "(self, sequence)"] #[pyo3(text_signature = "(self, sequence)")]
fn pre_tokenize_str(&self, s: &str) -> PyResult<Vec<(String, Offsets)>> { fn pre_tokenize_str(&self, s: &str) -> PyResult<Vec<(String, Offsets)>> {
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s); let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
@ -231,8 +236,8 @@ macro_rules! setter {
/// use_regex (:obj:`bool`, `optional`, defaults to :obj:`True`): /// use_regex (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Set this to :obj:`False` to prevent this `pre_tokenizer` from using /// Set this to :obj:`False` to prevent this `pre_tokenizer` from using
/// the GPT2 specific regexp for spliting on whitespace. /// the GPT2 specific regexp for spliting on whitespace.
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "ByteLevel")]
#[text_signature = "(self, add_prefix_space=True, use_regex=True)"] #[pyo3(text_signature = "(self, add_prefix_space=True, use_regex=True)")]
pub struct PyByteLevel {} pub struct PyByteLevel {}
#[pymethods] #[pymethods]
impl PyByteLevel { impl PyByteLevel {
@ -281,7 +286,7 @@ impl PyByteLevel {
/// Returns: /// Returns:
/// :obj:`List[str]`: A list of characters that compose the alphabet /// :obj:`List[str]`: A list of characters that compose the alphabet
#[staticmethod] #[staticmethod]
#[text_signature = "()"] #[pyo3(text_signature = "()")]
fn alphabet() -> Vec<String> { fn alphabet() -> Vec<String> {
ByteLevel::alphabet() ByteLevel::alphabet()
.into_iter() .into_iter()
@ -291,8 +296,8 @@ impl PyByteLevel {
} }
/// This pre-tokenizer simply splits using the following regex: `\w+|[^\w\s]+` /// This pre-tokenizer simply splits using the following regex: `\w+|[^\w\s]+`
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Whitespace)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Whitespace")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyWhitespace {} pub struct PyWhitespace {}
#[pymethods] #[pymethods]
impl PyWhitespace { impl PyWhitespace {
@ -303,8 +308,8 @@ impl PyWhitespace {
} }
/// This pre-tokenizer simply splits on the whitespace. Works like `.split()` /// This pre-tokenizer simply splits on the whitespace. Works like `.split()`
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=WhitespaceSplit)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "WhitespaceSplit")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyWhitespaceSplit {} pub struct PyWhitespaceSplit {}
#[pymethods] #[pymethods]
impl PyWhitespaceSplit { impl PyWhitespaceSplit {
@ -331,8 +336,8 @@ impl PyWhitespaceSplit {
/// ///
/// invert (:obj:`bool`, `optional`, defaults to :obj:`False`): /// invert (:obj:`bool`, `optional`, defaults to :obj:`False`):
/// Whether to invert the pattern. /// Whether to invert the pattern.
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Split)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Split")]
#[text_signature = "(self, pattern, behavior, invert=False)"] #[pyo3(text_signature = "(self, pattern, behavior, invert=False)")]
pub struct PySplit {} pub struct PySplit {}
#[pymethods] #[pymethods]
impl PySplit { impl PySplit {
@ -361,7 +366,7 @@ impl PySplit {
/// Args: /// Args:
/// delimiter: str: /// delimiter: str:
/// The delimiter char that will be used to split input /// The delimiter char that will be used to split input
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=CharDelimiterSplit)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "CharDelimiterSplit")]
pub struct PyCharDelimiterSplit {} pub struct PyCharDelimiterSplit {}
#[pymethods] #[pymethods]
impl PyCharDelimiterSplit { impl PyCharDelimiterSplit {
@ -392,8 +397,8 @@ impl PyCharDelimiterSplit {
/// ///
/// This pre-tokenizer splits tokens on spaces, and also on punctuation. /// This pre-tokenizer splits tokens on spaces, and also on punctuation.
/// Each occurence of a punctuation character will be treated separately. /// Each occurence of a punctuation character will be treated separately.
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=BertPreTokenizer)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "BertPreTokenizer")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyBertPreTokenizer {} pub struct PyBertPreTokenizer {}
#[pymethods] #[pymethods]
impl PyBertPreTokenizer { impl PyBertPreTokenizer {
@ -410,8 +415,8 @@ impl PyBertPreTokenizer {
/// The behavior to use when splitting. /// The behavior to use when splitting.
/// Choices: "removed", "isolated" (default), "merged_with_previous", "merged_with_next", /// Choices: "removed", "isolated" (default), "merged_with_previous", "merged_with_next",
/// "contiguous" /// "contiguous"
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Punctuation)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Punctuation")]
#[text_signature = "(self, behavior=\"isolated\")"] #[pyo3(text_signature = "(self, behavior=\"isolated\")")]
pub struct PyPunctuation {} pub struct PyPunctuation {}
#[pymethods] #[pymethods]
impl PyPunctuation { impl PyPunctuation {
@ -423,8 +428,8 @@ impl PyPunctuation {
} }
/// This pre-tokenizer composes other pre_tokenizers and applies them in sequence /// This pre-tokenizer composes other pre_tokenizers and applies them in sequence
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Sequence)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Sequence")]
#[text_signature = "(self, pretokenizers)"] #[pyo3(text_signature = "(self, pretokenizers)")]
pub struct PySequence {} pub struct PySequence {}
#[pymethods] #[pymethods]
impl PySequence { impl PySequence {
@ -464,8 +469,8 @@ impl PySequence {
/// add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`): /// add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to add a space to the first word if there isn't already one. This /// Whether to add a space to the first word if there isn't already one. This
/// lets us treat `hello` exactly like `say hello`. /// lets us treat `hello` exactly like `say hello`.
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Metaspace)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Metaspace")]
#[text_signature = "(self, replacement=\"_\", add_prefix_space=True)"] #[pyo3(text_signature = "(self, replacement=\"_\", add_prefix_space=True)")]
pub struct PyMetaspace {} pub struct PyMetaspace {}
#[pymethods] #[pymethods]
impl PyMetaspace { impl PyMetaspace {
@ -514,8 +519,8 @@ impl PyMetaspace {
/// If set to False, digits will grouped as follows:: /// If set to False, digits will grouped as follows::
/// ///
/// "Call 123 please" -> "Call ", "123", " please" /// "Call 123 please" -> "Call ", "123", " please"
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Digits)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Digits")]
#[text_signature = "(self, individual_digits=False)"] #[pyo3(text_signature = "(self, individual_digits=False)")]
pub struct PyDigits {} pub struct PyDigits {}
#[pymethods] #[pymethods]
impl PyDigits { impl PyDigits {
@ -540,8 +545,8 @@ impl PyDigits {
/// It roughly follows https://github.com/google/sentencepiece/blob/master/data/Scripts.txt /// It roughly follows https://github.com/google/sentencepiece/blob/master/data/Scripts.txt
/// Actually Hiragana and Katakana are fused with Han, and 0x30FC is Han too. /// Actually Hiragana and Katakana are fused with Han, and 0x30FC is Han too.
/// This mimicks SentencePiece Unigram implementation. /// This mimicks SentencePiece Unigram implementation.
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=UnicodeScripts)] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "UnicodeScripts")]
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
pub struct PyUnicodeScripts {} pub struct PyUnicodeScripts {}
#[pymethods] #[pymethods]
impl PyUnicodeScripts { impl PyUnicodeScripts {
@ -704,8 +709,8 @@ mod test {
let py_wsp = py_norm.get_as_subtype().unwrap(); let py_wsp = py_norm.get_as_subtype().unwrap();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
assert_eq!( assert_eq!(
"tokenizers.pre_tokenizers.Whitespace", "Whitespace",
py_wsp.as_ref(gil.python()).get_type().name() py_wsp.as_ref(gil.python()).get_type().name().unwrap()
); );
} }

View File

@ -20,7 +20,12 @@ use tokenizers as tk;
/// ///
/// This class is not supposed to be instantiated directly. Instead, any implementation of /// This class is not supposed to be instantiated directly. Instead, any implementation of
/// a PostProcessor will return an instance of this class when instantiated. /// a PostProcessor will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.processors", name=PostProcessor)] #[pyclass(
dict,
module = "tokenizers.processors",
name = "PostProcessor",
subclass
)]
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
pub struct PyPostProcessor { pub struct PyPostProcessor {
#[serde(flatten)] #[serde(flatten)]
@ -100,7 +105,7 @@ impl PyPostProcessor {
/// ///
/// Returns: /// Returns:
/// :obj:`int`: The number of tokens to add /// :obj:`int`: The number of tokens to add
#[text_signature = "(self, is_pair)"] #[pyo3(text_signature = "(self, is_pair)")]
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize { fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair) self.processor.added_tokens(is_pair)
} }
@ -120,7 +125,7 @@ impl PyPostProcessor {
/// Return: /// Return:
/// :class:`~tokenizers.Encoding`: The final encoding /// :class:`~tokenizers.Encoding`: The final encoding
#[args(pair = "None", add_special_tokens = "true")] #[args(pair = "None", add_special_tokens = "true")]
#[text_signature = "(self, encoding, pair=None, add_special_tokens=True)"] #[pyo3(text_signature = "(self, encoding, pair=None, add_special_tokens=True)")]
fn process( fn process(
&self, &self,
encoding: &PyEncoding, encoding: &PyEncoding,
@ -149,8 +154,8 @@ impl PyPostProcessor {
/// ///
/// cls (:obj:`Tuple[str, int]`): /// cls (:obj:`Tuple[str, int]`):
/// A tuple with the string representation of the CLS token, and its id /// A tuple with the string representation of the CLS token, and its id
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=BertProcessing)] #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "BertProcessing")]
#[text_signature = "(self, sep, cls)"] #[pyo3(text_signature = "(self, sep, cls)")]
pub struct PyBertProcessing {} pub struct PyBertProcessing {}
#[pymethods] #[pymethods]
impl PyBertProcessing { impl PyBertProcessing {
@ -191,8 +196,8 @@ impl PyBertProcessing {
/// add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`): /// add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether the add_prefix_space option was enabled during pre-tokenization. This /// Whether the add_prefix_space option was enabled during pre-tokenization. This
/// is relevant because it defines the way the offsets are trimmed out. /// is relevant because it defines the way the offsets are trimmed out.
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=RobertaProcessing)] #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "RobertaProcessing")]
#[text_signature = "(self, sep, cls, trim_offsets=True, add_prefix_space=True)"] #[pyo3(text_signature = "(self, sep, cls, trim_offsets=True, add_prefix_space=True)")]
pub struct PyRobertaProcessing {} pub struct PyRobertaProcessing {}
#[pymethods] #[pymethods]
impl PyRobertaProcessing { impl PyRobertaProcessing {
@ -226,8 +231,8 @@ impl PyRobertaProcessing {
/// Args: /// Args:
/// trim_offsets (:obj:`bool`): /// trim_offsets (:obj:`bool`):
/// Whether to trim the whitespaces from the produced offsets. /// Whether to trim the whitespaces from the produced offsets.
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=ByteLevel)] #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "ByteLevel")]
#[text_signature = "(self, trim_offsets=True)"] #[pyo3(text_signature = "(self, trim_offsets=True)")]
pub struct PyByteLevel {} pub struct PyByteLevel {}
#[pymethods] #[pymethods]
impl PyByteLevel { impl PyByteLevel {
@ -378,8 +383,8 @@ impl FromPyObject<'_> for PyTemplate {
/// ///
/// The given dict expects the provided :obj:`ids` and :obj:`tokens` lists to have /// The given dict expects the provided :obj:`ids` and :obj:`tokens` lists to have
/// the same length. /// the same length.
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=TemplateProcessing)] #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "TemplateProcessing")]
#[text_signature = "(self, single, pair, special_tokens)"] #[pyo3(text_signature = "(self, single, pair, special_tokens)")]
pub struct PyTemplateProcessing {} pub struct PyTemplateProcessing {}
#[pymethods] #[pymethods]
impl PyTemplateProcessing { impl PyTemplateProcessing {
@ -428,8 +433,8 @@ mod test {
let py_bert = py_proc.get_as_subtype().unwrap(); let py_bert = py_proc.get_as_subtype().unwrap();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
assert_eq!( assert_eq!(
"tokenizers.processors.BertProcessing", "BertProcessing",
py_bert.as_ref(gil.python()).get_type().name() py_bert.as_ref(gil.python()).get_type().name().unwrap()
); );
} }

View File

@ -1,7 +1,7 @@
use pyo3::prelude::*; use pyo3::prelude::*;
use tk::Token; use tk::Token;
#[pyclass(module = "tokenizers", name=Token)] #[pyclass(module = "tokenizers", name = "Token")]
#[derive(Clone)] #[derive(Clone)]
pub struct PyToken { pub struct PyToken {
token: Token, token: Token,

View File

@ -1,12 +1,12 @@
use std::collections::{hash_map::DefaultHasher, HashMap}; use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use numpy::PyArray1; use numpy::{npyffi, PyArray1};
use pyo3::class::basic::CompareOp; use pyo3::class::basic::CompareOp;
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use pyo3::PyObjectProtocol; use pyo3::AsPyPointer;
use tk::models::bpe::BPE; use tk::models::bpe::BPE;
use tk::tokenizer::{ use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
@ -55,8 +55,10 @@ use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
/// lowercasing the text, the token could be extract from the input ``"I saw a lion /// lowercasing the text, the token could be extract from the input ``"I saw a lion
/// Yesterday"``. /// Yesterday"``.
/// ///
#[pyclass(dict, module = "tokenizers", name=AddedToken)] #[pyclass(dict, module = "tokenizers", name = "AddedToken")]
#[text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True)"] #[pyo3(
text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True)"
)]
pub struct PyAddedToken { pub struct PyAddedToken {
pub content: String, pub content: String,
pub is_special_token: bool, pub is_special_token: bool,
@ -199,10 +201,8 @@ impl PyAddedToken {
fn get_normalized(&self) -> bool { fn get_normalized(&self) -> bool {
self.get_token().normalized self.get_token().normalized
} }
}
#[pyproto] fn __str__(&self) -> PyResult<&str> {
impl PyObjectProtocol for PyAddedToken {
fn __str__(&'p self) -> PyResult<&'p str> {
Ok(&self.content) Ok(&self.content)
} }
@ -259,38 +259,54 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
struct PyArrayUnicode(Vec<String>); struct PyArrayUnicode(Vec<String>);
impl FromPyObject<'_> for PyArrayUnicode { impl FromPyObject<'_> for PyArrayUnicode {
fn extract(ob: &PyAny) -> PyResult<Self> { fn extract(ob: &PyAny) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<u8>>()?; // SAFETY Making sure the pointer is a valid numpy array requires calling numpy C code
let arr = array.as_array_ptr(); if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
let (type_num, elsize, alignment, data) = unsafe { return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
}
let arr = ob.as_ptr() as *mut npyffi::PyArrayObject;
// SAFETY Getting all the metadata about the numpy array to check its sanity
let (type_num, elsize, alignment, data, nd, flags) = unsafe {
let desc = (*arr).descr; let desc = (*arr).descr;
( (
(*desc).type_num, (*desc).type_num,
(*desc).elsize as usize, (*desc).elsize as usize,
(*desc).alignment as usize, (*desc).alignment as usize,
(*arr).data, (*arr).data,
(*arr).nd,
(*arr).flags,
) )
}; };
let n_elem = array.shape()[0];
// type_num == 19 => Unicode if nd != 1 {
if type_num != 19 { return Err(exceptions::PyTypeError::new_err(
"Expected a 1 dimensional np.array",
));
}
if flags & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) == 0 {
return Err(exceptions::PyTypeError::new_err(
"Expected a contiguous np.array",
));
}
if type_num != npyffi::types::NPY_TYPES::NPY_UNICODE as i32 {
return Err(exceptions::PyTypeError::new_err( return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='U']", "Expected a np.array[dtype='U']",
)); ));
} }
// SAFETY Looking at the raw numpy data to create new owned Rust strings via copies (so it's safe afterwards).
unsafe { unsafe {
let n_elem = *(*arr).dimensions as usize;
let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem); let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem);
let seq = (0..n_elem) let seq = (0..n_elem)
.map(|i| { .map(|i| {
let bytes = &all_bytes[i * elsize..(i + 1) * elsize]; let bytes = &all_bytes[i * elsize..(i + 1) * elsize];
let unicode = pyo3::ffi::PyUnicode_FromUnicode( let unicode = pyo3::ffi::PyUnicode_FromKindAndData(
pyo3::ffi::PyUnicode_4BYTE_KIND as _,
bytes.as_ptr() as *const _, bytes.as_ptr() as *const _,
elsize as isize / alignment as isize, elsize as isize / alignment as isize,
); );
let gil = Python::acquire_gil(); let py = ob.py();
let py = gil.python();
let obj = PyObject::from_owned_ptr(py, unicode); let obj = PyObject::from_owned_ptr(py, unicode);
let s = obj.cast_as::<PyString>(py)?; let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned()) Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned())
@ -310,26 +326,13 @@ impl From<PyArrayUnicode> for tk::InputSequence<'_> {
struct PyArrayStr(Vec<String>); struct PyArrayStr(Vec<String>);
impl FromPyObject<'_> for PyArrayStr { impl FromPyObject<'_> for PyArrayStr {
fn extract(ob: &PyAny) -> PyResult<Self> { fn extract(ob: &PyAny) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<u8>>()?; let array = ob.downcast::<PyArray1<PyObject>>()?;
let arr = array.as_array_ptr(); let seq = array
let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) }; .readonly()
let n_elem = array.shape()[0]; .as_array()
if type_num != 17 {
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='O']",
));
}
unsafe {
let objects = std::slice::from_raw_parts(data as *const PyObject, n_elem);
let seq = objects
.iter() .iter()
.map(|obj| { .map(|obj| {
let gil = Python::acquire_gil(); let s = obj.cast_as::<PyString>(ob.py())?;
let py = gil.python();
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string_lossy().into_owned()) Ok(s.to_string_lossy().into_owned())
}) })
.collect::<PyResult<Vec<_>>>()?; .collect::<PyResult<Vec<_>>>()?;
@ -337,7 +340,6 @@ impl FromPyObject<'_> for PyArrayStr {
Ok(Self(seq)) Ok(Self(seq))
} }
} }
}
impl From<PyArrayStr> for tk::InputSequence<'_> { impl From<PyArrayStr> for tk::InputSequence<'_> {
fn from(s: PyArrayStr) -> Self { fn from(s: PyArrayStr) -> Self {
s.0.into() s.0.into()
@ -438,8 +440,8 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
/// model (:class:`~tokenizers.models.Model`): /// model (:class:`~tokenizers.models.Model`):
/// The core algorithm that this :obj:`Tokenizer` should be using. /// The core algorithm that this :obj:`Tokenizer` should be using.
/// ///
#[pyclass(dict, module = "tokenizers", name=Tokenizer)] #[pyclass(dict, module = "tokenizers", name = "Tokenizer")]
#[text_signature = "(self, model)"] #[pyo3(text_signature = "(self, model)")]
#[derive(Clone)] #[derive(Clone)]
pub struct PyTokenizer { pub struct PyTokenizer {
tokenizer: Tokenizer, tokenizer: Tokenizer,
@ -502,7 +504,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :class:`~tokenizers.Tokenizer`: The new tokenizer /// :class:`~tokenizers.Tokenizer`: The new tokenizer
#[staticmethod] #[staticmethod]
#[text_signature = "(json)"] #[pyo3(text_signature = "(json)")]
fn from_str(json: &str) -> PyResult<Self> { fn from_str(json: &str) -> PyResult<Self> {
let tokenizer: PyResult<_> = ToPyResult(json.parse()).into(); let tokenizer: PyResult<_> = ToPyResult(json.parse()).into();
Ok(Self::new(tokenizer?)) Ok(Self::new(tokenizer?))
@ -518,7 +520,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :class:`~tokenizers.Tokenizer`: The new tokenizer /// :class:`~tokenizers.Tokenizer`: The new tokenizer
#[staticmethod] #[staticmethod]
#[text_signature = "(path)"] #[pyo3(text_signature = "(path)")]
fn from_file(path: &str) -> PyResult<Self> { fn from_file(path: &str) -> PyResult<Self> {
let tokenizer: PyResult<_> = ToPyResult(Tokenizer::from_file(path)).into(); let tokenizer: PyResult<_> = ToPyResult(Tokenizer::from_file(path)).into();
Ok(Self::new(tokenizer?)) Ok(Self::new(tokenizer?))
@ -533,7 +535,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :class:`~tokenizers.Tokenizer`: The new tokenizer /// :class:`~tokenizers.Tokenizer`: The new tokenizer
#[staticmethod] #[staticmethod]
#[text_signature = "(buffer)"] #[pyo3(text_signature = "(buffer)")]
fn from_buffer(buffer: &PyBytes) -> PyResult<Self> { fn from_buffer(buffer: &PyBytes) -> PyResult<Self> {
let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| { let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| {
exceptions::PyValueError::new_err(format!( exceptions::PyValueError::new_err(format!(
@ -561,7 +563,7 @@ impl PyTokenizer {
/// :class:`~tokenizers.Tokenizer`: The new tokenizer /// :class:`~tokenizers.Tokenizer`: The new tokenizer
#[staticmethod] #[staticmethod]
#[args(revision = "String::from(\"main\")", auth_token = "None")] #[args(revision = "String::from(\"main\")", auth_token = "None")]
#[text_signature = "(identifier, revision=\"main\", auth_token=None)"] #[pyo3(text_signature = "(identifier, revision=\"main\", auth_token=None)")]
fn from_pretrained( fn from_pretrained(
identifier: &str, identifier: &str,
revision: String, revision: String,
@ -590,7 +592,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :obj:`str`: A string representing the serialized Tokenizer /// :obj:`str`: A string representing the serialized Tokenizer
#[args(pretty = false)] #[args(pretty = false)]
#[text_signature = "(self, pretty=False)"] #[pyo3(text_signature = "(self, pretty=False)")]
fn to_str(&self, pretty: bool) -> PyResult<String> { fn to_str(&self, pretty: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.to_string(pretty)).into() ToPyResult(self.tokenizer.to_string(pretty)).into()
} }
@ -604,7 +606,7 @@ impl PyTokenizer {
/// pretty (:obj:`bool`, defaults to :obj:`True`): /// pretty (:obj:`bool`, defaults to :obj:`True`):
/// Whether the JSON file should be pretty formatted. /// Whether the JSON file should be pretty formatted.
#[args(pretty = true)] #[args(pretty = true)]
#[text_signature = "(self, path, pretty=True)"] #[pyo3(text_signature = "(self, path, pretty=True)")]
fn save(&self, path: &str, pretty: bool) -> PyResult<()> { fn save(&self, path: &str, pretty: bool) -> PyResult<()> {
ToPyResult(self.tokenizer.save(path, pretty)).into() ToPyResult(self.tokenizer.save(path, pretty)).into()
} }
@ -612,7 +614,7 @@ impl PyTokenizer {
/// Return the number of special tokens that would be added for single/pair sentences. /// Return the number of special tokens that would be added for single/pair sentences.
/// :param is_pair: Boolean indicating if the input would be a single sentence or a pair /// :param is_pair: Boolean indicating if the input would be a single sentence or a pair
/// :return: /// :return:
#[text_signature = "(self, is_pair)"] #[pyo3(text_signature = "(self, is_pair)")]
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize { fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
self.tokenizer self.tokenizer
.get_post_processor() .get_post_processor()
@ -628,7 +630,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :obj:`Dict[str, int]`: The vocabulary /// :obj:`Dict[str, int]`: The vocabulary
#[args(with_added_tokens = true)] #[args(with_added_tokens = true)]
#[text_signature = "(self, with_added_tokens=True)"] #[pyo3(text_signature = "(self, with_added_tokens=True)")]
fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> { fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
self.tokenizer.get_vocab(with_added_tokens) self.tokenizer.get_vocab(with_added_tokens)
} }
@ -642,7 +644,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :obj:`int`: The size of the vocabulary /// :obj:`int`: The size of the vocabulary
#[args(with_added_tokens = true)] #[args(with_added_tokens = true)]
#[text_signature = "(self, with_added_tokens=True)"] #[pyo3(text_signature = "(self, with_added_tokens=True)")]
fn get_vocab_size(&self, with_added_tokens: bool) -> usize { fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
self.tokenizer.get_vocab_size(with_added_tokens) self.tokenizer.get_vocab_size(with_added_tokens)
} }
@ -664,7 +666,9 @@ impl PyTokenizer {
/// direction (:obj:`str`, defaults to :obj:`right`): /// direction (:obj:`str`, defaults to :obj:`right`):
/// Truncate direction /// Truncate direction
#[args(kwargs = "**")] #[args(kwargs = "**")]
#[text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')"] #[pyo3(
text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')"
)]
fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> { fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut params = TruncationParams { let mut params = TruncationParams {
max_length, max_length,
@ -714,7 +718,7 @@ impl PyTokenizer {
} }
/// Disable truncation /// Disable truncation
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn no_truncation(&mut self) { fn no_truncation(&mut self) {
self.tokenizer.with_truncation(None); self.tokenizer.with_truncation(None);
} }
@ -764,7 +768,9 @@ impl PyTokenizer {
/// If specified, the length at which to pad. If not specified we pad using the size of /// If specified, the length at which to pad. If not specified we pad using the size of
/// the longest sequence in a batch. /// the longest sequence in a batch.
#[args(kwargs = "**")] #[args(kwargs = "**")]
#[text_signature = "(self, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]', length=None, pad_to_multiple_of=None)"] #[pyo3(
text_signature = "(self, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]', length=None, pad_to_multiple_of=None)"
)]
fn enable_padding(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { fn enable_padding(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut params = PaddingParams::default(); let mut params = PaddingParams::default();
@ -822,7 +828,7 @@ impl PyTokenizer {
} }
/// Disable padding /// Disable padding
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn no_padding(&mut self) { fn no_padding(&mut self) {
self.tokenizer.with_padding(None); self.tokenizer.with_padding(None);
} }
@ -891,7 +897,9 @@ impl PyTokenizer {
/// :class:`~tokenizers.Encoding`: The encoded result /// :class:`~tokenizers.Encoding`: The encoded result
/// ///
#[args(pair = "None", is_pretokenized = "false", add_special_tokens = "true")] #[args(pair = "None", is_pretokenized = "false", add_special_tokens = "true")]
#[text_signature = "(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True)"] #[pyo3(
text_signature = "(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True)"
)]
fn encode( fn encode(
&self, &self,
sequence: &PyAny, sequence: &PyAny,
@ -956,7 +964,7 @@ impl PyTokenizer {
/// A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch /// A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch
/// ///
#[args(is_pretokenized = "false", add_special_tokens = "true")] #[args(is_pretokenized = "false", add_special_tokens = "true")]
#[text_signature = "(self, input, is_pretokenized=False, add_special_tokens=True)"] #[pyo3(text_signature = "(self, input, is_pretokenized=False, add_special_tokens=True)")]
fn encode_batch( fn encode_batch(
&self, &self,
input: Vec<&PyAny>, input: Vec<&PyAny>,
@ -999,7 +1007,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :obj:`str`: The decoded string /// :obj:`str`: The decoded string
#[args(skip_special_tokens = true)] #[args(skip_special_tokens = true)]
#[text_signature = "(self, ids, skip_special_tokens=True)"] #[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> { fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
} }
@ -1016,7 +1024,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :obj:`List[str]`: A list of decoded strings /// :obj:`List[str]`: A list of decoded strings
#[args(skip_special_tokens = true)] #[args(skip_special_tokens = true)]
#[text_signature = "(self, sequences, skip_special_tokens=True)"] #[pyo3(text_signature = "(self, sequences, skip_special_tokens=True)")]
fn decode_batch( fn decode_batch(
&self, &self,
sequences: Vec<Vec<u32>>, sequences: Vec<Vec<u32>>,
@ -1036,7 +1044,7 @@ impl PyTokenizer {
/// ///
/// Returns: /// Returns:
/// :obj:`Optional[int]`: An optional id, :obj:`None` if out of vocabulary /// :obj:`Optional[int]`: An optional id, :obj:`None` if out of vocabulary
#[text_signature = "(self, token)"] #[pyo3(text_signature = "(self, token)")]
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {
self.tokenizer.token_to_id(token) self.tokenizer.token_to_id(token)
} }
@ -1049,7 +1057,7 @@ impl PyTokenizer {
/// ///
/// Returns: /// Returns:
/// :obj:`Optional[str]`: An optional token, :obj:`None` if out of vocabulary /// :obj:`Optional[str]`: An optional token, :obj:`None` if out of vocabulary
#[text_signature = "(self, id)"] #[pyo3(text_signature = "(self, id)")]
fn id_to_token(&self, id: u32) -> Option<String> { fn id_to_token(&self, id: u32) -> Option<String> {
self.tokenizer.id_to_token(id) self.tokenizer.id_to_token(id)
} }
@ -1066,7 +1074,7 @@ impl PyTokenizer {
/// ///
/// Returns: /// Returns:
/// :obj:`int`: The number of tokens that were created in the vocabulary /// :obj:`int`: The number of tokens that were created in the vocabulary
#[text_signature = "(self, tokens)"] #[pyo3(text_signature = "(self, tokens)")]
fn add_tokens(&mut self, tokens: &PyList) -> PyResult<usize> { fn add_tokens(&mut self, tokens: &PyList) -> PyResult<usize> {
let tokens = tokens let tokens = tokens
.into_iter() .into_iter()
@ -1103,7 +1111,7 @@ impl PyTokenizer {
/// ///
/// Returns: /// Returns:
/// :obj:`int`: The number of tokens that were created in the vocabulary /// :obj:`int`: The number of tokens that were created in the vocabulary
#[text_signature = "(self, tokens)"] #[pyo3(text_signature = "(self, tokens)")]
fn add_special_tokens(&mut self, tokens: &PyList) -> PyResult<usize> { fn add_special_tokens(&mut self, tokens: &PyList) -> PyResult<usize> {
let tokens = tokens let tokens = tokens
.into_iter() .into_iter()
@ -1137,7 +1145,7 @@ impl PyTokenizer {
/// trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`): /// trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
/// An optional trainer that should be used to train our Model /// An optional trainer that should be used to train our Model
#[args(trainer = "None")] #[args(trainer = "None")]
#[text_signature = "(self, files, trainer = None)"] #[pyo3(text_signature = "(self, files, trainer = None)")]
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> { fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
let mut trainer = let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
@ -1173,7 +1181,7 @@ impl PyTokenizer {
/// The total number of sequences in the iterator. This is used to /// The total number of sequences in the iterator. This is used to
/// provide meaningful progress tracking /// provide meaningful progress tracking
#[args(trainer = "None", length = "None")] #[args(trainer = "None", length = "None")]
#[text_signature = "(self, iterator, trainer=None, length=None)"] #[pyo3(text_signature = "(self, iterator, trainer=None, length=None)")]
fn train_from_iterator( fn train_from_iterator(
&mut self, &mut self,
py: Python, py: Python,
@ -1239,7 +1247,7 @@ impl PyTokenizer {
/// Returns: /// Returns:
/// :class:`~tokenizers.Encoding`: The final post-processed encoding /// :class:`~tokenizers.Encoding`: The final post-processed encoding
#[args(pair = "None", add_special_tokens = true)] #[args(pair = "None", add_special_tokens = true)]
#[text_signature = "(self, encoding, pair=None, add_special_tokens=True)"] #[pyo3(text_signature = "(self, encoding, pair=None, add_special_tokens=True)")]
fn post_process( fn post_process(
&self, &self,
encoding: &PyEncoding, encoding: &PyEncoding,

View File

@ -15,7 +15,7 @@ use tokenizers as tk;
/// ///
/// This class is not supposed to be instantiated directly. Instead, any implementation of a /// This class is not supposed to be instantiated directly. Instead, any implementation of a
/// Trainer will return an instance of this class when instantiated. /// Trainer will return an instance of this class when instantiated.
#[pyclass(name=Trainer, module = "tokenizers.trainers", name=Trainer)] #[pyclass(module = "tokenizers.trainers", name = "Trainer", subclass)]
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
pub struct PyTrainer { pub struct PyTrainer {
#[serde(flatten)] #[serde(flatten)]
@ -164,7 +164,7 @@ macro_rules! setter {
/// ///
/// end_of_word_suffix (:obj:`str`, `optional`): /// end_of_word_suffix (:obj:`str`, `optional`):
/// A suffix to be used for every subword that is a end-of-word. /// A suffix to be used for every subword that is a end-of-word.
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name=BpeTrainer)] #[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "BpeTrainer")]
pub struct PyBpeTrainer {} pub struct PyBpeTrainer {}
#[pymethods] #[pymethods]
impl PyBpeTrainer { impl PyBpeTrainer {
@ -367,8 +367,10 @@ impl PyBpeTrainer {
/// ///
/// end_of_word_suffix (:obj:`str`, `optional`): /// end_of_word_suffix (:obj:`str`, `optional`):
/// A suffix to be used for every subword that is a end-of-word. /// A suffix to be used for every subword that is a end-of-word.
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name=WordPieceTrainer)] #[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "WordPieceTrainer")]
#[text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[], limit_alphabet=None, initial_alphabet= [],continuing_subword_prefix=\"##\", end_of_word_suffix=None)"] #[pyo3(
text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[], limit_alphabet=None, initial_alphabet= [],continuing_subword_prefix=\"##\", end_of_word_suffix=None)"
)]
pub struct PyWordPieceTrainer {} pub struct PyWordPieceTrainer {}
#[pymethods] #[pymethods]
impl PyWordPieceTrainer { impl PyWordPieceTrainer {
@ -557,7 +559,7 @@ impl PyWordPieceTrainer {
/// ///
/// special_tokens (:obj:`List[Union[str, AddedToken]]`): /// special_tokens (:obj:`List[Union[str, AddedToken]]`):
/// A list of special tokens the model should know of. /// A list of special tokens the model should know of.
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name=WordLevelTrainer)] #[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "WordLevelTrainer")]
pub struct PyWordLevelTrainer {} pub struct PyWordLevelTrainer {}
#[pymethods] #[pymethods]
impl PyWordLevelTrainer { impl PyWordLevelTrainer {
@ -713,8 +715,10 @@ impl PyWordLevelTrainer {
/// n_sub_iterations (:obj:`int`): /// n_sub_iterations (:obj:`int`):
/// The number of iterations of the EM algorithm to perform before /// The number of iterations of the EM algorithm to perform before
/// pruning the vocabulary. /// pruning the vocabulary.
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name=UnigramTrainer)] #[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "UnigramTrainer")]
#[text_signature = "(self, vocab_size=8000, show_progress=True, special_tokens=[], shrinking_factor=0.75, unk_token=None, max_piece_length=16, n_sub_iterations=2)"] #[pyo3(
text_signature = "(self, vocab_size=8000, show_progress=True, special_tokens=[], shrinking_factor=0.75, unk_token=None, max_piece_length=16, n_sub_iterations=2)"
)]
pub struct PyUnigramTrainer {} pub struct PyUnigramTrainer {}
#[pymethods] #[pymethods]
impl PyUnigramTrainer { impl PyUnigramTrainer {
@ -864,8 +868,8 @@ mod tests {
let py_bpe = py_trainer.get_as_subtype().unwrap(); let py_bpe = py_trainer.get_as_subtype().unwrap();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
assert_eq!( assert_eq!(
"tokenizers.trainers.BpeTrainer", "BpeTrainer",
py_bpe.as_ref(gil.python()).get_type().name() py_bpe.as_ref(gil.python()).get_type().name().unwrap()
); );
} }
} }

View File

@ -1,5 +1,5 @@
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{AsPyPointer, PyNativeType}; use pyo3::AsPyPointer;
use std::collections::VecDeque; use std::collections::VecDeque;
/// An simple iterator that can be instantiated with a specified length. /// An simple iterator that can be instantiated with a specified length.

View File

@ -4,7 +4,6 @@ use crate::error::ToPyResult;
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use pyo3::{PyMappingProtocol, PyObjectProtocol};
use tk::normalizer::{char_to_bytes, NormalizedString, Range, SplitDelimiterBehavior}; use tk::normalizer::{char_to_bytes, NormalizedString, Range, SplitDelimiterBehavior};
use tk::pattern::Pattern; use tk::pattern::Pattern;
@ -192,7 +191,7 @@ fn slice(
/// Args: /// Args:
/// sequence: str: /// sequence: str:
/// The string sequence used to initialize this NormalizedString /// The string sequence used to initialize this NormalizedString
#[pyclass(module = "tokenizers", name=NormalizedString)] #[pyclass(module = "tokenizers", name = "NormalizedString")]
#[derive(Clone)] #[derive(Clone)]
pub struct PyNormalizedString { pub struct PyNormalizedString {
pub(crate) normalized: NormalizedString, pub(crate) normalized: NormalizedString,
@ -217,91 +216,91 @@ impl PyNormalizedString {
} }
/// Runs the NFD normalization /// Runs the NFD normalization
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn nfd(&mut self) { fn nfd(&mut self) {
self.normalized.nfd(); self.normalized.nfd();
} }
/// Runs the NFKD normalization /// Runs the NFKD normalization
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn nfkd(&mut self) { fn nfkd(&mut self) {
self.normalized.nfkd(); self.normalized.nfkd();
} }
/// Runs the NFC normalization /// Runs the NFC normalization
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn nfc(&mut self) { fn nfc(&mut self) {
self.normalized.nfc(); self.normalized.nfc();
} }
/// Runs the NFKC normalization /// Runs the NFKC normalization
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn nfkc(&mut self) { fn nfkc(&mut self) {
self.normalized.nfkc(); self.normalized.nfkc();
} }
/// Lowercase the string /// Lowercase the string
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn lowercase(&mut self) { fn lowercase(&mut self) {
self.normalized.lowercase(); self.normalized.lowercase();
} }
/// Uppercase the string /// Uppercase the string
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn uppercase(&mut self) { fn uppercase(&mut self) {
self.normalized.uppercase(); self.normalized.uppercase();
} }
/// Prepend the given sequence to the string /// Prepend the given sequence to the string
#[text_signature = "(self, s)"] #[pyo3(text_signature = "(self, s)")]
fn prepend(&mut self, s: &str) { fn prepend(&mut self, s: &str) {
self.normalized.prepend(s); self.normalized.prepend(s);
} }
/// Append the given sequence to the string /// Append the given sequence to the string
#[text_signature = "(self, s)"] #[pyo3(text_signature = "(self, s)")]
fn append(&mut self, s: &str) { fn append(&mut self, s: &str) {
self.normalized.append(s); self.normalized.append(s);
} }
/// Strip the left of the string /// Strip the left of the string
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn lstrip(&mut self) { fn lstrip(&mut self) {
self.normalized.lstrip(); self.normalized.lstrip();
} }
/// Strip the right of the string /// Strip the right of the string
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn rstrip(&mut self) { fn rstrip(&mut self) {
self.normalized.rstrip(); self.normalized.rstrip();
} }
/// Strip both ends of the string /// Strip both ends of the string
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn strip(&mut self) { fn strip(&mut self) {
self.normalized.strip(); self.normalized.strip();
} }
/// Clears the string /// Clears the string
#[text_signature = "(self)"] #[pyo3(text_signature = "(self)")]
fn clear(&mut self) { fn clear(&mut self) {
self.normalized.clear(); self.normalized.clear();
} }
/// Slice the string using the given range /// Slice the string using the given range
#[text_signature = "(self, range)"] #[pyo3(text_signature = "(self, range)")]
fn slice(&self, range: PyRange) -> PyResult<Option<PyNormalizedString>> { fn slice(&self, range: PyRange) -> PyResult<Option<PyNormalizedString>> {
slice(&self.normalized, &range) slice(&self.normalized, &range)
} }
/// Filter each character of the string using the given func /// Filter each character of the string using the given func
#[text_signature = "(self, func)"] #[pyo3(text_signature = "(self, func)")]
fn filter(&mut self, func: &PyAny) -> PyResult<()> { fn filter(&mut self, func: &PyAny) -> PyResult<()> {
filter(&mut self.normalized, func) filter(&mut self.normalized, func)
} }
/// Calls the given function for each character of the string /// Calls the given function for each character of the string
#[text_signature = "(self, func)"] #[pyo3(text_signature = "(self, func)")]
fn for_each(&self, func: &PyAny) -> PyResult<()> { fn for_each(&self, func: &PyAny) -> PyResult<()> {
for_each(&self.normalized, func) for_each(&self.normalized, func)
} }
@ -310,7 +309,7 @@ impl PyNormalizedString {
/// ///
/// Replaces each character of the string using the returned value. Each /// Replaces each character of the string using the returned value. Each
/// returned value **must** be a str of length 1 (ie a character). /// returned value **must** be a str of length 1 (ie a character).
#[text_signature = "(self, func)"] #[pyo3(text_signature = "(self, func)")]
fn map(&mut self, func: &PyAny) -> PyResult<()> { fn map(&mut self, func: &PyAny) -> PyResult<()> {
map(&mut self.normalized, func) map(&mut self.normalized, func)
} }
@ -328,7 +327,7 @@ impl PyNormalizedString {
/// ///
/// Returns: /// Returns:
/// A list of NormalizedString, representing each split /// A list of NormalizedString, representing each split
#[text_signature = "(self, pattern, behavior)"] #[pyo3(text_signature = "(self, pattern, behavior)")]
fn split( fn split(
&mut self, &mut self,
pattern: PyPattern, pattern: PyPattern,
@ -349,14 +348,11 @@ impl PyNormalizedString {
/// ///
/// content: str: /// content: str:
/// The content to be used as replacement /// The content to be used as replacement
#[text_signature = "(self, pattern, content)"] #[pyo3(text_signature = "(self, pattern, content)")]
fn replace(&mut self, pattern: PyPattern, content: &str) -> PyResult<()> { fn replace(&mut self, pattern: PyPattern, content: &str) -> PyResult<()> {
ToPyResult(self.normalized.replace(pattern, content)).into() ToPyResult(self.normalized.replace(pattern, content)).into()
} }
}
#[pyproto]
impl PyObjectProtocol<'p> for PyNormalizedString {
fn __repr__(&self) -> String { fn __repr__(&self) -> String {
format!( format!(
r#"NormalizedString(original="{}", normalized="{}")"#, r#"NormalizedString(original="{}", normalized="{}")"#,
@ -365,14 +361,11 @@ impl PyObjectProtocol<'p> for PyNormalizedString {
) )
} }
fn __str__(&'p self) -> &'p str { fn __str__(&self) -> &str {
self.normalized.get() self.normalized.get()
} }
}
#[pyproto] fn __getitem__(&self, range: PyRange<'_>) -> PyResult<Option<PyNormalizedString>> {
impl PyMappingProtocol<'p> for PyNormalizedString {
fn __getitem__(&self, range: PyRange<'p>) -> PyResult<Option<PyNormalizedString>> {
slice(&self.normalized, &range) slice(&self.normalized, &range)
} }
} }
@ -389,7 +382,7 @@ impl From<PyNormalizedString> for NormalizedString {
} }
} }
#[pyclass(module = "tokenizers", name=NormalizedStringRefMut)] #[pyclass(module = "tokenizers", name = "NormalizedStringRefMut")]
#[derive(Clone)] #[derive(Clone)]
pub struct PyNormalizedStringRefMut { pub struct PyNormalizedStringRefMut {
inner: RefMutContainer<NormalizedString>, inner: RefMutContainer<NormalizedString>,

View File

@ -147,8 +147,8 @@ fn to_encoding(
/// Args: /// Args:
/// sequence: str: /// sequence: str:
/// The string sequence used to initialize this PreTokenizedString /// The string sequence used to initialize this PreTokenizedString
#[pyclass(module = "tokenizers", name=PreTokenizedString)] #[pyclass(module = "tokenizers", name = "PreTokenizedString")]
#[text_signature = "(self, sequence)"] #[pyo3(text_signature = "(self, sequence)")]
pub struct PyPreTokenizedString { pub struct PyPreTokenizedString {
pub(crate) pretok: tk::PreTokenizedString, pub(crate) pretok: tk::PreTokenizedString,
} }
@ -182,7 +182,7 @@ impl PyPreTokenizedString {
/// just return it directly. /// just return it directly.
/// In order for the offsets to be tracked accurately, any returned `NormalizedString` /// In order for the offsets to be tracked accurately, any returned `NormalizedString`
/// should come from calling either `.split` or `.slice` on the received one. /// should come from calling either `.split` or `.slice` on the received one.
#[text_signature = "(self, func)"] #[pyo3(text_signature = "(self, func)")]
fn split(&mut self, func: &PyAny) -> PyResult<()> { fn split(&mut self, func: &PyAny) -> PyResult<()> {
split(&mut self.pretok, func) split(&mut self.pretok, func)
} }
@ -194,7 +194,7 @@ impl PyPreTokenizedString {
/// The function used to normalize each underlying split. This function /// The function used to normalize each underlying split. This function
/// does not need to return anything, just calling the methods on the provided /// does not need to return anything, just calling the methods on the provided
/// NormalizedString allow its modification. /// NormalizedString allow its modification.
#[text_signature = "(self, func)"] #[pyo3(text_signature = "(self, func)")]
fn normalize(&mut self, func: &PyAny) -> PyResult<()> { fn normalize(&mut self, func: &PyAny) -> PyResult<()> {
normalize(&mut self.pretok, func) normalize(&mut self.pretok, func)
} }
@ -205,7 +205,7 @@ impl PyPreTokenizedString {
/// func: Callable[[str], List[Token]]: /// func: Callable[[str], List[Token]]:
/// The function used to tokenize each underlying split. This function must return /// The function used to tokenize each underlying split. This function must return
/// a list of Token generated from the input str. /// a list of Token generated from the input str.
#[text_signature = "(self, func)"] #[pyo3(text_signature = "(self, func)")]
fn tokenize(&mut self, func: &PyAny) -> PyResult<()> { fn tokenize(&mut self, func: &PyAny) -> PyResult<()> {
tokenize(&mut self.pretok, func) tokenize(&mut self.pretok, func)
} }
@ -224,7 +224,7 @@ impl PyPreTokenizedString {
/// Returns: /// Returns:
/// An Encoding /// An Encoding
#[args(type_id = "0", word_idx = "None")] #[args(type_id = "0", word_idx = "None")]
#[text_signature = "(self, type_id=0, word_idx=None)"] #[pyo3(text_signature = "(self, type_id=0, word_idx=None)")]
fn to_encoding(&self, type_id: u32, word_idx: Option<u32>) -> PyResult<PyEncoding> { fn to_encoding(&self, type_id: u32, word_idx: Option<u32>) -> PyResult<PyEncoding> {
to_encoding(&self.pretok, type_id, word_idx) to_encoding(&self.pretok, type_id, word_idx)
} }
@ -249,7 +249,7 @@ impl PyPreTokenizedString {
offset_referential = "PyOffsetReferential(OffsetReferential::Original)", offset_referential = "PyOffsetReferential(OffsetReferential::Original)",
offset_type = "PyOffsetType(OffsetType::Char)" offset_type = "PyOffsetType(OffsetType::Char)"
)] )]
#[text_signature = "(self, offset_referential=\"original\", offset_type=\"char\")"] #[pyo3(text_signature = "(self, offset_referential=\"original\", offset_type=\"char\")")]
fn get_splits( fn get_splits(
&self, &self,
offset_referential: PyOffsetReferential, offset_referential: PyOffsetReferential,
@ -259,7 +259,7 @@ impl PyPreTokenizedString {
} }
} }
#[pyclass(module = "tokenizers", name=PreTokenizedString)] #[pyclass(module = "tokenizers", name = "PreTokenizedString")]
#[derive(Clone)] #[derive(Clone)]
pub struct PyPreTokenizedStringRefMut { pub struct PyPreTokenizedStringRefMut {
inner: RefMutContainer<PreTokenizedString>, inner: RefMutContainer<PreTokenizedString>,

View File

@ -3,8 +3,8 @@ use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
/// Instantiate a new Regex with the given pattern /// Instantiate a new Regex with the given pattern
#[pyclass(module = "tokenizers", name=Regex)] #[pyclass(module = "tokenizers", name = "Regex")]
#[text_signature = "(self, pattern)"] #[pyo3(text_signature = "(self, pattern)")]
pub struct PyRegex { pub struct PyRegex {
pub inner: Regex, pub inner: Regex,
pub pattern: String, pub pattern: String,