pyo3 v0.18 migration (#1173)

* pyo v0.18 migration

* Fix formatting issues of black
This commit is contained in:
mert-kurttutan
2023-03-08 11:27:47 +01:00
committed by GitHub
parent 3138657565
commit 5c18ec5ff5
15 changed files with 138 additions and 82 deletions

View File

@ -283,7 +283,7 @@ impl PyBpeTrainer {
}
#[new]
#[args(kwargs = "**")]
#[pyo3(signature = (**kwargs))]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
let mut builder = tk::models::bpe::BpeTrainer::builder();
if let Some(kwargs) = kwargs {
@ -295,7 +295,7 @@ impl PyBpeTrainer {
"show_progress" => builder = builder.show_progress(val.extract()?),
"special_tokens" => {
builder = builder.special_tokens(
val.cast_as::<PyList>()?
val.downcast::<PyList>()?
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
@ -489,7 +489,7 @@ impl PyWordPieceTrainer {
}
#[new]
#[args(kwargs = "**")]
#[pyo3(signature = (** kwargs))]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
if let Some(kwargs) = kwargs {
@ -501,7 +501,7 @@ impl PyWordPieceTrainer {
"show_progress" => builder = builder.show_progress(val.extract()?),
"special_tokens" => {
builder = builder.special_tokens(
val.cast_as::<PyList>()?
val.downcast::<PyList>()?
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
@ -629,7 +629,7 @@ impl PyWordLevelTrainer {
}
#[new]
#[args(kwargs = "**")]
#[pyo3(signature = (**kwargs))]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
@ -648,7 +648,7 @@ impl PyWordLevelTrainer {
}
"special_tokens" => {
builder.special_tokens(
val.cast_as::<PyList>()?
val.downcast::<PyList>()?
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
@ -797,7 +797,7 @@ impl PyUnigramTrainer {
}
#[new]
#[args(kwargs = "**")]
#[pyo3(signature = (**kwargs))]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
let mut builder = tk::models::unigram::UnigramTrainer::builder();
if let Some(kwargs) = kwargs {
@ -821,7 +821,7 @@ impl PyUnigramTrainer {
)
}
"special_tokens" => builder.special_tokens(
val.cast_as::<PyList>()?
val.downcast::<PyList>()?
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {