Merge pull request #197 from huggingface/remove-normalized

Remove NormalizedString from Encoding
This commit is contained in:
Anthony MOI
2020-03-18 16:52:21 -04:00
committed by GitHub
41 changed files with 1043 additions and 776 deletions

View File

@ -60,7 +60,22 @@ jobs:
args: --manifest-path ./tokenizers/Cargo.toml --all-targets --all-features -- -D warnings
- name: Run Tests
if: matrix.os != 'windows-latest'
shell: bash
working-directory: ./tokenizers
run: make test
# Skip integration tests for now on Windows
- name: Run lib Tests on Windows
if: matrix.os == 'windows-latest'
uses: actions-rs/cargo@v1
with:
command: test
args: --verbose --manifest-path ./tokenizers/Cargo.toml
args: --verbose --manifest-path ./tokenizers/Cargo.toml --lib
- name: Run doc Tests on Windows
if: matrix.os == 'windows-latest'
uses: actions-rs/cargo@v1
with:
command: test
args: --verbose --manifest-path ./tokenizers/Cargo.toml --doc

3
.gitignore vendored
View File

@ -7,6 +7,7 @@ target
Cargo.lock
/data
tokenizers/data
/docs
__pycache__
@ -16,4 +17,4 @@ pip-wheel-metadata
/bindings/python/build
/bindings/python/dist
.vscode
.vscode

View File

@ -44,17 +44,6 @@ export interface RawEncoding {
*/
getTypeIds(): number[];
/**
* Returns the original string
*
* @param [begin] The index from which to start (can be negative).
* @param [end] The index (excluded) to which to stop (can be negative).
* Stopping at the end of the string if not provided.
* @returns The full original string if no parameter is provided,
* otherwise the original string between `begin` and `end`
*/
getOriginalString(begin?: number, end?: number): string;
/**
* Pad the current Encoding at the given length
*

View File

@ -22,7 +22,6 @@ describe("RawEncoding", () => {
expect(typeof encoding.getIds).toBe("function");
expect(typeof encoding.getLength).toBe("function");
expect(typeof encoding.getOffsets).toBe("function");
expect(typeof encoding.getOriginalString).toBe("function");
expect(typeof encoding.getOverflowing).toBe("function");
expect(typeof encoding.getSpecialTokensMask).toBe("function");
expect(typeof encoding.getTokens).toBe("function");
@ -31,109 +30,6 @@ describe("RawEncoding", () => {
expect(typeof encoding.truncate).toBe("function");
});
describe("getOriginalString", () => {
it("returns the full original string when no params", () => {
const original = encoding.getOriginalString();
expect(original).toEqual(originalString);
});
it("accepts `undefined` as first parameter", () => {
const original = encoding.getOriginalString(undefined);
expect(original).toEqual(originalString);
});
it("accepts `undefined` as second parameter", () => {
const original = encoding.getOriginalString(0, undefined);
expect(original).toEqual(originalString);
});
it("throws an error when `begin` is out of range", () => {
expect(() => encoding.getOriginalString(1000)).toThrow();
});
it("returns the original string starting at the specified index", () => {
const original = encoding.getOriginalString(3);
expect(original).toEqual("name is john");
});
it("throws an error when `end` is out of range", () => {
expect(() => encoding.getOriginalString(0, 1000)).toThrow();
});
it("returns the original string between the two specified indexes", () => {
const original = encoding.getOriginalString(3, 7);
expect(original).toEqual("name");
});
describe("with only a negative `begin`", () => {
it("returns the original string counting from the end when in the range", () => {
const original = encoding.getOriginalString(-4);
expect(original).toEqual("john");
});
it("throws an error when out of range", () => {
expect(() => encoding.getOriginalString(-1000)).toThrow();
});
});
describe("with a positive `begin` and a negative `end`", () => {
it("returns the original string when resulting range is valid", () => {
const original = encoding.getOriginalString(3, -5);
expect(original).toEqual("name is");
});
it("throws an error when resulting `end` index is lower than `begin`", () => {
expect(() => encoding.getOriginalString(7, -10)).toThrow();
});
it("throws an error when `begin` is out of range", () => {
expect(() => encoding.getOriginalString(1000, -10)).toThrow();
});
it("throws an error when resulting `end` index is out of range", () => {
expect(() => encoding.getOriginalString(7, -1000)).toThrow();
});
});
describe("with a negative `begin` and a positive `end`", () => {
it("returns the original string when resulting range is valid", () => {
const original = encoding.getOriginalString(-7, 10);
expect(original).toEqual("is");
});
it("throws an error when resulting `begin` index is upper than `end`", () => {
expect(() => encoding.getOriginalString(-3, 5)).toThrow();
});
it("throws an error when `end` is out of range", () => {
expect(() => encoding.getOriginalString(-5, 1000)).toThrow();
});
it("throws an error when resulting `begin` index is out of range", () => {
expect(() => encoding.getOriginalString(-1000, 10)).toThrow();
});
});
describe("with negatives `begin` and `end`", () => {
it("returns the original string when resulting range is valid", () => {
const original = encoding.getOriginalString(-7, -5);
expect(original).toEqual("is");
});
it("throws an error when resulting `end` index is lower than `begin`", () => {
expect(() => encoding.getOriginalString(-5, -10)).toThrow();
});
it("throws an error when resulting `begin` index is out of range", () => {
expect(() => encoding.getOriginalString(-1000, -10)).toThrow();
});
it("throws an error when resulting `end` index is out of range", () => {
expect(() => encoding.getOriginalString(-10, -1000)).toThrow();
});
});
});
describe("truncate", () => {
it("accepts `undefined` as second parameter", () => {
expect(encoding.truncate(10, undefined)).toBeUndefined();

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-empty-function */
import { promisify } from "util";
@ -112,7 +113,7 @@ describe("Tokenizer", () => {
[2, 6],
[6, 8],
[8, 12],
[12, 16]
[0, 4]
]);
expect(encoding.getOverflowing()).toEqual([]);
expect(encoding.getSpecialTokensMask()).toEqual([0, 0, 0, 0, 0]);

11
bindings/node/lib/bindings/utils.d.ts vendored Normal file
View File

@ -0,0 +1,11 @@
/**
* Returns a subpart of a string according to specified indexes, and respecting unicode characters
*
* @param text The text for which to return a subpart
* @param [begin] The index from which to start (can be negative).
* @param [end] The index (excluded) to which to stop (can be negative).
* Stopping at the end of the string if not provided.
* @returns The full string if no start/end indexes are provided,
* otherwise the original string between `begin` (included) and `end` (excluded)
*/
export function slice(text: string, start?: number, end?: number): string;

View File

@ -0,0 +1,5 @@
const native = require("./native");
module.exports = {
slice: native.utils_slice
};

View File

@ -0,0 +1,107 @@
import { slice } from "./utils";
describe("slice", () => {
const text = "My name is John 👋";
const sliceText = slice.bind({}, text);
it("returns the full text when no params", () => {
const sliced = sliceText();
expect(sliced).toEqual(text);
});
it("accepts `undefined` as second parameter", () => {
const original = sliceText(undefined);
expect(original).toEqual(text);
});
it("accepts `undefined` as third parameter", () => {
const original = sliceText(0, undefined);
expect(original).toEqual(text);
});
it("throws an error when `begin` is out of range", () => {
expect(() => sliceText(1000)).toThrow();
});
it("returns slice starting at the specified index", () => {
const original = sliceText(3);
expect(original).toEqual("name is John 👋");
});
it("throws an error when `end` is out of range", () => {
expect(() => sliceText(0, 1000)).toThrow();
});
it("returns the text between the two specified indexes", () => {
const original = sliceText(3, 7);
expect(original).toEqual("name");
});
describe("with only a negative `begin`", () => {
it("returns the original string counting from the end when in the range", () => {
const original = sliceText(-1);
expect(original).toEqual("👋");
});
it("throws an error when out of range", () => {
expect(() => sliceText(-1000)).toThrow();
});
});
describe("with a positive `begin` and a negative `end`", () => {
it("returns correct slice when resulting range is valid", () => {
const original = sliceText(3, -7);
expect(original).toEqual("name is");
});
it("throws an error when resulting `end` index is lower than `begin`", () => {
expect(() => sliceText(7, -12)).toThrow();
});
it("throws an error when `begin` is out of range", () => {
expect(() => sliceText(1000, -12)).toThrow();
});
it("throws an error when resulting `end` index is out of range", () => {
expect(() => sliceText(7, -1000)).toThrow();
});
});
describe("with a negative `begin` and a positive `end`", () => {
it("returns correct slice when resulting range is valid", () => {
const original = sliceText(-9, 10);
expect(original).toEqual("is");
});
it("throws an error when resulting `begin` index is upper than `end`", () => {
expect(() => sliceText(-3, 5)).toThrow();
});
it("throws an error when `end` is out of range", () => {
expect(() => sliceText(-5, 1000)).toThrow();
});
it("throws an error when resulting `begin` index is out of range", () => {
expect(() => sliceText(-1000, 10)).toThrow();
});
});
describe("with negatives `begin` and `end`", () => {
it("returns correct slice when resulting range is valid", () => {
const original = sliceText(-9, -7);
expect(original).toEqual("is");
});
it("throws an error when resulting `end` index is lower than `begin`", () => {
expect(() => sliceText(-5, -10)).toThrow();
});
it("throws an error when resulting `begin` index is out of range", () => {
expect(() => sliceText(-1000, -10)).toThrow();
});
it("throws an error when resulting `end` index is out of range", () => {
expect(() => sliceText(-10, -1000)).toThrow();
});
});
});

View File

@ -5,7 +5,6 @@ export class Encoding {
private _ids?: number[];
private _length?: number;
private _offsets?: [number, number][];
private _originalString?: string;
private _overflowing?: Encoding[];
private _specialTokensMask?: number[];
private _tokens?: string[];
@ -103,27 +102,6 @@ export class Encoding {
return (this._typeIds = this.rawEncoding.getTypeIds());
}
/**
* Returns the original string
*
* @param [begin] The index from which to start (can be negative).
* @param [end] The index (excluded) to which to stop (can be negative).
* Stopping at the end of the string if not provided.
* @returns The full original string if no parameter is provided,
* otherwise the original string between `begin` and `end`
*/
getOriginalString(begin?: number, end?: number): string {
if (begin === undefined && end === undefined) {
if (this._originalString !== undefined) {
return this._originalString;
} else {
return (this._originalString = this.rawEncoding.getOriginalString());
}
}
return this.rawEncoding.getOriginalString(begin, end);
}
/**
* Pad the current Encoding at the given length
*
@ -153,7 +131,6 @@ export class Encoding {
"_ids",
"_length",
"_offsets",
"_originalString",
"_overflowing",
"_specialTokensMask",
"_tokens",

View File

@ -1,5 +1,6 @@
// export * from "./bindings";
export * from "./implementations/tokenizers";
export * from "./bindings/enums";
export { slice } from "./bindings/utils";
export { PaddingOptions, TruncationOptions } from "./bindings/tokenizer";
export { Encoding } from "./implementations/encoding";

View File

@ -0,0 +1,102 @@
/// A Container type
///
/// Provides an interface to allow transfer of ownership between Node and Rust.
/// It either contains a Box with full ownership of the content, or a pointer to the content.
///
/// The main goal here is to allow Node calling into Rust to initialize some objects. Later
/// these objects may need to be used by Rust who will expect to take ownership. Since Node
/// does not allow any sort of ownership transfer, it will keep a reference to this object
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
/// and just keep a pointer in the Node object.
pub enum Container<T: ?Sized> {
Owned(Box<T>),
Pointer(*mut T),
Empty,
}
impl<T> Container<T>
where
T: ?Sized,
{
pub fn from_ref(reference: &Box<T>) -> Self {
let content: *const T = &**reference;
Container::Pointer(content as *mut _)
}
pub fn is_owned(&self) -> bool {
if let Container::Owned(_) = &self {
true
} else {
false
}
}
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
pub fn take(self) -> Option<Box<T>> {
match self {
Container::Owned(obj) => Some(obj),
Container::Pointer(_) => None,
Container::Empty => None,
}
}
/// Replace an empty content by the new provided owned one, otherwise do nothing
pub fn to_owned(&mut self, o: Box<T>) {
if let Container::Empty = self {
unsafe {
let new_container = Container::Owned(o);
std::ptr::write(self, new_container);
}
}
}
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
pub fn to_pointer(&mut self) -> Option<Box<T>> {
if let Container::Owned(_) = self {
unsafe {
let old_container = std::ptr::read(self);
let ptr = Box::into_raw(old_container.take().unwrap());
let new_container = Container::Pointer(ptr);
std::ptr::write(self, new_container);
Some(Box::from_raw(ptr))
}
} else {
None
}
}
pub fn execute<F, U>(&self, closure: F) -> U
where
F: FnOnce(Option<&Box<T>>) -> U,
{
match self {
Container::Owned(val) => closure(Some(val)),
Container::Pointer(ptr) => unsafe {
let val = Box::from_raw(*ptr);
let res = closure(Some(&val));
// We call this to make sure we don't drop the Box
Box::into_raw(val);
res
},
Container::Empty => closure(None),
}
}
pub fn execute_mut<F, U>(&mut self, closure: F) -> U
where
F: FnOnce(Option<&mut Box<T>>) -> U,
{
match self {
Container::Owned(val) => closure(Some(val)),
Container::Pointer(ptr) => unsafe {
let mut val = Box::from_raw(*ptr);
let res = closure(Some(&mut val));
// We call this to make sure we don't drop the Box
Box::into_raw(val);
res
},
Container::Empty => closure(None),
}
}
}

View File

@ -1,6 +1,6 @@
extern crate tokenizers as tk;
use crate::utils::Container;
use crate::container::Container;
use neon::prelude::*;
/// Decoder

View File

@ -2,7 +2,7 @@ extern crate tokenizers as tk;
use tk::tokenizer::PaddingDirection;
use crate::utils::Container;
use crate::container::Container;
use neon::prelude::*;
/// Encoding
@ -159,62 +159,6 @@ declare_types! {
Ok(js_overflowings.upcast())
}
method getOriginalString(mut cx) {
// getOriginalString(begin?: number, end?: number)
let this = cx.this();
let len_original = {
let guard = cx.lock();
let len = this.borrow(&guard).encoding.execute(|encoding| {
encoding.unwrap().get_normalized().len_original()
});
len
};
let get_index = |x: i32| -> usize {
if x >= 0 {
x as usize
} else {
(len_original as i32 + x) as usize
}
};
let begin_index = if let Some(begin_arg) = cx.argument_opt(0) {
if begin_arg.downcast::<JsUndefined>().is_err() {
let begin = begin_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
get_index(begin)
} else {
0
}
} else {
0
};
let end_index = if let Some(end_arg) = cx.argument_opt(1) {
if end_arg.downcast::<JsUndefined>().is_err() {
let end = end_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
get_index(end)
} else {
len_original
}
} else {
len_original
};
let original = {
let guard = cx.lock();
let original = this.borrow(&guard).encoding.execute(|encoding| {
encoding.unwrap().get_normalized().get_range_original(begin_index..end_index)
});
original
};
if let Some(original) = original {
Ok(cx.string(original).upcast())
} else {
cx.throw_error("Error in offsets")
}
}
method pad(mut cx) {
// pad(length: number, options?: {
// direction?: 'left' | 'right' = 'right',

View File

@ -3,6 +3,7 @@
extern crate neon;
extern crate tokenizers as tk;
mod container;
mod decoders;
mod encoding;
mod models;
@ -31,6 +32,8 @@ register_module!(mut m, {
pre_tokenizers::register(&mut m, "pre_tokenizers")?;
// Trainers
trainers::register(&mut m, "trainers")?;
// Utils
utils::register(&mut m, "utils")?;
Ok(())
});

View File

@ -1,7 +1,7 @@
extern crate tokenizers as tk;
use crate::container::Container;
use crate::tasks::models::{BPEFromFilesTask, WordPieceFromFilesTask};
use crate::utils::Container;
use neon::prelude::*;
use std::path::Path;

View File

@ -1,6 +1,6 @@
extern crate tokenizers as tk;
use crate::utils::Container;
use crate::container::Container;
use neon::prelude::*;
/// Normalizer

View File

@ -1,6 +1,6 @@
extern crate tokenizers as tk;
use crate::utils::Container;
use crate::container::Container;
use neon::prelude::*;
/// PreTokenizers

View File

@ -1,6 +1,6 @@
extern crate tokenizers as tk;
use crate::utils::Container;
use crate::container::Container;
use neon::prelude::*;
/// Processor

View File

@ -1,5 +1,6 @@
extern crate tokenizers as tk;
use crate::container::Container;
use crate::decoders::JsDecoder;
use crate::models::JsModel;
use crate::normalizers::JsNormalizer;
@ -7,7 +8,6 @@ use crate::pre_tokenizers::JsPreTokenizer;
use crate::processors::JsPostProcessor;
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer};
use crate::trainers::JsTrainer;
use crate::utils::Container;
use neon::prelude::*;
use tk::tokenizer::{
@ -64,17 +64,43 @@ declare_types! {
let mut with_added_tokens = true;
if let Some(args) = cx.argument_opt(0) {
if args.downcast::<JsUndefined>().is_err() {
with_added_tokens = args.downcast::<JsBoolean>().or_throw(&mut cx)?.value() as bool;
with_added_tokens = args.downcast::<JsBoolean>()
.or_throw(&mut cx)?
.value() as bool;
}
}
let mut this = cx.this();
let guard = cx.lock();
let size = this.borrow_mut(&guard).tokenizer.get_vocab_size(with_added_tokens);
let size = this.borrow_mut(&guard)
.tokenizer
.get_vocab_size(with_added_tokens);
Ok(cx.number(size as f64).upcast())
}
method normalize(mut cx) {
// normalize(sentence: String) -> String
let sentence = cx.argument::<JsString>(0)?.value();
let this = cx.this();
let guard = cx.lock();
let result = {
this.borrow(&guard)
.tokenizer
.normalize(&sentence)
.map(|s| s.get().to_owned())
};
let normalized = result
.map_err(|e| {
cx.throw_error::<_, ()>(format!("{}", e))
.unwrap_err()
})?;
Ok(cx.string(normalized).upcast())
}
method encode(mut cx) {
// encode(
// sentence: String,

View File

@ -1,6 +1,6 @@
extern crate tokenizers as tk;
use crate::utils::Container;
use crate::container::Container;
use neon::prelude::*;
use std::collections::HashSet;

View File

@ -1,102 +1,51 @@
/// A Container type
///
/// Provides an interface to allow transfer of ownership between Node and Rust.
/// It either contains a Box with full ownership of the content, or a pointer to the content.
///
/// The main goal here is to allow Node calling into Rust to initialize some objects. Later
/// these objects may need to be used by Rust who will expect to take ownership. Since Node
/// does not allow any sort of ownership transfer, it will keep a reference to this object
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
/// and just keep a pointer in the Node object.
pub enum Container<T: ?Sized> {
Owned(Box<T>),
Pointer(*mut T),
Empty,
extern crate tokenizers as tk;
use neon::prelude::*;
/// slice(s: string, start?: number, end?: number)
fn slice(mut cx: FunctionContext) -> JsResult<JsString> {
let s = cx.argument::<JsString>(0)?.value();
let len = s.chars().count();
let get_index = |x: i32| -> usize {
if x >= 0 {
x as usize
} else {
(len as i32 + x) as usize
}
};
let begin_index = if let Some(begin_arg) = cx.argument_opt(1) {
if begin_arg.downcast::<JsUndefined>().is_err() {
let begin = begin_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
get_index(begin)
} else {
0
}
} else {
0
};
let end_index = if let Some(end_arg) = cx.argument_opt(2) {
if end_arg.downcast::<JsUndefined>().is_err() {
let end = end_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
get_index(end)
} else {
len
}
} else {
len
};
if let Some(slice) = tk::tokenizer::get_range_of(&s, begin_index..end_index) {
Ok(cx.string(slice))
} else {
cx.throw_error("Error in offsets")
}
}
impl<T> Container<T>
where
T: ?Sized,
{
pub fn from_ref(reference: &Box<T>) -> Self {
let content: *const T = &**reference;
Container::Pointer(content as *mut _)
}
pub fn is_owned(&self) -> bool {
if let Container::Owned(_) = &self {
true
} else {
false
}
}
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
pub fn take(self) -> Option<Box<T>> {
match self {
Container::Owned(obj) => Some(obj),
Container::Pointer(_) => None,
Container::Empty => None,
}
}
/// Replace an empty content by the new provided owned one, otherwise do nothing
pub fn to_owned(&mut self, o: Box<T>) {
if let Container::Empty = self {
unsafe {
let new_container = Container::Owned(o);
std::ptr::write(self, new_container);
}
}
}
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
pub fn to_pointer(&mut self) -> Option<Box<T>> {
if let Container::Owned(_) = self {
unsafe {
let old_container = std::ptr::read(self);
let ptr = Box::into_raw(old_container.take().unwrap());
let new_container = Container::Pointer(ptr);
std::ptr::write(self, new_container);
Some(Box::from_raw(ptr))
}
} else {
None
}
}
pub fn execute<F, U>(&self, closure: F) -> U
where
F: FnOnce(Option<&Box<T>>) -> U,
{
match self {
Container::Owned(val) => closure(Some(val)),
Container::Pointer(ptr) => unsafe {
let val = Box::from_raw(*ptr);
let res = closure(Some(&val));
// We call this to make sure we don't drop the Box
Box::into_raw(val);
res
},
Container::Empty => closure(None),
}
}
pub fn execute_mut<F, U>(&mut self, closure: F) -> U
where
F: FnOnce(Option<&mut Box<T>>) -> U,
{
match self {
Container::Owned(val) => closure(Some(val)),
Container::Pointer(ptr) => unsafe {
let mut val = Box::from_raw(*ptr);
let res = closure(Some(&mut val));
// We call this to make sure we don't drop the Box
Box::into_raw(val);
res
},
Container::Empty => closure(None),
}
}
/// Register everything here
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_slice", prefix), slice)?;
Ok(())
}

View File

@ -6,18 +6,28 @@ a high number of files as it avoids having too many progress bars on screen.
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
whitespaces are part of the actual token.
It has been added to `ByteLevelBPETokenizer` and but it is off by default (`trim_offsets=False`).
- `encode` and `encode_batch` no take a new optional argument, specifying whether we should add the
special tokens. This stays activated by default.
It has been added to `ByteLevelBPETokenizer` but it is off by default (`trim_offsets=False`).
([#188](https://github.com/huggingface/tokenizers/pull/188))
- `encode` and `encode_batch` now take a new optional argument, specifying whether we should add the
special tokens. This is activated by default. ([#193](https://github.com/huggingface/tokenizers/pull/193))
- `original_str` and `normalized_str` have been removed from the `Encoding` returned by `encode` and
`encode_batch`. This brings a reduction of 70% the memory footprint.
([#197](https://github.com/huggingface/tokenizers/pull/197))
## Fixes:
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE ([#193](https://github.com/huggingface/tokenizers/pull/193)):
- when `add_prefix_space=True`
- when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156))
## How to migrate:
- Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant. If you are
using `ByteLevelBPETokenizer`, this option is disabled by default (`trim_offsets=False`).
- Access to the `original_str` on the `Encoding` has been removed. The original string is the input
of `encode` so it didn't make sense to keep it here.
- No need to call `original_str.offsets(offsets[N])` to convert offsets to the original string. They
are now relative to the original string by default.
- Access to the `normalized_str` on the `Encoding` has been removed. Can be retrieved by calling
`normalize(sequence)` on the `Tokenizer`
# v0.6.0

View File

@ -1,115 +1,11 @@
extern crate tokenizers as tk;
use crate::error::PyError;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{PyMappingProtocol, PyObjectProtocol, PySequenceProtocol};
use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::PaddingDirection;
fn get_range(item: PyObject, max_len: usize) -> PyResult<std::ops::Range<usize>> {
let gil = Python::acquire_gil();
let py = gil.python();
let slice = if let Ok(index) = item.extract::<isize>(py) {
if index >= max_len as isize || index < -(max_len as isize) {
Err(exceptions::IndexError::py_err("Index out of bounds"))
} else {
Ok(if index == -1 {
PySlice::new(py, index, max_len as isize, 1)
} else {
PySlice::new(py, index, index + 1, 1)
})
}
} else if let Ok(slice) = item.cast_as::<PySlice>(py) {
Ok(slice)
} else if let Ok(offset) = item.cast_as::<PyTuple>(py) {
if offset.len() == 2 {
let start = offset.get_item(0).extract::<isize>()?;
let end = offset.get_item(1).extract::<isize>()?;
Ok(PySlice::new(py, start, end, 1))
} else {
Err(exceptions::TypeError::py_err("Expected Tuple[int, int]"))
}
} else {
Err(exceptions::TypeError::py_err(
"Expected number or slice or Tuple[int, int]",
))
}?;
// Find out range from the slice
let len: std::os::raw::c_long = (max_len as i32) as _;
let PySliceIndices { start, stop, .. } = slice.indices(len)?;
Ok(start as usize..stop as usize)
}
enum IndexableStringType {
Original,
Normalized,
}
#[pyclass(dict)]
pub struct IndexableString {
s: tk::tokenizer::NormalizedString,
t: IndexableStringType,
}
#[pymethods]
impl IndexableString {
fn offsets(&self, item: PyObject) -> PyResult<Option<(usize, usize)>> {
let range = get_range(item, self.s.len())?;
match self.t {
IndexableStringType::Original => Ok(self
.s
.get_original_offsets(range)
.map(|range| (range.start, range.end))),
IndexableStringType::Normalized => Ok(Some((range.start, range.end))),
}
}
}
#[pyproto]
impl PyObjectProtocol for IndexableString {
fn __repr__(&self) -> PyResult<String> {
Ok(match self.t {
IndexableStringType::Original => self.s.get_original().to_owned(),
IndexableStringType::Normalized => self.s.get().to_owned(),
})
}
fn __str__(&self) -> PyResult<String> {
Ok(match self.t {
IndexableStringType::Original => self.s.get_original().to_owned(),
IndexableStringType::Normalized => self.s.get().to_owned(),
})
}
}
#[pyproto]
impl PyMappingProtocol for IndexableString {
fn __getitem__(&self, item: PyObject) -> PyResult<String> {
// Find out the range
let range = get_range(item, self.s.len())?;
// Get the range from the relevant string
let s = match self.t {
IndexableStringType::Original => self.s.get_range_original(range),
IndexableStringType::Normalized => self.s.get_range(range),
};
s.map(|s| s.to_owned())
.ok_or_else(|| exceptions::IndexError::py_err("Wrong offsets"))
}
fn __len__(self) -> PyResult<usize> {
Ok(match self.t {
IndexableStringType::Original => self.s.len_original(),
IndexableStringType::Normalized => self.s.len(),
})
}
}
#[pyclass(dict)]
#[repr(transparent)]
pub struct Encoding {
@ -127,7 +23,7 @@ impl PyObjectProtocol for Encoding {
fn __repr__(&self) -> PyResult<String> {
Ok(format!(
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
attention_mask, special_tokens_mask, overflowing, original_str, normalized_str])",
attention_mask, special_tokens_mask, overflowing])",
self.encoding.get_ids().len()
))
}
@ -142,50 +38,6 @@ impl PySequenceProtocol for Encoding {
#[pymethods]
impl Encoding {
#[getter]
fn get_normalized_str(&self) -> IndexableString {
IndexableString {
s: self.encoding.get_normalized().clone(),
t: IndexableStringType::Normalized,
}
}
#[getter]
fn get_original_str(&self) -> IndexableString {
IndexableString {
s: self.encoding.get_normalized().clone(),
t: IndexableStringType::Original,
}
}
#[args(kwargs = "**")]
fn get_range(
&self,
range: (usize, usize),
kwargs: Option<&PyDict>,
) -> PyResult<Option<String>> {
let mut original = false;
if let Some(kwargs) = kwargs {
if let Some(koriginal) = kwargs.get_item("original") {
original = koriginal.extract()?;
}
}
if original {
Ok(self
.encoding
.get_normalized()
.get_range_original(range.0..range.1)
.map(|s| s.to_owned()))
} else {
Ok(self
.encoding
.get_normalized()
.get_range(range.0..range.1)
.map(|s| s.to_owned()))
}
}
#[getter]
fn get_ids(&self) -> Vec<u32> {
self.encoding.get_ids().to_vec()

View File

@ -159,6 +159,15 @@ impl Tokenizer {
self.tokenizer.with_padding(None);
}
fn normalize(&self, sentence: &str) -> PyResult<String> {
ToPyResult(
self.tokenizer
.normalize(sentence)
.map(|s| s.get().to_owned()),
)
.into()
}
#[args(add_special_tokens = true)]
fn encode(
&self,

View File

@ -16,32 +16,9 @@ from typing import Optional, Union, List, Tuple
Offsets = Tuple[int, int]
class IndexableString:
"""
Works almost like a `str`, but allows indexing on offsets
provided on an `Encoding`
"""
def offsets(self, offsets: Tuple[int, int]) -> Optional[Tuple[int, int]]:
""" Convert the Encoding's offsets to the current string.
`Encoding` provides a list of offsets that are actually offsets to the Normalized
version of text. Calling this method with the offsets provided by `Encoding` will make
sure that said offsets can be used to index the `str` directly.
"""
pass
class Encoding:
""" An Encoding as returned by the Tokenizer """
@property
def normalized_str(self) -> IndexableString:
""" The normalized string """
pass
@property
def original_str(self) -> IndexableString:
""" The original string """
pass
@property
def ids(self) -> List[int]:
""" The tokenized ids """
@ -244,6 +221,17 @@ class Tokenizer:
def no_padding(self):
""" Disable padding """
pass
def normalize(self, sequence: str) -> str:
""" Normalize the given sequence
Args:
sequence: str:
The sequence to normalize
Returns:
The normalized string
"""
pass
def encode(
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
) -> Encoding:

View File

@ -125,6 +125,18 @@ class BaseTokenizer:
"""
return self._tokenizer.add_special_tokens(special_tokens)
def normalize(self, sequence: str) -> str:
""" Normalize the given sequence
Args:
sequence: str:
The sequence to normalize
Returns:
The normalized string
"""
return self._tokenizer.normalize(sequence)
def encode(
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
) -> Encoding:

View File

@ -6,9 +6,16 @@ a high number of files as it avoids having too many progress bars on screen.
- Improve BPE and WordPiece builders.
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
whitespaces are part of the actual token.
whitespaces are part of the actual token. ([#188](https://github.com/huggingface/tokenizers/pull/188))
- `encode` and `encode_batch` now take a new argument, specifying whether we should add the
special tokens.
special tokens. ([#193](https://github.com/huggingface/tokenizers/pull/193))
- The `NormalizedString` has been removed from the `Encoding`. It is now possible to retrieve it
by calling `normalized` on the `Tokenizer`. This brings a reduction of 70% of the memory footprint
([#197](https://github.com/huggingface/tokenizers/pull/197))
- The `NormalizedString` API has been improved. It is now possible to retrieve part of both strings
using both "normalized" or "original" offsets. ([#197](https://github.com/huggingface/tokenizers/pull/197))
- The offsets provided on `Encoding` are now relative to the original string, and not the normalized
one anymore. ([#197](https://github.com/huggingface/tokenizers/pull/197))
## Fixes:
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:

View File

@ -1,5 +1,12 @@
DATA_DIR = data
BENCHMARK_DIR = benches
BENCHMARK_RESOURCES = $(BENCHMARK_DIR)/gpt2-vocab.json $(BENCHMARK_DIR)/gpt2-merges.txt $(BENCHMARK_DIR)/big.txt
TESTS_DIR = tests
dir_guard=@mkdir -p $(@D)
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt
TESTS_RESOURCES = $(SHARED_RESOURCES)
.PHONY : build
build :
@ -20,7 +27,7 @@ lint :
cargo clippy --all-targets --all-features -- -D warnings
.PHONY : test
test :
test : $(TESTS_RESOURCES)
cargo test
.PHONY : doc
@ -38,8 +45,10 @@ all-checks : lint test doc
bench : $(BENCHMARK_RESOURCES)
cargo bench -- --verbose
$(BENCHMARK_DIR)/gpt2-% :
$(DATA_DIR)/gpt2-% :
$(dir_guard)
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-$* -O $@
$(BENCHMARK_DIR)/big.txt :
$(DATA_DIR)/big.txt :
$(dir_guard)
wget https://norvig.com/big.txt -O $@

View File

@ -1,2 +0,0 @@
*.txt
*.json

View File

@ -68,13 +68,13 @@ fn iter_bench_encode_batch(
}
fn bench_gpt2(c: &mut Criterion) {
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
let bpe = BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
.build()
.unwrap();
let tokenizer = create_gpt2_tokenizer(bpe);
let mut lines: Vec<EncodeInput> = vec![];
let mut batches: Vec<Vec<EncodeInput>> = vec![vec![]];
for line in BufReader::new(File::open(Path::new("benches/big.txt")).unwrap())
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap())
.lines()
.map(line_to_input)
{
@ -93,7 +93,7 @@ fn bench_gpt2(c: &mut Criterion) {
b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches))
});
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
let bpe = BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
.cache_capacity(0)
.build()
.unwrap();

View File

@ -239,7 +239,7 @@ impl PostProcessor for ByteLevel {
None => encoding,
Some(mut pair) => {
process_offsets(&mut pair);
encoding.merge_with(pair);
encoding.merge_with(pair, false);
encoding
}
};
@ -251,7 +251,9 @@ impl PostProcessor for ByteLevel {
#[cfg(test)]
mod tests {
use super::ByteLevel;
use crate::tokenizer::{Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizer};
use crate::tokenizer::{
Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizer, Range,
};
#[test]
fn pre_tokenization() {
@ -391,13 +393,12 @@ mod tests {
]
);
assert_eq!(input.get(), "iâŃ¢j");
assert_eq!(input.get_range_original(1..4), Some("".into()));
assert_eq!(input.get_range_original(Range::Normalized(1..4)), Some(""));
}
#[test]
fn processor_trims_offsets() {
let start = Encoding::new(
NormalizedString::from(""),
vec![],
vec![],
vec![
@ -412,7 +413,6 @@ mod tests {
vec![],
);
let expected = Encoding::new(
NormalizedString::from(""),
vec![],
vec![],
vec![
@ -434,7 +434,7 @@ mod tests {
);
let mut pair_expected = expected.clone();
pair_expected.merge_with(expected);
pair_expected.merge_with(expected, false);
assert_eq!(
pair_expected,
bytelevel

View File

@ -43,7 +43,6 @@ impl PostProcessor for BertProcessing {
let attention_mask = vec![1; ids.len()];
let mut new_encoding = Encoding::new(
encoding.get_normalized().clone(),
ids,
type_ids,
tokens,
@ -68,7 +67,6 @@ impl PostProcessor for BertProcessing {
let attention_mask = vec![1; ids.len()];
Encoding::new(
encoding.get_normalized().clone(),
ids,
type_ids,
tokens,
@ -91,7 +89,6 @@ impl PostProcessor for BertProcessing {
let pair_attention_mask = vec![1; pair_ids.len()];
let new_pair_encoding = Encoding::new(
encoding.get_normalized().clone(),
pair_ids,
pair_type_ids,
pair_tokens,
@ -112,7 +109,6 @@ impl PostProcessor for BertProcessing {
let pair_attention_mask = vec![1; pair_ids.len()];
Encoding::new(
encoding.get_normalized().clone(),
pair_ids,
pair_type_ids,
pair_tokens,
@ -125,7 +121,7 @@ impl PostProcessor for BertProcessing {
.collect(),
);
new_encoding.merge_with(new_pair_encoding);
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(new_encoding)

View File

@ -43,7 +43,6 @@ impl PostProcessor for RobertaProcessing {
let attention_mask = vec![1; ids.len()];
let mut new_encoding = Encoding::new(
encoding.get_normalized().clone(),
ids,
type_ids,
tokens,
@ -68,7 +67,6 @@ impl PostProcessor for RobertaProcessing {
let pair_attention_mask = vec![1; pair_ids.len()];
let new_pair_encoding = Encoding::new(
encoding.get_normalized().clone(),
pair_ids,
pair_type_ids,
pair_tokens,
@ -78,7 +76,7 @@ impl PostProcessor for RobertaProcessing {
encoding.take_overflowing(),
);
new_encoding.merge_with(new_pair_encoding);
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(new_encoding)

View File

@ -1,26 +1,9 @@
use crate::tokenizer::NormalizedString;
use crate::utils::padding::PaddingDirection;
use rayon::prelude::*;
/// The various possible padding directions.
#[derive(Debug, Clone, Copy)]
pub enum PaddingDirection {
Left,
Right,
}
impl std::convert::AsRef<str> for PaddingDirection {
fn as_ref(&self) -> &str {
match self {
PaddingDirection::Left => "left",
PaddingDirection::Right => "right",
}
}
}
/// Represents the output of a `Tokenizer`.
#[derive(Default, PartialEq, Debug, Clone)]
pub struct Encoding {
normalized: NormalizedString,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
@ -32,7 +15,6 @@ pub struct Encoding {
impl Encoding {
#[allow(clippy::too_many_arguments)]
pub fn new(
normalized: NormalizedString,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
@ -42,7 +24,6 @@ impl Encoding {
overflowing: Vec<Encoding>,
) -> Self {
Encoding {
normalized,
ids,
type_ids,
tokens,
@ -53,10 +34,6 @@ impl Encoding {
}
}
pub fn get_normalized(&self) -> &NormalizedString {
&self.normalized
}
pub fn get_tokens(&self) -> &[String] {
&self.tokens[..]
}
@ -124,7 +101,6 @@ impl Encoding {
}
let o = Encoding {
normalized: self.normalized.clone(),
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
type_ids: get_current_part(
&prev_encoding.type_ids,
@ -173,7 +149,7 @@ impl Encoding {
}
/// Merge ourself with the given `Encoding`. Happens in place.
pub fn merge_with(&mut self, pair: Encoding) {
pub fn merge_with(&mut self, pair: Encoding, growing_offsets: bool) {
// Handle merging the overflowing parts too: Combine them all
// In most of the cases, we expect `pair.overflowing.len() == 0`
let mut overflowings = vec![];
@ -182,33 +158,33 @@ impl Encoding {
for self_o in &self.overflowing {
// 1. The pair itself
let mut n_encoding = self_o.clone();
n_encoding.merge_with(pair.clone());
n_encoding.merge_with(pair.clone(), growing_offsets);
overflowings.push(n_encoding);
// 2. Its overflowings (this should rarely happen...)
for other_o in &pair.overflowing {
let mut n_encoding = self_o.clone();
n_encoding.merge_with(other_o.clone());
n_encoding.merge_with(other_o.clone(), growing_offsets);
overflowings.push(n_encoding);
}
}
// 2. Ourself with all the other overflowings (this should rarely happen too...)
for other_o in &pair.overflowing {
let mut n_encoding = self.clone();
n_encoding.merge_with(other_o.clone());
n_encoding.merge_with(other_o.clone(), growing_offsets);
overflowings.push(n_encoding);
}
// Finish by merging ourself with the other encoding
self.normalized.merge_with(&pair.normalized);
self.ids.extend(pair.ids);
self.type_ids.extend(pair.type_ids);
self.tokens.extend(pair.tokens);
let starting_offset = self
.offsets
.iter()
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
let starting_offset = if growing_offsets {
self.offsets.last().map_or(0, |o| o.1)
} else {
0
};
self.offsets.extend(
pair.offsets
.into_iter()
@ -304,7 +280,6 @@ mod tests {
#[test]
fn merge_encodings() {
let mut a = Encoding {
normalized: NormalizedString::from("Hello "),
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello ")],
@ -314,7 +289,6 @@ mod tests {
overflowing: vec![],
};
let b = Encoding {
normalized: NormalizedString::from("World!"),
ids: vec![2],
type_ids: vec![1],
tokens: vec![String::from("World!")],
@ -323,12 +297,11 @@ mod tests {
attention_mask: vec![1],
overflowing: vec![],
};
a.merge_with(b);
a.merge_with(b, true);
assert_eq!(
a,
Encoding {
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2],
type_ids: vec![0, 1],
tokens: vec![String::from("Hello "), String::from("World!")],
@ -343,7 +316,6 @@ mod tests {
#[test]
fn truncate() {
let mut a = Encoding {
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
@ -361,7 +333,6 @@ mod tests {
assert_eq!(
a,
Encoding {
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2],
type_ids: vec![0, 0],
tokens: vec![String::from("Hello"), String::from("World")],
@ -369,7 +340,6 @@ mod tests {
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: vec![Encoding {
normalized: NormalizedString::from("Hello World!"),
ids: vec![3],
type_ids: vec![0],
tokens: vec![String::from("!")],

View File

@ -9,10 +9,9 @@
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
//! ...).
pub use crate::utils::{
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
TruncationStrategy,
};
use crate::utils::iter::ResultShunt;
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::{
@ -31,6 +30,11 @@ pub use normalizer::*;
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
pub type Offsets = (usize, usize);
/// Takes care of pre-processing strings.
pub trait Normalizer {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
}
/// The `PreTokenizer` is in charge of doing the pre-segmentation step. It splits the given string
/// in multiple substrings, keeping track of the offsets of said substrings from the
/// `NormalizedString`. In some occasions, the `PreTokenizer` might need to modify the given
@ -71,7 +75,7 @@ impl dyn PostProcessor {
match pair_encoding {
None => Ok(encoding),
Some(pair) => {
encoding.merge_with(pair);
encoding.merge_with(pair, false);
Ok(encoding)
}
}
@ -290,99 +294,154 @@ impl Tokenizer {
}
}
pub fn num_added_tokens(&self, is_pair: bool) -> usize {
self.post_processor
.as_ref()
.map_or(0, |p| p.as_ref().added_tokens(is_pair))
/// Normalize the given sentence and return the corresponding normalized string
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
let mut normalized = self
.split_on_added_tokens(sentence)
.into_iter()
.map(|(sentence, id)| -> Result<NormalizedString> {
if id.is_some() {
Ok(NormalizedString::from(&sentence))
} else {
let mut normalized = self.do_normalize(&sentence)?;
let _ = self.pre_tokenize(&mut normalized)?;
Ok(normalized)
}
})
.collect::<Result<Vec<_>>>()?;
let others = normalized.split_off(1);
let mut normalized: NormalizedString = normalized.into_iter().next().unwrap();
for n in others {
normalized.merge_with(&n);
}
Ok(normalized)
}
/// Encode the given sentence
pub fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> Result<Encoding> {
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
// First we need to split into as many sequences as needed to avoid splitting
// on our added tokens
let mut encodings = self
.split_on_added_tokens(&sentence)
.into_iter()
.map(|(sentence, id)| -> Result<Encoding> {
// If this is one of our added tokens, lets return an encoding directly
if let Some(id) = id {
return Ok(Encoding::new(
NormalizedString::from(&sentence),
vec![id],
vec![type_id],
vec![sentence.to_owned()],
vec![(0, sentence.len())],
vec![0],
vec![1],
vec![],
));
}
let generate_output =
move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> {
// First we need to split into as many sequences as needed to avoid splitting
// on our added tokens
let results = self.split_on_added_tokens(&sentence).into_iter().map(
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
// If this is one of our added tokens, lets return an encoding directly
if let Some(id) = id {
return Ok((
Encoding::new(
vec![id],
vec![type_id],
vec![sentence.to_owned()],
vec![(0, sentence.len())],
vec![0],
vec![1],
vec![],
),
NormalizedString::from(&sentence),
));
}
// 1. Normalization
let mut normalized = self.normalize(&sentence)?;
// 1. Normalization
let mut normalized = self.do_normalize(&sentence)?;
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
// 3. Model
let output = self.model.tokenize(pre_tokenized)?;
let length = output.len();
// 3. Model
let output = self.model.tokenize(pre_tokenized)?;
let length = output.len();
let (ids, tokens, offsets) = output.into_iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(mut ids, mut tokens, mut offsets), t| {
ids.push(t.id);
tokens.push(t.value);
offsets.push(t.offsets);
(ids, tokens, offsets)
},
);
let (ids, tokens, offsets) = output.into_iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(mut ids, mut tokens, mut offsets), t| {
ids.push(t.id);
tokens.push(t.value);
offsets.push(t.offsets);
(ids, tokens, offsets)
},
);
Ok(Encoding::new(
normalized,
ids,
vec![type_id; length],
tokens,
offsets,
vec![0; length],
vec![1; length],
vec![],
))
})
.collect::<Result<Vec<Encoding>>>()?;
Ok((
Encoding::new(
ids,
vec![type_id; length],
tokens,
offsets,
vec![0; length],
vec![1; length],
vec![],
),
normalized,
))
},
);
if encodings.is_empty() {
return Ok(Encoding::default());
}
let (mut encodings, mut normalized) =
ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
let others = encodings.split_off(1);
let mut first: Encoding = encodings.into_iter().next().unwrap();
if encodings.is_empty() {
return Ok((Encoding::default(), NormalizedString::from("")));
}
for encoding in others {
first.merge_with(encoding);
}
let others = encodings.split_off(1);
let mut first: Encoding = encodings.into_iter().next().unwrap();
Ok(first)
};
for encoding in others {
first.merge_with(encoding, true);
}
let others = normalized.split_off(1);
let mut normalized: NormalizedString = normalized.into_iter().next().unwrap();
for n in others {
normalized.merge_with(&n);
}
Ok((first, normalized))
};
let (sentence, pair) = match input {
EncodeInput::Single(s1) => (s1, None),
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
};
let encoding = generate_output(sentence, 0)?;
let pair_encoding = match pair {
Some(pair) => Some(generate_output(pair, 1)?),
None => None,
let (encoding, normalized) = generate_output(sentence, 0)?;
let (pair_encoding, pair_normalized) = match pair {
Some(pair) => {
let (e, n) = generate_output(pair, 1)?;
(Some(e), Some(n))
}
None => (None, None),
};
// 4. Post processing
self.post_process(encoding, pair_encoding, add_special_tokens)
let mut output = self.post_process(encoding, pair_encoding, add_special_tokens)?;
// 5. Convert offsets back to original string
let mut current_offset = (0, 0);
let mut n_source = &normalized;
output
.get_offsets_mut()
.iter_mut()
.for_each(|(start, end)| {
if (*start, *end) < current_offset {
n_source = &pair_normalized.as_ref().unwrap_or(&normalized);
}
current_offset = (*start, *end);
let (s, e) = n_source
.convert_offsets(Range::Normalized(*start..*end))
.map_or((*start, *end), |range| (range.start, range.end));
*start = s;
*end = e;
});
Ok(output)
}
/// Encode all the sentences in parallel, using multiple threads
@ -471,7 +530,7 @@ impl Tokenizer {
match file.read_line(&mut buf)? {
0 => break,
b => {
let mut normalized = self.normalize(&buf)?;
let mut normalized = self.do_normalize(&buf)?;
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
trainer.process_tokens(
&mut words,
@ -522,7 +581,7 @@ impl Tokenizer {
}
/// Normalization logic, go through all normalizers
fn normalize(&self, sequence: &str) -> Result<NormalizedString> {
fn do_normalize(&self, sequence: &str) -> Result<NormalizedString> {
let mut normalized = NormalizedString::from(sequence);
if let Some(normalizer) = &self.normalizer {

View File

@ -1,21 +1,63 @@
use super::Result;
use std::cmp::Ordering;
use std::ops::{Bound, RangeBounds};
use unicode_normalization_alignments::UnicodeNormalization;
/// Takes care of pre-processing strings.
pub trait Normalizer {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
/// Represents a Range usable by the NormalizedString to index its content.
/// A Range can use indices relative to either the `Original` or the `Normalized` string
#[derive(Debug, Clone, Copy)]
pub enum Range<T: RangeBounds<usize>> {
Original(T),
Normalized(T),
}
/// A normalized string takes care of keeping both versions of a `String`, and
/// provides necessary alignments to retrieve ranges of both strings.
impl<T> Range<T>
where
T: RangeBounds<usize>,
{
/// Unwrap the underlying range
fn unwrap(self) -> T {
match self {
Range::Original(r) => r,
Range::Normalized(r) => r,
}
}
/// Converts the current Range to a `std::ops::Range<usize>`. This requires the `max_len`
/// of the represented string (in chars, not bytes) in order to cover the case where the
/// original provided range was unbounded
fn into_full_range(self, max_len: usize) -> std::ops::Range<usize> {
let range = self.unwrap();
let start = match range.start_bound() {
Bound::Unbounded => 0,
Bound::Included(i) => *i,
Bound::Excluded(i) => *i + 1,
};
let end = match range.end_bound() {
Bound::Unbounded => max_len,
Bound::Included(i) => *i + 1,
Bound::Excluded(i) => *i,
};
start..end
}
}
/// A `NormalizedString` takes care of processing an "original" string to modify it and obtain a
/// "normalized" string. It keeps both version of the string, alignments information between both
/// and provides an interface to retrieve ranges of each string, using offsets from any of them.
///
/// It is possible to retrieve a part of the original string, by indexing it with offsets from the
/// normalized one, and the other way around too. It is also possible to convert offsets from one
/// referential to the other one easily.
#[derive(Default, Debug, Clone)]
pub struct NormalizedString {
/// The original version of the string, before any modification
original: String,
/// The normalized version of the string, after all modifications
normalized: String,
/// Mapping from normalized string to original one
/// (pos, changes) where pos is the position in the modified string, and changes an isize
/// representing the number of insertions or deletions
/// Mapping from normalized string to original one: (start, end) for each character of the
/// normalized string
alignments: Vec<(usize, usize)>,
}
@ -26,6 +68,7 @@ impl std::cmp::PartialEq for NormalizedString {
}
impl NormalizedString {
/// Create a NormalizedString from the given str
pub fn from(s: &str) -> Self {
NormalizedString {
original: s.to_owned(),
@ -44,60 +87,68 @@ impl NormalizedString {
&self.original
}
/// Return the range of the original string corresponding to the received range on the
/// normalized string. Returns None if out of bounds
pub fn get_original_offsets(
/// Convert the given offsets range from one referential to the other one:
/// `Original => Normalized` or `Normalized => Original`
pub fn convert_offsets<T: RangeBounds<usize>>(
&self,
range: std::ops::Range<usize>,
range: Range<T>,
) -> Option<std::ops::Range<usize>> {
self.alignments
.get(range)
.map(|alignments| {
if alignments.is_empty() {
None
} else {
let start = alignments[0].0;
let end = alignments[alignments.len() - 1].1;
Some(start..end)
}
})
.flatten()
}
fn get_range_of(&self, s: &str, range: std::ops::Range<usize>) -> Option<String> {
let len = s.chars().count();
if range.start >= len || range.end > len {
None
} else {
Some(
s.chars()
match range {
Range::Original(_) => {
let (mut start, mut end) = (0, 0);
let r = range.into_full_range(self.alignments.last().map_or(0, |(_, e)| *e));
println!("{:?}\t{:?}", r, self.alignments);
self.alignments
.iter()
.enumerate()
.skip(range.start)
.map(|(i, c)| {
if i >= range.start && i < range.end {
Some(c)
} else {
None
.take_while(|(_, alignment)| r.end >= alignment.1)
.for_each(|(i, alignment)| {
println!("{:?}", alignment);
if alignment.0 <= r.start {
start = i;
}
})
.fuse()
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect::<String>(),
)
if alignment.1 <= r.end {
end = i + 1;
}
});
Some(start..end)
}
Range::Normalized(_) => self
.alignments
.get(range.into_full_range(self.alignments.len()))
.map(|alignments| {
if alignments.is_empty() {
None
} else {
let start = alignments[0].0;
let end = alignments[alignments.len() - 1].1;
Some(start..end)
}
})
.flatten(),
}
}
/// Return a range of the normalized string (indexing on char not bytes)
pub fn get_range(&self, range: std::ops::Range<usize>) -> Option<String> {
self.get_range_of(&self.normalized, range)
pub fn get_range<T: RangeBounds<usize>>(&self, range: Range<T>) -> Option<&str> {
match range {
Range::Original(_) => self
.convert_offsets(range)
.map(|r| get_range_of(&self.normalized, r))
.flatten(),
Range::Normalized(r) => get_range_of(&self.normalized, r),
}
}
/// Return a range of the original string, using a range from the normalized string
pub fn get_range_original(&self, range: std::ops::Range<usize>) -> Option<String> {
self.get_original_offsets(range)
.map(|range| self.get_range_of(&self.original, range))
.flatten()
/// Return a range of the original string (indexing on char not bytes)
pub fn get_range_original<T: RangeBounds<usize>>(&self, range: Range<T>) -> Option<&str> {
match range {
Range::Original(r) => get_range_of(&self.original, r),
Range::Normalized(_) => self
.convert_offsets(range)
.map(|r| get_range_of(&self.original, r))
.flatten(),
}
}
/// Applies transformations to the current normalized version, updating the current
@ -115,19 +166,10 @@ impl NormalizedString {
/// them has a `change` of `1`, but more doesn't make any sense.
/// We treat any value above `1` as `1`.
pub fn transform<I: Iterator<Item = (char, isize)>>(&mut self, dest: I, initial_offset: usize) {
let mut offset = 0;
let mut remaining_offset = initial_offset;
let mut offset = -(initial_offset as isize);
let (ch, alignments): (Vec<_>, Vec<_>) = dest
.enumerate()
.map(|(index, (c, changes))| {
let changes = if remaining_offset != 0 {
let c = changes - remaining_offset as isize;
remaining_offset = 0;
c
} else {
changes
};
let uof = if offset < 0 {
-offset as usize
} else {
@ -149,24 +191,10 @@ impl NormalizedString {
}
// No changes required here
Ordering::Equal => self.alignments.get(idx).copied(),
// Some characters where removed, so we merge our range with the one from the
// removed characters as the new alignment
// Some characters where removed, nothing to change in alignments
Ordering::Less => {
let uch = -changes as usize;
offset += changes;
self.alignments.get(idx..=idx + uch).map(|alignments| {
let min = alignments
.iter()
.map(|(start, end)| usize::min(*start, *end))
.min()
.unwrap();
let max = alignments
.iter()
.map(|(start, end)| usize::max(*start, *end))
.max()
.unwrap();
(min, max)
})
self.alignments.get(idx).copied()
}
};
@ -421,6 +449,37 @@ impl NormalizedString {
}
}
/// Returns a range of the given string slice, by indexing chars instead of bytes
pub fn get_range_of<T: RangeBounds<usize>>(s: &str, range: T) -> Option<&str> {
let len = s.chars().count();
let start = match range.start_bound() {
Bound::Unbounded => 0,
Bound::Included(i) => *i,
Bound::Excluded(i) => *i + 1,
};
let end = match range.end_bound() {
Bound::Unbounded => len,
Bound::Included(i) => *i + 1,
Bound::Excluded(i) => *i,
};
if start >= len || end > len || start >= end {
None
} else {
let start_b = s
.char_indices()
.map(|(i, _)| i)
.nth(start as usize)
.unwrap_or(0);
let end_b = s
.char_indices()
.map(|(i, _)| i)
.nth(end as usize)
.unwrap_or_else(|| s.len());
Some(&s[start_b..end_b])
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -462,7 +521,7 @@ mod tests {
n.filter(|c| *c != 'n');
assert_eq!(
&n.alignments,
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), (6, 7)]
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
);
}
@ -472,18 +531,43 @@ mod tests {
n.nfd().filter(|c| !c.is_mark_nonspacing() && *c != 'n');
assert_eq!(
&n.alignments,
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), (6, 7)]
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
);
}
#[test]
fn range_conversion() {
let mut n = NormalizedString::from(" __Hello__ ");
n.filter(|c| !c.is_whitespace()).lowercase();
let hello_n = n.convert_offsets(Range::Original(6..11));
assert_eq!(hello_n, Some(2..7));
assert_eq!(
n.get_range(Range::Normalized(hello_n.clone().unwrap())),
Some("hello")
);
assert_eq!(
n.get_range_original(Range::Normalized(hello_n.unwrap())),
Some("Hello")
);
assert_eq!(n.get_range(Range::Original(6..11)), Some("hello"));
assert_eq!(n.get_range_original(Range::Original(6..11)), Some("Hello"));
}
#[test]
fn original_range() {
let mut n = NormalizedString::from("Hello_______ World!");
n.filter(|c| *c != '_').lowercase();
let world_n = n.get_range(6..11).unwrap();
let world_o = n.get_range_original(6..11).unwrap();
let world_n = n.get_range(Range::Normalized(6..11)).unwrap();
let world_o = n.get_range_original(Range::Normalized(6..11)).unwrap();
assert_eq!(world_n, "world");
assert_eq!(world_o, "World");
let original_range = Range::Original(n.convert_offsets(Range::Normalized(6..11)).unwrap());
assert_eq!(n.get_range(original_range.clone()).unwrap(), "world");
assert_eq!(
n.get_range_original(original_range.clone()).unwrap(),
"World"
);
assert_eq!(original_range.into_full_range(n.len_original()), 13..18);
}
#[test]
@ -505,8 +589,8 @@ mod tests {
assert_eq!(&n.normalized, " Hello ");
assert_eq!(
n.get_range_original(0..n.normalized.len()),
Some("Hello".into())
n.get_range_original(Range::Normalized(1..n.normalized.len() - 1)),
Some("Hello")
);
}
@ -514,10 +598,13 @@ mod tests {
fn remove_at_beginning() {
let mut n = NormalizedString::from(" Hello");
n.filter(|c| !c.is_whitespace());
assert_eq!(n.get_range_original(1.."Hello".len()), Some("ello".into()));
assert_eq!(
n.get_range_original(0..n.normalized.len()),
Some(" Hello".into())
n.get_range_original(Range::Normalized(1.."Hello".len())),
Some("ello")
);
assert_eq!(
n.get_range_original(Range::Normalized(0..n.normalized.len())),
Some("Hello")
);
}
@ -525,10 +612,10 @@ mod tests {
fn remove_at_end() {
let mut n = NormalizedString::from("Hello ");
n.filter(|c| !c.is_whitespace());
assert_eq!(n.get_range_original(0..4), Some("Hell".into()));
assert_eq!(n.get_range_original(Range::Normalized(0..4)), Some("Hell"));
assert_eq!(
n.get_range_original(0..n.normalized.len()),
Some("Hello ".into())
n.get_range_original(Range::Normalized(0..n.normalized.len())),
Some("Hello")
);
}
@ -539,10 +626,13 @@ mod tests {
assert_eq!(&n.normalized, "Hello");
assert_eq!(
n.get_range_original(0.."Hello".len()),
Some(" Hello ".into())
n.get_range_original(Range::Normalized(0.."Hello".len())),
Some("Hello")
);
assert_eq!(
n.get_range_original(Range::Normalized(1.."Hell".len())),
Some("ell")
);
assert_eq!(n.get_range_original(1.."Hell".len()), Some("ell".into()));
}
#[test]
@ -551,8 +641,8 @@ mod tests {
n.lstrip();
assert_eq!(&n.normalized, "This is an example ");
assert_eq!(
n.get_range_original(0..n.normalized.len()),
Some(" This is an example ".into())
n.get_range_original(Range::Normalized(0..n.normalized.len())),
Some("This is an example ")
);
}
@ -562,8 +652,8 @@ mod tests {
n.rstrip();
assert_eq!(&n.normalized, " This is an example");
assert_eq!(
n.get_range_original(0..n.normalized.len()),
Some(" This is an example ".into())
n.get_range_original(Range::Normalized(0..n.normalized.len())),
Some(" This is an example")
);
}
@ -573,8 +663,8 @@ mod tests {
n.strip();
assert_eq!(&n.normalized, "This is an example");
assert_eq!(
n.get_range_original(0..n.normalized.len()),
Some(" This is an example ".into())
n.get_range_original(Range::Normalized(0..n.normalized.len())),
Some("This is an example")
);
}
@ -597,7 +687,7 @@ mod tests {
(4, 5)
]
);
assert_eq!(n.get_original_offsets(0..4), Some(0..0));
assert_eq!(n.convert_offsets(Range::Normalized(0..4)), Some(0..0));
}
#[test]
@ -619,6 +709,16 @@ mod tests {
(3, 3)
]
);
assert_eq!(n.get_original_offsets(3.." there".len()), Some(3..3));
assert_eq!(
n.convert_offsets(Range::Normalized(3.." there".len())),
Some(3..3)
);
}
#[test]
fn get_range() {
let s = String::from("Hello my name is John 👋");
assert_eq!(get_range_of(&s, ..), Some(&s[..]));
assert_eq!(get_range_of(&s, 17..), Some("John 👋"));
}
}

View File

@ -0,0 +1,58 @@
//! This comes from the Rust libcore and is duplicated here because it is not exported
//! (cf https://github.com/rust-lang/rust/blob/25091ed9b7739e12466fb2490baa1e8a2815121c/src/libcore/iter/adapters/mod.rs#L2664)
//! We are now using the version from https://stackoverflow.com/questions/44544323/how-to-unzip-a-sequence-of-resulta-b-e-to-a-veca-vecb-and-stop-on-f
//! because the one from the libcore seems to cause overflowing stacks in some cases
pub struct ResultShunt<I, E> {
iter: I,
error: Option<E>,
}
impl<I, T, E> ResultShunt<I, E>
where
I: Iterator<Item = Result<T, E>>,
{
/// Process the given iterator as if it yielded a `T` instead of a
/// `Result<T, _>`. Any errors will stop the inner iterator and
/// the overall result will be an error.
pub fn process<F, U>(iter: I, mut f: F) -> Result<U, E>
where
F: FnMut(&mut Self) -> U,
{
let mut shunt = ResultShunt::new(iter);
let value = f(shunt.by_ref());
shunt.reconstruct(value)
}
fn new(iter: I) -> Self {
ResultShunt { iter, error: None }
}
/// Consume the adapter and rebuild a `Result` value. This should
/// *always* be called, otherwise any potential error would be
/// lost.
fn reconstruct<U>(self, val: U) -> Result<U, E> {
match self.error {
None => Ok(val),
Some(e) => Err(e),
}
}
}
impl<I, T, E> Iterator for ResultShunt<I, E>
where
I: Iterator<Item = Result<T, E>>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.iter.next() {
Some(Ok(v)) => Some(v),
Some(Err(e)) => {
self.error = Some(e);
None
}
None => None,
}
}
}

View File

@ -0,0 +1,3 @@
pub mod iter;
pub mod padding;
pub mod truncation;

View File

@ -0,0 +1,63 @@
use crate::tokenizer::{Encoding, Result};
use rayon::prelude::*;
/// The various possible padding directions.
#[derive(Debug, Clone, Copy)]
pub enum PaddingDirection {
Left,
Right,
}
impl std::convert::AsRef<str> for PaddingDirection {
fn as_ref(&self) -> &str {
match self {
PaddingDirection::Left => "left",
PaddingDirection::Right => "right",
}
}
}
#[derive(Debug, Clone)]
pub struct PaddingParams {
pub strategy: PaddingStrategy,
pub direction: PaddingDirection,
pub pad_id: u32,
pub pad_type_id: u32,
pub pad_token: String,
}
#[derive(Debug, Clone)]
pub enum PaddingStrategy {
BatchLongest,
Fixed(usize),
}
pub fn pad_encodings(
mut encodings: Vec<Encoding>,
params: &PaddingParams,
) -> Result<Vec<Encoding>> {
if encodings.is_empty() {
return Ok(encodings);
}
let pad_length = match params.strategy {
PaddingStrategy::Fixed(size) => size,
PaddingStrategy::BatchLongest => encodings
.par_iter()
.map(|e| e.get_ids().len())
.max()
.unwrap(),
};
encodings.par_iter_mut().for_each(|encoding| {
encoding.pad(
pad_length,
params.pad_id,
params.pad_type_id,
&params.pad_token,
params.direction,
)
});
Ok(encodings)
}

View File

@ -1,5 +1,4 @@
use crate::tokenizer::{Encoding, PaddingDirection, Result};
use rayon::prelude::*;
use crate::tokenizer::{Encoding, Result};
#[derive(Debug, Clone)]
pub struct TruncationParams {
@ -8,21 +7,6 @@ pub struct TruncationParams {
pub stride: usize,
}
#[derive(Debug, Clone)]
pub struct PaddingParams {
pub strategy: PaddingStrategy,
pub direction: PaddingDirection,
pub pad_id: u32,
pub pad_type_id: u32,
pub pad_token: String,
}
#[derive(Debug, Clone)]
pub enum PaddingStrategy {
BatchLongest,
Fixed(usize),
}
#[derive(Debug)]
pub enum Error {
SecondSequenceNotProvided,
@ -118,33 +102,3 @@ pub fn truncate_encodings(
Ok((encoding, pair_encoding))
}
pub fn pad_encodings(
mut encodings: Vec<Encoding>,
params: &PaddingParams,
) -> Result<Vec<Encoding>> {
if encodings.is_empty() {
return Ok(encodings);
}
let pad_length = match params.strategy {
PaddingStrategy::Fixed(size) => size,
PaddingStrategy::BatchLongest => encodings
.par_iter()
.map(|e| e.get_ids().len())
.max()
.unwrap(),
};
encodings.par_iter_mut().for_each(|encoding| {
encoding.pad(
pad_length,
params.pad_id,
params.pad_type_id,
&params.pad_token,
params.direction,
)
});
Ok(encodings)
}

154
tokenizers/tests/offsets.rs Normal file
View File

@ -0,0 +1,154 @@
use tokenizers::models::bpe::BPE;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::tokenizer::{get_range_of, EncodeInput, Tokenizer};
fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
let mut tokenizer = Tokenizer::new(Box::new(
BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
.build()
.expect("Files not found, run `make test` to download these files"),
));
tokenizer.with_pre_tokenizer(Box::new(
ByteLevel::default().add_prefix_space(add_prefix_space),
));
tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets)));
tokenizer
}
#[inline]
fn offset_as_range(offset: (usize, usize)) -> std::ops::Range<usize> {
offset.0..offset.1
}
#[test]
fn byte_level_basic() {
// Without trimming offsets
let tokenizer = get_byte_level(true, false);
let input = String::from("Hello there, how are you?");
let output = tokenizer
.encode(EncodeInput::Single(input.clone()), false)
.unwrap();
let offsets = output.get_offsets();
assert_eq!(
get_range_of(&input, offset_as_range(offsets[0])),
Some("Hello")
);
assert_eq!(
get_range_of(&input, offset_as_range(offsets[1])),
Some(" there")
);
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some(","));
assert_eq!(
get_range_of(&input, offset_as_range(offsets[3])),
Some(" how")
);
assert_eq!(
get_range_of(&input, offset_as_range(offsets[4])),
Some(" are")
);
assert_eq!(
get_range_of(&input, offset_as_range(offsets[5])),
Some(" you")
);
assert_eq!(get_range_of(&input, offset_as_range(offsets[6])), Some("?"));
// And when trimming offsets:
let tokenizer = get_byte_level(true, true);
let input = String::from("Hello there, how are you?");
let output = tokenizer
.encode(EncodeInput::Single(input.clone()), false)
.unwrap();
let offsets = output.get_offsets();
assert_eq!(
get_range_of(&input, offset_as_range(offsets[0])),
Some("Hello")
);
assert_eq!(
get_range_of(&input, offset_as_range(offsets[1])),
Some("there")
);
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some(","));
assert_eq!(
get_range_of(&input, offset_as_range(offsets[3])),
Some("how")
);
assert_eq!(
get_range_of(&input, offset_as_range(offsets[4])),
Some("are")
);
assert_eq!(
get_range_of(&input, offset_as_range(offsets[5])),
Some("you")
);
assert_eq!(get_range_of(&input, offset_as_range(offsets[6])), Some("?"));
}
#[test]
fn byte_level_unicode() {
let tokenizer = get_byte_level(true, false);
let input = String::from("i⭢j");
let output = tokenizer
.encode(EncodeInput::Single(input.clone()), false)
.unwrap();
let offsets = output.get_offsets();
assert_eq!(get_range_of(&input, offset_as_range(offsets[1])), Some(""));
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some(""));
assert_eq!(get_range_of(&input, offset_as_range(offsets[3])), Some(""));
}
#[test]
fn byte_level_double_sequence() {
let input_a = String::from("My name is Anthony");
let input_b = String::from("What is my name?");
// Without trimming offsets
let tokenizer = get_byte_level(true, false);
let output = tokenizer
.encode(EncodeInput::Dual(input_a.clone(), input_b.clone()), false)
.unwrap();
let offsets = output.get_offsets();
assert_eq!(
offsets,
&[
(0, 2),
(2, 7),
(7, 10),
(10, 18),
(0, 4),
(4, 7),
(7, 10),
(10, 15),
(15, 16)
]
);
// When trimming offsets
let tokenizer = get_byte_level(true, true);
let output = tokenizer
.encode(EncodeInput::Dual(input_a, input_b), false)
.unwrap();
let offsets = output.get_offsets();
assert_eq!(
offsets,
&[
(0, 2),
(3, 7),
(8, 10),
(11, 18),
(0, 4),
(5, 7),
(8, 10),
(11, 15),
(15, 16)
]
);
}