mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-02 07:19:24 +00:00
Python - expost get_vocab
on Tokenizer
This commit is contained in:
@ -4,6 +4,7 @@ use pyo3::exceptions;
|
|||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
use pyo3::PyObjectProtocol;
|
use pyo3::PyObjectProtocol;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use super::decoders::Decoder;
|
use super::decoders::Decoder;
|
||||||
use super::encoding::Encoding;
|
use super::encoding::Encoding;
|
||||||
@ -87,20 +88,13 @@ impl Tokenizer {
|
|||||||
.map_or(0, |p| p.as_ref().added_tokens(is_pair)))
|
.map_or(0, |p| p.as_ref().added_tokens(is_pair)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[args(kwargs = "**")]
|
#[args(with_added_tokens = true)]
|
||||||
fn get_vocab_size(&self, kwargs: Option<&PyDict>) -> PyResult<usize> {
|
fn get_vocab(&self, with_added_tokens: bool) -> PyResult<HashMap<String, u32>> {
|
||||||
let mut with_added_tokens = true;
|
Ok(self.tokenizer.get_vocab(with_added_tokens))
|
||||||
|
|
||||||
if let Some(kwargs) = kwargs {
|
|
||||||
for (key, value) in kwargs {
|
|
||||||
let key: &str = key.extract()?;
|
|
||||||
match key {
|
|
||||||
"with_added_tokens" => with_added_tokens = value.extract()?,
|
|
||||||
_ => println!("Ignored unknown kwarg option {}", key),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[args(with_added_tokens = true)]
|
||||||
|
fn get_vocab_size(&self, with_added_tokens: bool) -> PyResult<usize> {
|
||||||
Ok(self.tokenizer.get_vocab_size(with_added_tokens))
|
Ok(self.tokenizer.get_vocab_size(with_added_tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,12 +258,26 @@ class Tokenizer:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
def get_vocab_size(self, with_added_tokens: Optional[bool]) -> int:
|
def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]:
|
||||||
|
""" Returns the vocabulary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
with_added_tokens: boolean:
|
||||||
|
Whether to include the added tokens in the vocabulary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The vocabulary
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
||||||
""" Returns the size of the vocabulary
|
""" Returns the size of the vocabulary
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
with_added_tokens: (`optional`) boolean:
|
with_added_tokens: boolean:
|
||||||
Whether to include the added tokens in the vocabulary's size
|
Whether to include the added tokens in the vocabulary's size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The size of the vocabulary
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
def enable_truncation(self, max_length: int, stride: Optional[int], strategy: Optional[str]):
|
def enable_truncation(self, max_length: int, stride: Optional[int], strategy: Optional[str]):
|
||||||
|
Reference in New Issue
Block a user