mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Making the regex in ByteLevel optional. (#939)
* Making the regex in ByteLevel optional. * Changed the stub. * Beter stub. * Typo fix. * Remove bad comments.
This commit is contained in:
@ -125,12 +125,15 @@ declare_types! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// byte_level(addPrefixSpace: bool = true)
|
/// byte_level(addPrefixSpace: bool = true, useRegex: bool = true)
|
||||||
fn byte_level(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
fn byte_level(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||||
let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default();
|
let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default();
|
||||||
if let Some(add_prefix_space) = cx.extract_opt::<bool>(0)? {
|
if let Some(add_prefix_space) = cx.extract_opt::<bool>(0)? {
|
||||||
byte_level = byte_level.add_prefix_space(add_prefix_space);
|
byte_level = byte_level.add_prefix_space(add_prefix_space);
|
||||||
}
|
}
|
||||||
|
if let Some(use_regex) = cx.extract_opt::<bool>(1)? {
|
||||||
|
byte_level = byte_level.use_regex(use_regex);
|
||||||
|
}
|
||||||
|
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
|
@ -102,7 +102,7 @@ class ByteLevel(PreTokenizer):
|
|||||||
lets us treat `hello` exactly like `say hello`.
|
lets us treat `hello` exactly like `say hello`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, add_prefix_space=True):
|
def __init__(self, add_prefix_space=True, use_regex=True):
|
||||||
pass
|
pass
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def alphabet():
|
def alphabet():
|
||||||
|
@ -229,7 +229,7 @@ macro_rules! setter {
|
|||||||
/// Whether to add a space to the first word if there isn't already one. This
|
/// Whether to add a space to the first word if there isn't already one. This
|
||||||
/// lets us treat `hello` exactly like `say hello`.
|
/// lets us treat `hello` exactly like `say hello`.
|
||||||
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)]
|
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)]
|
||||||
#[text_signature = "(self, add_prefix_space=True)"]
|
#[text_signature = "(self, add_prefix_space=True, use_regex=True)"]
|
||||||
pub struct PyByteLevel {}
|
pub struct PyByteLevel {}
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyByteLevel {
|
impl PyByteLevel {
|
||||||
@ -243,13 +243,28 @@ impl PyByteLevel {
|
|||||||
setter!(self_, ByteLevel, add_prefix_space, add_prefix_space);
|
setter!(self_, ByteLevel, add_prefix_space, add_prefix_space);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_use_regex(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, ByteLevel, use_regex)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_use_regex(self_: PyRef<Self>, use_regex: bool) {
|
||||||
|
setter!(self_, ByteLevel, use_regex, use_regex);
|
||||||
|
}
|
||||||
|
|
||||||
#[new]
|
#[new]
|
||||||
#[args(add_prefix_space = "true", _kwargs = "**")]
|
#[args(add_prefix_space = "true", use_regex = "true", _kwargs = "**")]
|
||||||
fn new(add_prefix_space: bool, _kwargs: Option<&PyDict>) -> (Self, PyPreTokenizer) {
|
fn new(
|
||||||
|
add_prefix_space: bool,
|
||||||
|
use_regex: bool,
|
||||||
|
_kwargs: Option<&PyDict>,
|
||||||
|
) -> (Self, PyPreTokenizer) {
|
||||||
(
|
(
|
||||||
PyByteLevel {},
|
PyByteLevel {},
|
||||||
ByteLevel::default()
|
ByteLevel::default()
|
||||||
.add_prefix_space(add_prefix_space)
|
.add_prefix_space(add_prefix_space)
|
||||||
|
.use_regex(use_regex)
|
||||||
.into(),
|
.into(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -53,6 +53,15 @@ pub struct ByteLevel {
|
|||||||
pub add_prefix_space: bool,
|
pub add_prefix_space: bool,
|
||||||
/// Whether the post processing step should trim offsets to avoid including whitespaces.
|
/// Whether the post processing step should trim offsets to avoid including whitespaces.
|
||||||
pub trim_offsets: bool,
|
pub trim_offsets: bool,
|
||||||
|
|
||||||
|
/// Whether to use the standard GPT2 regex for whitespace splitting
|
||||||
|
/// Set it to False if you want to use your own splitting.
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub use_regex: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_true() -> bool {
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ByteLevel {
|
impl Default for ByteLevel {
|
||||||
@ -60,15 +69,17 @@ impl Default for ByteLevel {
|
|||||||
Self {
|
Self {
|
||||||
add_prefix_space: true,
|
add_prefix_space: true,
|
||||||
trim_offsets: true,
|
trim_offsets: true,
|
||||||
|
use_regex: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ByteLevel {
|
impl ByteLevel {
|
||||||
pub fn new(add_prefix_space: bool, trim_offsets: bool) -> Self {
|
pub fn new(add_prefix_space: bool, trim_offsets: bool, use_regex: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
add_prefix_space,
|
add_prefix_space,
|
||||||
trim_offsets,
|
trim_offsets,
|
||||||
|
use_regex,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,6 +98,12 @@ impl ByteLevel {
|
|||||||
self.trim_offsets = v;
|
self.trim_offsets = v;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn use_regex(mut self, v: bool) -> Self {
|
||||||
|
self.use_regex = v;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
|
/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
|
||||||
@ -99,7 +116,11 @@ impl PreTokenizer for ByteLevel {
|
|||||||
if self.add_prefix_space && !normalized.get().starts_with(' ') {
|
if self.add_prefix_space && !normalized.get().starts_with(' ') {
|
||||||
normalized.prepend(" ");
|
normalized.prepend(" ");
|
||||||
}
|
}
|
||||||
normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
|
if self.use_regex {
|
||||||
|
normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
|
||||||
|
} else {
|
||||||
|
Ok(vec![normalized])
|
||||||
|
}
|
||||||
})?;
|
})?;
|
||||||
pretokenized.normalize(|normalized| {
|
pretokenized.normalize(|normalized| {
|
||||||
let s = normalized.get();
|
let s = normalized.get();
|
||||||
@ -247,6 +268,21 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pre_tokenization_no_regex() {
|
||||||
|
let bytelevel = ByteLevel::default().use_regex(false);
|
||||||
|
let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into();
|
||||||
|
bytelevel.pre_tokenize(&mut pretokenized).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
pretokenized
|
||||||
|
.get_splits(OffsetReferential::Original, OffsetType::Byte)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(s, o, _)| (s, o))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec![("ĠHelloĠmyĠfriend,ĠhowĠisĠyourĠdayĠgoing?", (0, 39))]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decoding() {
|
fn decoding() {
|
||||||
let bytelevel = ByteLevel::default().add_prefix_space(false);
|
let bytelevel = ByteLevel::default().add_prefix_space(false);
|
||||||
@ -513,4 +549,27 @@ mod tests {
|
|||||||
vec!["Hello there dear friend! [PA D]"]
|
vec!["Hello there dear friend! [PA D]"]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deserialization() {
|
||||||
|
// Before use_regex
|
||||||
|
let byte_level: ByteLevel = serde_json::from_str(
|
||||||
|
r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(byte_level.use_regex);
|
||||||
|
|
||||||
|
// Loading works, new future BC test.
|
||||||
|
let byte_level: ByteLevel = serde_json::from_str(
|
||||||
|
r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": true}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(byte_level.use_regex);
|
||||||
|
|
||||||
|
let byte_level: ByteLevel = serde_json::from_str(
|
||||||
|
r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(!byte_level.use_regex);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user