mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 14:59:20 +00:00
Python - Extract single pre-tokenized inputs from np.array
This commit is contained in:
73
bindings/python/Cargo.lock
generated
73
bindings/python/Cargo.lock
generated
@ -299,6 +299,14 @@ dependencies = [
|
|||||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matrixmultiply"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "maybe-uninit"
|
name = "maybe-uninit"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@ -317,6 +325,44 @@ dependencies = [
|
|||||||
"autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.43"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num_cpus"
|
name = "num_cpus"
|
||||||
version = "1.13.0"
|
version = "1.13.0"
|
||||||
@ -331,6 +377,19 @@ name = "number_prefix"
|
|||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "numpy"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"libc 0.2.74 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "onig"
|
name = "onig"
|
||||||
version = "6.0.0"
|
version = "6.0.0"
|
||||||
@ -500,6 +559,11 @@ dependencies = [
|
|||||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rawpointer"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon"
|
name = "rayon"
|
||||||
version = "1.3.1"
|
version = "1.3.1"
|
||||||
@ -680,6 +744,8 @@ version = "0.8.1"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"libc 0.2.74 (registry+https://github.com/rust-lang/crates.io-index)",
|
"libc 0.2.74 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rayon 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rayon 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"serde 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)",
|
"serde 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
@ -788,11 +854,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum libc 0.2.74 (registry+https://github.com/rust-lang/crates.io-index)" = "a2f02823cf78b754822df5f7f268fb59822e7296276d3e069d8e8cb26a14bd10"
|
"checksum libc 0.2.74 (registry+https://github.com/rust-lang/crates.io-index)" = "a2f02823cf78b754822df5f7f268fb59822e7296276d3e069d8e8cb26a14bd10"
|
||||||
"checksum lock_api 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "28247cc5a5be2f05fbcd76dd0cf2c7d3b5400cb978a28042abcd4fa0b3f8261c"
|
"checksum lock_api 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "28247cc5a5be2f05fbcd76dd0cf2c7d3b5400cb978a28042abcd4fa0b3f8261c"
|
||||||
"checksum log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
"checksum log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
||||||
|
"checksum matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d4f7ec66360130972f34830bfad9ef05c6610a43938a467bcc9ab9369ab3478f"
|
||||||
"checksum maybe-uninit 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00"
|
"checksum maybe-uninit 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00"
|
||||||
"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
|
"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
|
||||||
"checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f"
|
"checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f"
|
||||||
|
"checksum ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ac06db03ec2f46ee0ecdca1a1c34a99c0d188a0d83439b84bf0cb4b386e4ab09"
|
||||||
|
"checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
||||||
|
"checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
|
||||||
|
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
|
||||||
"checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
"checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
||||||
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
||||||
|
"checksum numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1afa393812b0e4c33d8a46054028ce39d79f954dfa98298f25691abf00b24e39"
|
||||||
"checksum onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bd91ccd8a02fce2f7e8a86655aec67bc6c171e6f8e704118a0e8c4b866a05a8a"
|
"checksum onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bd91ccd8a02fce2f7e8a86655aec67bc6c171e6f8e704118a0e8c4b866a05a8a"
|
||||||
"checksum onig_sys 69.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3814583fad89f3c60ae0701d80e87e1fd3028741723deda72d0d4a0ecf0cb0db"
|
"checksum onig_sys 69.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3814583fad89f3c60ae0701d80e87e1fd3028741723deda72d0d4a0ecf0cb0db"
|
||||||
"checksum parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a4893845fa2ca272e647da5d0e46660a314ead9c2fdd9a883aabc32e481a8733"
|
"checksum parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a4893845fa2ca272e647da5d0e46660a314ead9c2fdd9a883aabc32e481a8733"
|
||||||
@ -812,6 +884,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum rand_chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
|
"checksum rand_chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
|
||||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||||
|
"checksum rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||||
"checksum rayon 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "62f02856753d04e03e26929f820d0a0a337ebe71f849801eea335d464b349080"
|
"checksum rayon 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "62f02856753d04e03e26929f820d0a0a337ebe71f849801eea335d464b349080"
|
||||||
"checksum rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)" = "<none>"
|
"checksum rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)" = "<none>"
|
||||||
"checksum rayon-core 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e92e15d89083484e11353891f1af602cc661426deb9564c298b270c726973280"
|
"checksum rayon-core 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e92e15d89083484e11353891f1af602cc661426deb9564c298b270c726973280"
|
||||||
|
@ -14,6 +14,8 @@ serde = { version = "1.0", features = [ "rc", "derive" ]}
|
|||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
env_logger = "0.7.1"
|
env_logger = "0.7.1"
|
||||||
|
numpy = "0.11"
|
||||||
|
ndarray = "0.13"
|
||||||
|
|
||||||
[dependencies.pyo3]
|
[dependencies.pyo3]
|
||||||
version = "0.11"
|
version = "0.11"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use numpy::PyArray1;
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
@ -188,6 +189,90 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct PyArrayUnicode(Vec<String>);
|
||||||
|
impl FromPyObject<'_> for PyArrayUnicode {
|
||||||
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
|
let array = ob.downcast::<PyArray1<u8>>()?;
|
||||||
|
let arr = array.as_array_ptr();
|
||||||
|
let (type_num, elsize, alignment, data) = unsafe {
|
||||||
|
let desc = (*arr).descr;
|
||||||
|
(
|
||||||
|
(*desc).type_num,
|
||||||
|
(*desc).elsize as usize,
|
||||||
|
(*desc).alignment as usize,
|
||||||
|
(*arr).data,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let n_elem = array.shape()[0];
|
||||||
|
|
||||||
|
// type_num == 19 => Unicode
|
||||||
|
if type_num != 19 {
|
||||||
|
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem);
|
||||||
|
|
||||||
|
let seq = (0..n_elem)
|
||||||
|
.map(|i| {
|
||||||
|
let bytes = &all_bytes[i * elsize..(i + 1) * elsize];
|
||||||
|
let unicode = pyo3::ffi::PyUnicode_FromUnicode(
|
||||||
|
bytes.as_ptr() as *const _,
|
||||||
|
elsize as isize / alignment as isize,
|
||||||
|
);
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
let obj = PyObject::from_owned_ptr(py, unicode);
|
||||||
|
let s = obj.cast_as::<PyString>(py)?;
|
||||||
|
Ok(s.to_string()?.trim_matches(char::from(0)).to_owned())
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<_>>>()?;
|
||||||
|
|
||||||
|
Ok(Self(seq))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<PyArrayUnicode> for tk::InputSequence<'_> {
|
||||||
|
fn from(s: PyArrayUnicode) -> Self {
|
||||||
|
s.0.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PyArrayStr(Vec<String>);
|
||||||
|
impl FromPyObject<'_> for PyArrayStr {
|
||||||
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
|
let array = ob.downcast::<PyArray1<u8>>()?;
|
||||||
|
let arr = array.as_array_ptr();
|
||||||
|
let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) };
|
||||||
|
let n_elem = array.shape()[0];
|
||||||
|
|
||||||
|
if type_num != 17 {
|
||||||
|
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let objects = std::slice::from_raw_parts(data as *const PyObject, n_elem);
|
||||||
|
|
||||||
|
let seq = objects
|
||||||
|
.into_iter()
|
||||||
|
.map(|obj| {
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
let s = obj.cast_as::<PyString>(py)?;
|
||||||
|
Ok(s.to_string()?.into_owned())
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<_>>>()?;
|
||||||
|
|
||||||
|
Ok(Self(seq))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<PyArrayStr> for tk::InputSequence<'_> {
|
||||||
|
fn from(s: PyArrayStr) -> Self {
|
||||||
|
s.0.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
||||||
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
|
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
|
||||||
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
||||||
@ -195,11 +280,15 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
|
|||||||
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Ok(s) = ob.downcast::<PyList>() {
|
if let Ok(seq) = ob.extract::<PyArrayUnicode>() {
|
||||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
Ok(Self(seq.into()))
|
||||||
|
} else if let Ok(seq) = ob.extract::<PyArrayStr>() {
|
||||||
|
Ok(Self(seq.into()))
|
||||||
|
} else if let Ok(s) = ob.downcast::<PyList>() {
|
||||||
|
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
|
||||||
Ok(Self(seq.into()))
|
Ok(Self(seq.into()))
|
||||||
} else if let Ok(s) = ob.downcast::<PyTuple>() {
|
} else if let Ok(s) = ob.downcast::<PyTuple>() {
|
||||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
|
||||||
Ok(Self(seq.into()))
|
Ok(Self(seq.into()))
|
||||||
} else {
|
} else {
|
||||||
Err(err)
|
Err(err)
|
||||||
|
Reference in New Issue
Block a user