mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Python - Expose PostProcessors
This commit is contained in:
1
bindings/python/Cargo.lock
generated
1
bindings/python/Cargo.lock
generated
@ -462,6 +462,7 @@ dependencies = [
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
@ -9,6 +9,9 @@ impl PyError {
|
||||
pub fn from(s: &str) -> Self {
|
||||
PyError(String::from(s))
|
||||
}
|
||||
pub fn into_pyerr(self) -> PyErr {
|
||||
exceptions::Exception::py_err(format!("{}", self))
|
||||
}
|
||||
}
|
||||
impl Display for PyError {
|
||||
fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
|
||||
|
@ -3,6 +3,7 @@ mod encoding;
|
||||
mod error;
|
||||
mod models;
|
||||
mod pre_tokenizers;
|
||||
mod processors;
|
||||
mod token;
|
||||
mod tokenizer;
|
||||
mod trainers;
|
||||
@ -46,6 +47,14 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Processors Module
|
||||
#[pymodule]
|
||||
fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<processors::PostProcessor>()?;
|
||||
m.add_class::<processors::BertProcessing>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Tokenizers Module
|
||||
#[pymodule]
|
||||
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
@ -53,6 +62,7 @@ fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_wrapped(wrap_pymodule!(models))?;
|
||||
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
||||
m.add_wrapped(wrap_pymodule!(decoders))?;
|
||||
m.add_wrapped(wrap_pymodule!(processors))?;
|
||||
m.add_wrapped(wrap_pymodule!(trainers))?;
|
||||
Ok(())
|
||||
}
|
||||
|
64
bindings/python/src/processors.rs
Normal file
64
bindings/python/src/processors.rs
Normal file
@ -0,0 +1,64 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use super::error::PyError;
|
||||
use super::utils::Container;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use tk::utils::TruncationStrategy;
|
||||
|
||||
#[pyclass(dict)]
|
||||
pub struct PostProcessor {
|
||||
pub processor: Container<dyn tk::tokenizer::PostProcessor + Sync>,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct BertProcessing {}
|
||||
#[pymethods]
|
||||
impl BertProcessing {
|
||||
#[staticmethod]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(
|
||||
sep: (String, u32),
|
||||
cls: (String, u32),
|
||||
kwargs: Option<&PyDict>,
|
||||
) -> PyResult<PostProcessor> {
|
||||
let mut max_len = 512;
|
||||
let mut trunc_strategy = tk::utils::TruncationStrategy::LongestFirst;
|
||||
let mut trunc_stride = 0;
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"max_len" => max_len = val.extract()?,
|
||||
"trunc_stride" => trunc_stride = val.extract()?,
|
||||
"trunc_strategy" => {
|
||||
let strategy: &str = val.extract()?;
|
||||
trunc_strategy = match strategy {
|
||||
"longest_first" => Ok(TruncationStrategy::LongestFirst),
|
||||
"only_first" => Ok(TruncationStrategy::OnlyFirst),
|
||||
"only_second" => Ok(TruncationStrategy::OnlySecond),
|
||||
other => Err(PyError(format!(
|
||||
"Unknown `trunc_strategy`: `{}`. Use \
|
||||
one of `longest_first`, `only_first` or `only_second`",
|
||||
other
|
||||
))
|
||||
.into_pyerr()),
|
||||
}?;
|
||||
}
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PostProcessor {
|
||||
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
|
||||
max_len,
|
||||
trunc_strategy,
|
||||
trunc_stride,
|
||||
sep,
|
||||
cls,
|
||||
))),
|
||||
})
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ use super::encoding::Encoding;
|
||||
use super::error::ToPyResult;
|
||||
use super::models::Model;
|
||||
use super::pre_tokenizers::PreTokenizer;
|
||||
use super::processors::PostProcessor;
|
||||
use super::trainers::Trainer;
|
||||
|
||||
#[pyclass(dict)]
|
||||
@ -69,6 +70,17 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
fn with_post_processor(&mut self, processor: &mut PostProcessor) -> PyResult<()> {
|
||||
if let Some(processor) = processor.processor.to_pointer() {
|
||||
self.tokenizer.with_post_processor(processor);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"The Processor is already being used in another Tokenizer",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&self, sentence: &str, pair: Option<&str>) -> PyResult<Encoding> {
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
|
@ -1,3 +1,3 @@
|
||||
__version__ = "0.0.2"
|
||||
|
||||
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers
|
||||
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers, processors
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||
use crate::utils::{truncate_encodings, TruncationStrategy};
|
||||
|
||||
struct BertProcessing {
|
||||
pub struct BertProcessing {
|
||||
max_len: usize,
|
||||
trunc_strategy: TruncationStrategy,
|
||||
trunc_stride: usize,
|
||||
@ -10,6 +10,24 @@ struct BertProcessing {
|
||||
cls: (String, u32),
|
||||
}
|
||||
|
||||
impl BertProcessing {
|
||||
pub fn new(
|
||||
max_len: usize,
|
||||
trunc_strategy: TruncationStrategy,
|
||||
trunc_stride: usize,
|
||||
sep: (String, u32),
|
||||
cls: (String, u32),
|
||||
) -> Self {
|
||||
BertProcessing {
|
||||
max_len,
|
||||
trunc_strategy,
|
||||
trunc_stride,
|
||||
sep,
|
||||
cls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PostProcessor for BertProcessing {
|
||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||
let special_token_len = if pair_encoding.is_some() { 3 } else { 2 };
|
||||
|
Reference in New Issue
Block a user