mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Python - Add bindings for new AddedToken options
This commit is contained in:
@ -1,6 +1,8 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use super::utils::Container;
|
||||
use crate::tokenizer::AddedToken;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
|
||||
@ -28,7 +30,27 @@ impl BpeTrainer {
|
||||
"vocab_size" => builder = builder.vocab_size(val.extract()?),
|
||||
"min_frequency" => builder = builder.min_frequency(val.extract()?),
|
||||
"show_progress" => builder = builder.show_progress(val.extract()?),
|
||||
"special_tokens" => builder = builder.special_tokens(val.extract()?),
|
||||
"special_tokens" => {
|
||||
builder = builder.special_tokens(
|
||||
val.cast_as::<PyList>()?
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||
Ok(token.token.clone())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?,
|
||||
);
|
||||
}
|
||||
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
||||
"initial_alphabet" => {
|
||||
let alphabet: Vec<String> = val.extract()?;
|
||||
@ -74,7 +96,27 @@ impl WordPieceTrainer {
|
||||
"vocab_size" => builder = builder.vocab_size(val.extract()?),
|
||||
"min_frequency" => builder = builder.min_frequency(val.extract()?),
|
||||
"show_progress" => builder = builder.show_progress(val.extract()?),
|
||||
"special_tokens" => builder = builder.special_tokens(val.extract()?),
|
||||
"special_tokens" => {
|
||||
builder = builder.special_tokens(
|
||||
val.cast_as::<PyList>()?
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||
Ok(token.token.clone())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?,
|
||||
);
|
||||
}
|
||||
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
||||
"initial_alphabet" => {
|
||||
let alphabet: Vec<String> = val.extract()?;
|
||||
|
Reference in New Issue
Block a user