Python - expost get_vocab on Tokenizer

This commit is contained in:
Anthony MOI
2020-03-27 11:53:18 -04:00
parent e191008751
commit a2a6d80017
2 changed files with 23 additions and 15 deletions

View File

@ -4,6 +4,7 @@ use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use pyo3::PyObjectProtocol; use pyo3::PyObjectProtocol;
use std::collections::HashMap;
use super::decoders::Decoder; use super::decoders::Decoder;
use super::encoding::Encoding; use super::encoding::Encoding;
@ -87,20 +88,13 @@ impl Tokenizer {
.map_or(0, |p| p.as_ref().added_tokens(is_pair))) .map_or(0, |p| p.as_ref().added_tokens(is_pair)))
} }
#[args(kwargs = "**")] #[args(with_added_tokens = true)]
fn get_vocab_size(&self, kwargs: Option<&PyDict>) -> PyResult<usize> { fn get_vocab(&self, with_added_tokens: bool) -> PyResult<HashMap<String, u32>> {
let mut with_added_tokens = true; Ok(self.tokenizer.get_vocab(with_added_tokens))
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"with_added_tokens" => with_added_tokens = value.extract()?,
_ => println!("Ignored unknown kwarg option {}", key),
}
}
} }
#[args(with_added_tokens = true)]
fn get_vocab_size(&self, with_added_tokens: bool) -> PyResult<usize> {
Ok(self.tokenizer.get_vocab_size(with_added_tokens)) Ok(self.tokenizer.get_vocab_size(with_added_tokens))
} }

View File

@ -258,12 +258,26 @@ class Tokenizer:
:return: :return:
""" """
pass pass
def get_vocab_size(self, with_added_tokens: Optional[bool]) -> int: def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]:
""" Returns the vocabulary
Args:
with_added_tokens: boolean:
Whether to include the added tokens in the vocabulary
Returns:
The vocabulary
"""
pass
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
""" Returns the size of the vocabulary """ Returns the size of the vocabulary
Args: Args:
with_added_tokens: (`optional`) boolean: with_added_tokens: boolean:
Whether to include the added tokens in the vocabulary's size Whether to include the added tokens in the vocabulary's size
Returns:
The size of the vocabulary
""" """
pass pass
def enable_truncation(self, max_length: int, stride: Optional[int], strategy: Optional[str]): def enable_truncation(self, max_length: int, stride: Optional[int], strategy: Optional[str]):