Python - Add WordPiece model

This commit is contained in:
Anthony MOI
2019-12-09 12:49:44 -05:00
parent 5eba30835d
commit d60d24a378

View File

@ -5,8 +5,8 @@ use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
/// Represents any Model to be used with a Tokenizer
/// This class is to be constructed from specific models
/// A Model represents some tokenization algorithm like BPE or Word
/// This class cannot be constructed directly. Please use one of the concrete models.
#[pyclass]
pub struct Model {
pub model: Container<dyn tk::tokenizer::Model + Sync>,
@ -15,9 +15,9 @@ pub struct Model {
#[pymethods]
impl Model {
#[new]
fn new(_obj: &PyRawObject) -> PyResult<Self> {
fn new(obj: &PyRawObject) -> PyResult<()> {
Err(exceptions::Exception::py_err(
"Cannot create a Model directly",
"Cannot create a Model directly. Use a concrete subclass",
))
}
}
@ -29,6 +29,10 @@ pub struct BPE {}
#[pymethods]
impl BPE {
/// from_files(vocab, merges, /)
/// --
///
/// Instanciate a new BPE model using the provided vocab and merges files
#[staticmethod]
fn from_files(vocab: &str, merges: &str) -> PyResult<Model> {
match tk::models::bpe::BPE::from_files(vocab, merges) {
@ -44,6 +48,10 @@ impl BPE {
}
}
/// empty()
/// --
///
/// Instanciate a new BPE model with empty vocab and merges
#[staticmethod]
fn empty() -> Model {
Model {
@ -51,3 +59,37 @@ impl BPE {
}
}
}
/// WordPiece Model
#[pyclass]
pub struct WordPiece {}
#[pymethods]
impl WordPiece {
/// from_files(vocab, /)
/// --
///
/// Instantiate a new WordPiece model using the provided vocabulary file
#[staticmethod]
fn from_files(vocab: &str) -> PyResult<Model> {
// TODO: Parse kwargs for these
let unk_token = String::from("[UNK]");
let max_input_chars_per_word = Some(100);
match tk::models::wordpiece::WordPiece::from_files(
vocab,
unk_token,
max_input_chars_per_word,
) {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err(
"Error while initializing WordPiece",
))
}
Ok(wordpiece) => Ok(Model {
model: Container::Owned(Box::new(wordpiece)),
}),
}
}
}