pyo3 v0.18 migration (#1173)

* pyo v0.18 migration

* Fix formatting issues of black
This commit is contained in:
mert-kurttutan
2023-03-08 11:27:47 +01:00
committed by GitHub
parent 3138657565
commit 5c18ec5ff5
15 changed files with 138 additions and 82 deletions

View File

@ -128,7 +128,7 @@ impl From<tk::AddedToken> for PyAddedToken {
#[pymethods]
impl PyAddedToken {
#[new]
#[args(kwargs = "**")]
#[pyo3(signature = (content=None, **kwargs))]
fn __new__(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = PyAddedToken::from(content.unwrap_or(""), None);
@ -308,7 +308,7 @@ impl FromPyObject<'_> for PyArrayUnicode {
);
let py = ob.py();
let obj = PyObject::from_owned_ptr(py, unicode);
let s = obj.cast_as::<PyString>(py)?;
let s = obj.downcast::<PyString>(py)?;
Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned())
})
.collect::<PyResult<Vec<_>>>()?;
@ -332,7 +332,7 @@ impl FromPyObject<'_> for PyArrayStr {
.as_array()
.iter()
.map(|obj| {
let s = obj.cast_as::<PyString>(ob.py())?;
let s = obj.downcast::<PyString>(ob.py())?;
Ok(s.to_string_lossy().into_owned())
})
.collect::<PyResult<Vec<_>>>()?;
@ -562,7 +562,7 @@ impl PyTokenizer {
/// Returns:
/// :class:`~tokenizers.Tokenizer`: The new tokenizer
#[staticmethod]
#[args(revision = "String::from(\"main\")", auth_token = "None")]
#[pyo3(signature = (identifier, revision = String::from("main"), auth_token = None))]
#[pyo3(text_signature = "(identifier, revision=\"main\", auth_token=None)")]
fn from_pretrained(
identifier: &str,
@ -591,7 +591,7 @@ impl PyTokenizer {
///
/// Returns:
/// :obj:`str`: A string representing the serialized Tokenizer
#[args(pretty = false)]
#[pyo3(signature = (pretty = false))]
#[pyo3(text_signature = "(self, pretty=False)")]
fn to_str(&self, pretty: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.to_string(pretty)).into()
@ -605,7 +605,7 @@ impl PyTokenizer {
///
/// pretty (:obj:`bool`, defaults to :obj:`True`):
/// Whether the JSON file should be pretty formatted.
#[args(pretty = true)]
#[pyo3(signature = (path, pretty = true))]
#[pyo3(text_signature = "(self, path, pretty=True)")]
fn save(&self, path: &str, pretty: bool) -> PyResult<()> {
ToPyResult(self.tokenizer.save(path, pretty)).into()
@ -629,7 +629,7 @@ impl PyTokenizer {
///
/// Returns:
/// :obj:`Dict[str, int]`: The vocabulary
#[args(with_added_tokens = true)]
#[pyo3(signature = (with_added_tokens = true))]
#[pyo3(text_signature = "(self, with_added_tokens=True)")]
fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
self.tokenizer.get_vocab(with_added_tokens)
@ -643,7 +643,7 @@ impl PyTokenizer {
///
/// Returns:
/// :obj:`int`: The size of the vocabulary
#[args(with_added_tokens = true)]
#[pyo3(signature = (with_added_tokens = true))]
#[pyo3(text_signature = "(self, with_added_tokens=True)")]
fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
self.tokenizer.get_vocab_size(with_added_tokens)
@ -665,7 +665,7 @@ impl PyTokenizer {
///
/// direction (:obj:`str`, defaults to :obj:`right`):
/// Truncate direction
#[args(kwargs = "**")]
#[pyo3(signature = (max_length, **kwargs))]
#[pyo3(
text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')"
)]
@ -767,7 +767,7 @@ impl PyTokenizer {
/// length (:obj:`int`, `optional`):
/// If specified, the length at which to pad. If not specified we pad using the size of
/// the longest sequence in a batch.
#[args(kwargs = "**")]
#[pyo3(signature = (**kwargs))]
#[pyo3(
text_signature = "(self, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]', length=None, pad_to_multiple_of=None)"
)]
@ -896,7 +896,7 @@ impl PyTokenizer {
/// Returns:
/// :class:`~tokenizers.Encoding`: The encoded result
///
#[args(pair = "None", is_pretokenized = "false", add_special_tokens = "true")]
#[pyo3(signature = (sequence, pair = None, is_pretokenized = false, add_special_tokens = true))]
#[pyo3(
text_signature = "(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True)"
)]
@ -963,7 +963,7 @@ impl PyTokenizer {
/// Returns:
/// A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch
///
#[args(is_pretokenized = "false", add_special_tokens = "true")]
#[pyo3(signature = (input, is_pretokenized = false, add_special_tokens = true))]
#[pyo3(text_signature = "(self, input, is_pretokenized=False, add_special_tokens=True)")]
fn encode_batch(
&self,
@ -1006,7 +1006,7 @@ impl PyTokenizer {
///
/// Returns:
/// :obj:`str`: The decoded string
#[args(skip_special_tokens = true)]
#[pyo3(signature = (ids, skip_special_tokens = true))]
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
@ -1023,7 +1023,7 @@ impl PyTokenizer {
///
/// Returns:
/// :obj:`List[str]`: A list of decoded strings
#[args(skip_special_tokens = true)]
#[pyo3(signature = (sequences, skip_special_tokens = true))]
#[pyo3(text_signature = "(self, sequences, skip_special_tokens=True)")]
fn decode_batch(
&self,
@ -1144,7 +1144,7 @@ impl PyTokenizer {
///
/// trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
/// An optional trainer that should be used to train our Model
#[args(trainer = "None")]
#[pyo3(signature = (files, trainer = None))]
#[pyo3(text_signature = "(self, files, trainer = None)")]
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
let mut trainer =
@ -1180,7 +1180,7 @@ impl PyTokenizer {
/// length (:obj:`int`, `optional`):
/// The total number of sequences in the iterator. This is used to
/// provide meaningful progress tracking
#[args(trainer = "None", length = "None")]
#[pyo3(signature = (iterator, trainer = None, length = None))]
#[pyo3(text_signature = "(self, iterator, trainer=None, length=None)")]
fn train_from_iterator(
&mut self,
@ -1246,7 +1246,7 @@ impl PyTokenizer {
///
/// Returns:
/// :class:`~tokenizers.Encoding`: The final post-processed encoding
#[args(pair = "None", add_special_tokens = true)]
#[pyo3(signature = (encoding, pair = None, add_special_tokens = true))]
#[pyo3(text_signature = "(self, encoding, pair=None, add_special_tokens=True)")]
fn post_process(
&self,