mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Merge pull request #197 from huggingface/remove-normalized
Remove NormalizedString from Encoding
This commit is contained in:
17
.github/workflows/rust.yml
vendored
17
.github/workflows/rust.yml
vendored
@ -60,7 +60,22 @@ jobs:
|
||||
args: --manifest-path ./tokenizers/Cargo.toml --all-targets --all-features -- -D warnings
|
||||
|
||||
- name: Run Tests
|
||||
if: matrix.os != 'windows-latest'
|
||||
shell: bash
|
||||
working-directory: ./tokenizers
|
||||
run: make test
|
||||
|
||||
# Skip integration tests for now on Windows
|
||||
- name: Run lib Tests on Windows
|
||||
if: matrix.os == 'windows-latest'
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --verbose --manifest-path ./tokenizers/Cargo.toml
|
||||
args: --verbose --manifest-path ./tokenizers/Cargo.toml --lib
|
||||
|
||||
- name: Run doc Tests on Windows
|
||||
if: matrix.os == 'windows-latest'
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --verbose --manifest-path ./tokenizers/Cargo.toml --doc
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -7,6 +7,7 @@ target
|
||||
Cargo.lock
|
||||
|
||||
/data
|
||||
tokenizers/data
|
||||
/docs
|
||||
|
||||
__pycache__
|
||||
@ -16,4 +17,4 @@ pip-wheel-metadata
|
||||
/bindings/python/build
|
||||
/bindings/python/dist
|
||||
|
||||
.vscode
|
||||
.vscode
|
||||
|
11
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
11
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
@ -44,17 +44,6 @@ export interface RawEncoding {
|
||||
*/
|
||||
getTypeIds(): number[];
|
||||
|
||||
/**
|
||||
* Returns the original string
|
||||
*
|
||||
* @param [begin] The index from which to start (can be negative).
|
||||
* @param [end] The index (excluded) to which to stop (can be negative).
|
||||
* Stopping at the end of the string if not provided.
|
||||
* @returns The full original string if no parameter is provided,
|
||||
* otherwise the original string between `begin` and `end`
|
||||
*/
|
||||
getOriginalString(begin?: number, end?: number): string;
|
||||
|
||||
/**
|
||||
* Pad the current Encoding at the given length
|
||||
*
|
||||
|
@ -22,7 +22,6 @@ describe("RawEncoding", () => {
|
||||
expect(typeof encoding.getIds).toBe("function");
|
||||
expect(typeof encoding.getLength).toBe("function");
|
||||
expect(typeof encoding.getOffsets).toBe("function");
|
||||
expect(typeof encoding.getOriginalString).toBe("function");
|
||||
expect(typeof encoding.getOverflowing).toBe("function");
|
||||
expect(typeof encoding.getSpecialTokensMask).toBe("function");
|
||||
expect(typeof encoding.getTokens).toBe("function");
|
||||
@ -31,109 +30,6 @@ describe("RawEncoding", () => {
|
||||
expect(typeof encoding.truncate).toBe("function");
|
||||
});
|
||||
|
||||
describe("getOriginalString", () => {
|
||||
it("returns the full original string when no params", () => {
|
||||
const original = encoding.getOriginalString();
|
||||
expect(original).toEqual(originalString);
|
||||
});
|
||||
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
const original = encoding.getOriginalString(undefined);
|
||||
expect(original).toEqual(originalString);
|
||||
});
|
||||
|
||||
it("accepts `undefined` as second parameter", () => {
|
||||
const original = encoding.getOriginalString(0, undefined);
|
||||
expect(original).toEqual(originalString);
|
||||
});
|
||||
|
||||
it("throws an error when `begin` is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(1000)).toThrow();
|
||||
});
|
||||
|
||||
it("returns the original string starting at the specified index", () => {
|
||||
const original = encoding.getOriginalString(3);
|
||||
expect(original).toEqual("name is john");
|
||||
});
|
||||
|
||||
it("throws an error when `end` is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(0, 1000)).toThrow();
|
||||
});
|
||||
|
||||
it("returns the original string between the two specified indexes", () => {
|
||||
const original = encoding.getOriginalString(3, 7);
|
||||
expect(original).toEqual("name");
|
||||
});
|
||||
|
||||
describe("with only a negative `begin`", () => {
|
||||
it("returns the original string counting from the end when in the range", () => {
|
||||
const original = encoding.getOriginalString(-4);
|
||||
expect(original).toEqual("john");
|
||||
});
|
||||
|
||||
it("throws an error when out of range", () => {
|
||||
expect(() => encoding.getOriginalString(-1000)).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("with a positive `begin` and a negative `end`", () => {
|
||||
it("returns the original string when resulting range is valid", () => {
|
||||
const original = encoding.getOriginalString(3, -5);
|
||||
expect(original).toEqual("name is");
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
||||
expect(() => encoding.getOriginalString(7, -10)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when `begin` is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(1000, -10)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(7, -1000)).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("with a negative `begin` and a positive `end`", () => {
|
||||
it("returns the original string when resulting range is valid", () => {
|
||||
const original = encoding.getOriginalString(-7, 10);
|
||||
expect(original).toEqual("is");
|
||||
});
|
||||
|
||||
it("throws an error when resulting `begin` index is upper than `end`", () => {
|
||||
expect(() => encoding.getOriginalString(-3, 5)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when `end` is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(-5, 1000)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `begin` index is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(-1000, 10)).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("with negatives `begin` and `end`", () => {
|
||||
it("returns the original string when resulting range is valid", () => {
|
||||
const original = encoding.getOriginalString(-7, -5);
|
||||
expect(original).toEqual("is");
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
||||
expect(() => encoding.getOriginalString(-5, -10)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `begin` index is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(-1000, -10)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is out of range", () => {
|
||||
expect(() => encoding.getOriginalString(-10, -1000)).toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("truncate", () => {
|
||||
it("accepts `undefined` as second parameter", () => {
|
||||
expect(encoding.truncate(10, undefined)).toBeUndefined();
|
||||
|
@ -1,3 +1,4 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
/* eslint-disable @typescript-eslint/no-empty-function */
|
||||
|
||||
import { promisify } from "util";
|
||||
@ -112,7 +113,7 @@ describe("Tokenizer", () => {
|
||||
[2, 6],
|
||||
[6, 8],
|
||||
[8, 12],
|
||||
[12, 16]
|
||||
[0, 4]
|
||||
]);
|
||||
expect(encoding.getOverflowing()).toEqual([]);
|
||||
expect(encoding.getSpecialTokensMask()).toEqual([0, 0, 0, 0, 0]);
|
||||
|
11
bindings/node/lib/bindings/utils.d.ts
vendored
Normal file
11
bindings/node/lib/bindings/utils.d.ts
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
/**
|
||||
* Returns a subpart of a string according to specified indexes, and respecting unicode characters
|
||||
*
|
||||
* @param text The text for which to return a subpart
|
||||
* @param [begin] The index from which to start (can be negative).
|
||||
* @param [end] The index (excluded) to which to stop (can be negative).
|
||||
* Stopping at the end of the string if not provided.
|
||||
* @returns The full string if no start/end indexes are provided,
|
||||
* otherwise the original string between `begin` (included) and `end` (excluded)
|
||||
*/
|
||||
export function slice(text: string, start?: number, end?: number): string;
|
5
bindings/node/lib/bindings/utils.js
Normal file
5
bindings/node/lib/bindings/utils.js
Normal file
@ -0,0 +1,5 @@
|
||||
const native = require("./native");
|
||||
|
||||
module.exports = {
|
||||
slice: native.utils_slice
|
||||
};
|
107
bindings/node/lib/bindings/utils.test.ts
Normal file
107
bindings/node/lib/bindings/utils.test.ts
Normal file
@ -0,0 +1,107 @@
|
||||
import { slice } from "./utils";
|
||||
|
||||
describe("slice", () => {
|
||||
const text = "My name is John 👋";
|
||||
const sliceText = slice.bind({}, text);
|
||||
|
||||
it("returns the full text when no params", () => {
|
||||
const sliced = sliceText();
|
||||
expect(sliced).toEqual(text);
|
||||
});
|
||||
|
||||
it("accepts `undefined` as second parameter", () => {
|
||||
const original = sliceText(undefined);
|
||||
expect(original).toEqual(text);
|
||||
});
|
||||
|
||||
it("accepts `undefined` as third parameter", () => {
|
||||
const original = sliceText(0, undefined);
|
||||
expect(original).toEqual(text);
|
||||
});
|
||||
|
||||
it("throws an error when `begin` is out of range", () => {
|
||||
expect(() => sliceText(1000)).toThrow();
|
||||
});
|
||||
|
||||
it("returns slice starting at the specified index", () => {
|
||||
const original = sliceText(3);
|
||||
expect(original).toEqual("name is John 👋");
|
||||
});
|
||||
|
||||
it("throws an error when `end` is out of range", () => {
|
||||
expect(() => sliceText(0, 1000)).toThrow();
|
||||
});
|
||||
|
||||
it("returns the text between the two specified indexes", () => {
|
||||
const original = sliceText(3, 7);
|
||||
expect(original).toEqual("name");
|
||||
});
|
||||
|
||||
describe("with only a negative `begin`", () => {
|
||||
it("returns the original string counting from the end when in the range", () => {
|
||||
const original = sliceText(-1);
|
||||
expect(original).toEqual("👋");
|
||||
});
|
||||
|
||||
it("throws an error when out of range", () => {
|
||||
expect(() => sliceText(-1000)).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("with a positive `begin` and a negative `end`", () => {
|
||||
it("returns correct slice when resulting range is valid", () => {
|
||||
const original = sliceText(3, -7);
|
||||
expect(original).toEqual("name is");
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
||||
expect(() => sliceText(7, -12)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when `begin` is out of range", () => {
|
||||
expect(() => sliceText(1000, -12)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is out of range", () => {
|
||||
expect(() => sliceText(7, -1000)).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("with a negative `begin` and a positive `end`", () => {
|
||||
it("returns correct slice when resulting range is valid", () => {
|
||||
const original = sliceText(-9, 10);
|
||||
expect(original).toEqual("is");
|
||||
});
|
||||
|
||||
it("throws an error when resulting `begin` index is upper than `end`", () => {
|
||||
expect(() => sliceText(-3, 5)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when `end` is out of range", () => {
|
||||
expect(() => sliceText(-5, 1000)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `begin` index is out of range", () => {
|
||||
expect(() => sliceText(-1000, 10)).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("with negatives `begin` and `end`", () => {
|
||||
it("returns correct slice when resulting range is valid", () => {
|
||||
const original = sliceText(-9, -7);
|
||||
expect(original).toEqual("is");
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
||||
expect(() => sliceText(-5, -10)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `begin` index is out of range", () => {
|
||||
expect(() => sliceText(-1000, -10)).toThrow();
|
||||
});
|
||||
|
||||
it("throws an error when resulting `end` index is out of range", () => {
|
||||
expect(() => sliceText(-10, -1000)).toThrow();
|
||||
});
|
||||
});
|
||||
});
|
@ -5,7 +5,6 @@ export class Encoding {
|
||||
private _ids?: number[];
|
||||
private _length?: number;
|
||||
private _offsets?: [number, number][];
|
||||
private _originalString?: string;
|
||||
private _overflowing?: Encoding[];
|
||||
private _specialTokensMask?: number[];
|
||||
private _tokens?: string[];
|
||||
@ -103,27 +102,6 @@ export class Encoding {
|
||||
return (this._typeIds = this.rawEncoding.getTypeIds());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the original string
|
||||
*
|
||||
* @param [begin] The index from which to start (can be negative).
|
||||
* @param [end] The index (excluded) to which to stop (can be negative).
|
||||
* Stopping at the end of the string if not provided.
|
||||
* @returns The full original string if no parameter is provided,
|
||||
* otherwise the original string between `begin` and `end`
|
||||
*/
|
||||
getOriginalString(begin?: number, end?: number): string {
|
||||
if (begin === undefined && end === undefined) {
|
||||
if (this._originalString !== undefined) {
|
||||
return this._originalString;
|
||||
} else {
|
||||
return (this._originalString = this.rawEncoding.getOriginalString());
|
||||
}
|
||||
}
|
||||
|
||||
return this.rawEncoding.getOriginalString(begin, end);
|
||||
}
|
||||
|
||||
/**
|
||||
* Pad the current Encoding at the given length
|
||||
*
|
||||
@ -153,7 +131,6 @@ export class Encoding {
|
||||
"_ids",
|
||||
"_length",
|
||||
"_offsets",
|
||||
"_originalString",
|
||||
"_overflowing",
|
||||
"_specialTokensMask",
|
||||
"_tokens",
|
||||
|
@ -1,5 +1,6 @@
|
||||
// export * from "./bindings";
|
||||
export * from "./implementations/tokenizers";
|
||||
export * from "./bindings/enums";
|
||||
export { slice } from "./bindings/utils";
|
||||
export { PaddingOptions, TruncationOptions } from "./bindings/tokenizer";
|
||||
export { Encoding } from "./implementations/encoding";
|
||||
|
102
bindings/node/native/src/container.rs
Normal file
102
bindings/node/native/src/container.rs
Normal file
@ -0,0 +1,102 @@
|
||||
/// A Container type
|
||||
///
|
||||
/// Provides an interface to allow transfer of ownership between Node and Rust.
|
||||
/// It either contains a Box with full ownership of the content, or a pointer to the content.
|
||||
///
|
||||
/// The main goal here is to allow Node calling into Rust to initialize some objects. Later
|
||||
/// these objects may need to be used by Rust who will expect to take ownership. Since Node
|
||||
/// does not allow any sort of ownership transfer, it will keep a reference to this object
|
||||
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
|
||||
/// and just keep a pointer in the Node object.
|
||||
pub enum Container<T: ?Sized> {
|
||||
Owned(Box<T>),
|
||||
Pointer(*mut T),
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl<T> Container<T>
|
||||
where
|
||||
T: ?Sized,
|
||||
{
|
||||
pub fn from_ref(reference: &Box<T>) -> Self {
|
||||
let content: *const T = &**reference;
|
||||
Container::Pointer(content as *mut _)
|
||||
}
|
||||
|
||||
pub fn is_owned(&self) -> bool {
|
||||
if let Container::Owned(_) = &self {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
|
||||
pub fn take(self) -> Option<Box<T>> {
|
||||
match self {
|
||||
Container::Owned(obj) => Some(obj),
|
||||
Container::Pointer(_) => None,
|
||||
Container::Empty => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace an empty content by the new provided owned one, otherwise do nothing
|
||||
pub fn to_owned(&mut self, o: Box<T>) {
|
||||
if let Container::Empty = self {
|
||||
unsafe {
|
||||
let new_container = Container::Owned(o);
|
||||
std::ptr::write(self, new_container);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
|
||||
pub fn to_pointer(&mut self) -> Option<Box<T>> {
|
||||
if let Container::Owned(_) = self {
|
||||
unsafe {
|
||||
let old_container = std::ptr::read(self);
|
||||
let ptr = Box::into_raw(old_container.take().unwrap());
|
||||
let new_container = Container::Pointer(ptr);
|
||||
std::ptr::write(self, new_container);
|
||||
|
||||
Some(Box::from_raw(ptr))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<F, U>(&self, closure: F) -> U
|
||||
where
|
||||
F: FnOnce(Option<&Box<T>>) -> U,
|
||||
{
|
||||
match self {
|
||||
Container::Owned(val) => closure(Some(val)),
|
||||
Container::Pointer(ptr) => unsafe {
|
||||
let val = Box::from_raw(*ptr);
|
||||
let res = closure(Some(&val));
|
||||
// We call this to make sure we don't drop the Box
|
||||
Box::into_raw(val);
|
||||
res
|
||||
},
|
||||
Container::Empty => closure(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_mut<F, U>(&mut self, closure: F) -> U
|
||||
where
|
||||
F: FnOnce(Option<&mut Box<T>>) -> U,
|
||||
{
|
||||
match self {
|
||||
Container::Owned(val) => closure(Some(val)),
|
||||
Container::Pointer(ptr) => unsafe {
|
||||
let mut val = Box::from_raw(*ptr);
|
||||
let res = closure(Some(&mut val));
|
||||
// We call this to make sure we don't drop the Box
|
||||
Box::into_raw(val);
|
||||
res
|
||||
},
|
||||
Container::Empty => closure(None),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::utils::Container;
|
||||
use crate::container::Container;
|
||||
use neon::prelude::*;
|
||||
|
||||
/// Decoder
|
||||
|
@ -2,7 +2,7 @@ extern crate tokenizers as tk;
|
||||
|
||||
use tk::tokenizer::PaddingDirection;
|
||||
|
||||
use crate::utils::Container;
|
||||
use crate::container::Container;
|
||||
use neon::prelude::*;
|
||||
|
||||
/// Encoding
|
||||
@ -159,62 +159,6 @@ declare_types! {
|
||||
Ok(js_overflowings.upcast())
|
||||
}
|
||||
|
||||
method getOriginalString(mut cx) {
|
||||
// getOriginalString(begin?: number, end?: number)
|
||||
let this = cx.this();
|
||||
|
||||
let len_original = {
|
||||
let guard = cx.lock();
|
||||
let len = this.borrow(&guard).encoding.execute(|encoding| {
|
||||
encoding.unwrap().get_normalized().len_original()
|
||||
});
|
||||
len
|
||||
};
|
||||
|
||||
let get_index = |x: i32| -> usize {
|
||||
if x >= 0 {
|
||||
x as usize
|
||||
} else {
|
||||
(len_original as i32 + x) as usize
|
||||
}
|
||||
};
|
||||
|
||||
let begin_index = if let Some(begin_arg) = cx.argument_opt(0) {
|
||||
if begin_arg.downcast::<JsUndefined>().is_err() {
|
||||
let begin = begin_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
||||
get_index(begin)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let end_index = if let Some(end_arg) = cx.argument_opt(1) {
|
||||
if end_arg.downcast::<JsUndefined>().is_err() {
|
||||
let end = end_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
||||
get_index(end)
|
||||
} else {
|
||||
len_original
|
||||
}
|
||||
} else {
|
||||
len_original
|
||||
};
|
||||
|
||||
let original = {
|
||||
let guard = cx.lock();
|
||||
let original = this.borrow(&guard).encoding.execute(|encoding| {
|
||||
encoding.unwrap().get_normalized().get_range_original(begin_index..end_index)
|
||||
});
|
||||
original
|
||||
};
|
||||
if let Some(original) = original {
|
||||
Ok(cx.string(original).upcast())
|
||||
} else {
|
||||
cx.throw_error("Error in offsets")
|
||||
}
|
||||
}
|
||||
|
||||
method pad(mut cx) {
|
||||
// pad(length: number, options?: {
|
||||
// direction?: 'left' | 'right' = 'right',
|
||||
|
@ -3,6 +3,7 @@
|
||||
extern crate neon;
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
mod container;
|
||||
mod decoders;
|
||||
mod encoding;
|
||||
mod models;
|
||||
@ -31,6 +32,8 @@ register_module!(mut m, {
|
||||
pre_tokenizers::register(&mut m, "pre_tokenizers")?;
|
||||
// Trainers
|
||||
trainers::register(&mut m, "trainers")?;
|
||||
// Utils
|
||||
utils::register(&mut m, "utils")?;
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
@ -1,7 +1,7 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::container::Container;
|
||||
use crate::tasks::models::{BPEFromFilesTask, WordPieceFromFilesTask};
|
||||
use crate::utils::Container;
|
||||
use neon::prelude::*;
|
||||
use std::path::Path;
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::utils::Container;
|
||||
use crate::container::Container;
|
||||
use neon::prelude::*;
|
||||
|
||||
/// Normalizer
|
||||
|
@ -1,6 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::utils::Container;
|
||||
use crate::container::Container;
|
||||
use neon::prelude::*;
|
||||
|
||||
/// PreTokenizers
|
||||
|
@ -1,6 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::utils::Container;
|
||||
use crate::container::Container;
|
||||
use neon::prelude::*;
|
||||
|
||||
/// Processor
|
||||
|
@ -1,5 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::container::Container;
|
||||
use crate::decoders::JsDecoder;
|
||||
use crate::models::JsModel;
|
||||
use crate::normalizers::JsNormalizer;
|
||||
@ -7,7 +8,6 @@ use crate::pre_tokenizers::JsPreTokenizer;
|
||||
use crate::processors::JsPostProcessor;
|
||||
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer};
|
||||
use crate::trainers::JsTrainer;
|
||||
use crate::utils::Container;
|
||||
use neon::prelude::*;
|
||||
|
||||
use tk::tokenizer::{
|
||||
@ -64,17 +64,43 @@ declare_types! {
|
||||
let mut with_added_tokens = true;
|
||||
if let Some(args) = cx.argument_opt(0) {
|
||||
if args.downcast::<JsUndefined>().is_err() {
|
||||
with_added_tokens = args.downcast::<JsBoolean>().or_throw(&mut cx)?.value() as bool;
|
||||
with_added_tokens = args.downcast::<JsBoolean>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as bool;
|
||||
}
|
||||
}
|
||||
|
||||
let mut this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let size = this.borrow_mut(&guard).tokenizer.get_vocab_size(with_added_tokens);
|
||||
let size = this.borrow_mut(&guard)
|
||||
.tokenizer
|
||||
.get_vocab_size(with_added_tokens);
|
||||
|
||||
Ok(cx.number(size as f64).upcast())
|
||||
}
|
||||
|
||||
method normalize(mut cx) {
|
||||
// normalize(sentence: String) -> String
|
||||
let sentence = cx.argument::<JsString>(0)?.value();
|
||||
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
|
||||
let result = {
|
||||
this.borrow(&guard)
|
||||
.tokenizer
|
||||
.normalize(&sentence)
|
||||
.map(|s| s.get().to_owned())
|
||||
};
|
||||
let normalized = result
|
||||
.map_err(|e| {
|
||||
cx.throw_error::<_, ()>(format!("{}", e))
|
||||
.unwrap_err()
|
||||
})?;
|
||||
|
||||
Ok(cx.string(normalized).upcast())
|
||||
}
|
||||
|
||||
method encode(mut cx) {
|
||||
// encode(
|
||||
// sentence: String,
|
||||
|
@ -1,6 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::utils::Container;
|
||||
use crate::container::Container;
|
||||
use neon::prelude::*;
|
||||
use std::collections::HashSet;
|
||||
|
||||
|
@ -1,102 +1,51 @@
|
||||
/// A Container type
|
||||
///
|
||||
/// Provides an interface to allow transfer of ownership between Node and Rust.
|
||||
/// It either contains a Box with full ownership of the content, or a pointer to the content.
|
||||
///
|
||||
/// The main goal here is to allow Node calling into Rust to initialize some objects. Later
|
||||
/// these objects may need to be used by Rust who will expect to take ownership. Since Node
|
||||
/// does not allow any sort of ownership transfer, it will keep a reference to this object
|
||||
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
|
||||
/// and just keep a pointer in the Node object.
|
||||
pub enum Container<T: ?Sized> {
|
||||
Owned(Box<T>),
|
||||
Pointer(*mut T),
|
||||
Empty,
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use neon::prelude::*;
|
||||
|
||||
/// slice(s: string, start?: number, end?: number)
|
||||
fn slice(mut cx: FunctionContext) -> JsResult<JsString> {
|
||||
let s = cx.argument::<JsString>(0)?.value();
|
||||
let len = s.chars().count();
|
||||
|
||||
let get_index = |x: i32| -> usize {
|
||||
if x >= 0 {
|
||||
x as usize
|
||||
} else {
|
||||
(len as i32 + x) as usize
|
||||
}
|
||||
};
|
||||
|
||||
let begin_index = if let Some(begin_arg) = cx.argument_opt(1) {
|
||||
if begin_arg.downcast::<JsUndefined>().is_err() {
|
||||
let begin = begin_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
||||
get_index(begin)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let end_index = if let Some(end_arg) = cx.argument_opt(2) {
|
||||
if end_arg.downcast::<JsUndefined>().is_err() {
|
||||
let end = end_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
||||
get_index(end)
|
||||
} else {
|
||||
len
|
||||
}
|
||||
} else {
|
||||
len
|
||||
};
|
||||
|
||||
if let Some(slice) = tk::tokenizer::get_range_of(&s, begin_index..end_index) {
|
||||
Ok(cx.string(slice))
|
||||
} else {
|
||||
cx.throw_error("Error in offsets")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Container<T>
|
||||
where
|
||||
T: ?Sized,
|
||||
{
|
||||
pub fn from_ref(reference: &Box<T>) -> Self {
|
||||
let content: *const T = &**reference;
|
||||
Container::Pointer(content as *mut _)
|
||||
}
|
||||
|
||||
pub fn is_owned(&self) -> bool {
|
||||
if let Container::Owned(_) = &self {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
|
||||
pub fn take(self) -> Option<Box<T>> {
|
||||
match self {
|
||||
Container::Owned(obj) => Some(obj),
|
||||
Container::Pointer(_) => None,
|
||||
Container::Empty => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace an empty content by the new provided owned one, otherwise do nothing
|
||||
pub fn to_owned(&mut self, o: Box<T>) {
|
||||
if let Container::Empty = self {
|
||||
unsafe {
|
||||
let new_container = Container::Owned(o);
|
||||
std::ptr::write(self, new_container);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
|
||||
pub fn to_pointer(&mut self) -> Option<Box<T>> {
|
||||
if let Container::Owned(_) = self {
|
||||
unsafe {
|
||||
let old_container = std::ptr::read(self);
|
||||
let ptr = Box::into_raw(old_container.take().unwrap());
|
||||
let new_container = Container::Pointer(ptr);
|
||||
std::ptr::write(self, new_container);
|
||||
|
||||
Some(Box::from_raw(ptr))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<F, U>(&self, closure: F) -> U
|
||||
where
|
||||
F: FnOnce(Option<&Box<T>>) -> U,
|
||||
{
|
||||
match self {
|
||||
Container::Owned(val) => closure(Some(val)),
|
||||
Container::Pointer(ptr) => unsafe {
|
||||
let val = Box::from_raw(*ptr);
|
||||
let res = closure(Some(&val));
|
||||
// We call this to make sure we don't drop the Box
|
||||
Box::into_raw(val);
|
||||
res
|
||||
},
|
||||
Container::Empty => closure(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_mut<F, U>(&mut self, closure: F) -> U
|
||||
where
|
||||
F: FnOnce(Option<&mut Box<T>>) -> U,
|
||||
{
|
||||
match self {
|
||||
Container::Owned(val) => closure(Some(val)),
|
||||
Container::Pointer(ptr) => unsafe {
|
||||
let mut val = Box::from_raw(*ptr);
|
||||
let res = closure(Some(&mut val));
|
||||
// We call this to make sure we don't drop the Box
|
||||
Box::into_raw(val);
|
||||
res
|
||||
},
|
||||
Container::Empty => closure(None),
|
||||
}
|
||||
}
|
||||
/// Register everything here
|
||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_slice", prefix), slice)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -6,18 +6,28 @@ a high number of files as it avoids having too many progress bars on screen.
|
||||
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
|
||||
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
|
||||
whitespaces are part of the actual token.
|
||||
It has been added to `ByteLevelBPETokenizer` and but it is off by default (`trim_offsets=False`).
|
||||
- `encode` and `encode_batch` no take a new optional argument, specifying whether we should add the
|
||||
special tokens. This stays activated by default.
|
||||
It has been added to `ByteLevelBPETokenizer` but it is off by default (`trim_offsets=False`).
|
||||
([#188](https://github.com/huggingface/tokenizers/pull/188))
|
||||
- `encode` and `encode_batch` now take a new optional argument, specifying whether we should add the
|
||||
special tokens. This is activated by default. ([#193](https://github.com/huggingface/tokenizers/pull/193))
|
||||
- `original_str` and `normalized_str` have been removed from the `Encoding` returned by `encode` and
|
||||
`encode_batch`. This brings a reduction of 70% the memory footprint.
|
||||
([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||
|
||||
## Fixes:
|
||||
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
|
||||
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE ([#193](https://github.com/huggingface/tokenizers/pull/193)):
|
||||
- when `add_prefix_space=True`
|
||||
- when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156))
|
||||
|
||||
## How to migrate:
|
||||
- Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant. If you are
|
||||
using `ByteLevelBPETokenizer`, this option is disabled by default (`trim_offsets=False`).
|
||||
- Access to the `original_str` on the `Encoding` has been removed. The original string is the input
|
||||
of `encode` so it didn't make sense to keep it here.
|
||||
- No need to call `original_str.offsets(offsets[N])` to convert offsets to the original string. They
|
||||
are now relative to the original string by default.
|
||||
- Access to the `normalized_str` on the `Encoding` has been removed. Can be retrieved by calling
|
||||
`normalize(sequence)` on the `Tokenizer`
|
||||
|
||||
# v0.6.0
|
||||
|
||||
|
@ -1,115 +1,11 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::error::PyError;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use pyo3::{PyMappingProtocol, PyObjectProtocol, PySequenceProtocol};
|
||||
use pyo3::{PyObjectProtocol, PySequenceProtocol};
|
||||
use tk::tokenizer::PaddingDirection;
|
||||
|
||||
fn get_range(item: PyObject, max_len: usize) -> PyResult<std::ops::Range<usize>> {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let slice = if let Ok(index) = item.extract::<isize>(py) {
|
||||
if index >= max_len as isize || index < -(max_len as isize) {
|
||||
Err(exceptions::IndexError::py_err("Index out of bounds"))
|
||||
} else {
|
||||
Ok(if index == -1 {
|
||||
PySlice::new(py, index, max_len as isize, 1)
|
||||
} else {
|
||||
PySlice::new(py, index, index + 1, 1)
|
||||
})
|
||||
}
|
||||
} else if let Ok(slice) = item.cast_as::<PySlice>(py) {
|
||||
Ok(slice)
|
||||
} else if let Ok(offset) = item.cast_as::<PyTuple>(py) {
|
||||
if offset.len() == 2 {
|
||||
let start = offset.get_item(0).extract::<isize>()?;
|
||||
let end = offset.get_item(1).extract::<isize>()?;
|
||||
Ok(PySlice::new(py, start, end, 1))
|
||||
} else {
|
||||
Err(exceptions::TypeError::py_err("Expected Tuple[int, int]"))
|
||||
}
|
||||
} else {
|
||||
Err(exceptions::TypeError::py_err(
|
||||
"Expected number or slice or Tuple[int, int]",
|
||||
))
|
||||
}?;
|
||||
|
||||
// Find out range from the slice
|
||||
let len: std::os::raw::c_long = (max_len as i32) as _;
|
||||
let PySliceIndices { start, stop, .. } = slice.indices(len)?;
|
||||
|
||||
Ok(start as usize..stop as usize)
|
||||
}
|
||||
|
||||
enum IndexableStringType {
|
||||
Original,
|
||||
Normalized,
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
pub struct IndexableString {
|
||||
s: tk::tokenizer::NormalizedString,
|
||||
t: IndexableStringType,
|
||||
}
|
||||
#[pymethods]
|
||||
impl IndexableString {
|
||||
fn offsets(&self, item: PyObject) -> PyResult<Option<(usize, usize)>> {
|
||||
let range = get_range(item, self.s.len())?;
|
||||
|
||||
match self.t {
|
||||
IndexableStringType::Original => Ok(self
|
||||
.s
|
||||
.get_original_offsets(range)
|
||||
.map(|range| (range.start, range.end))),
|
||||
IndexableStringType::Normalized => Ok(Some((range.start, range.end))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for IndexableString {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
fn __str__(&self) -> PyResult<String> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyMappingProtocol for IndexableString {
|
||||
fn __getitem__(&self, item: PyObject) -> PyResult<String> {
|
||||
// Find out the range
|
||||
let range = get_range(item, self.s.len())?;
|
||||
|
||||
// Get the range from the relevant string
|
||||
let s = match self.t {
|
||||
IndexableStringType::Original => self.s.get_range_original(range),
|
||||
IndexableStringType::Normalized => self.s.get_range(range),
|
||||
};
|
||||
|
||||
s.map(|s| s.to_owned())
|
||||
.ok_or_else(|| exceptions::IndexError::py_err("Wrong offsets"))
|
||||
}
|
||||
|
||||
fn __len__(self) -> PyResult<usize> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.len_original(),
|
||||
IndexableStringType::Normalized => self.s.len(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
#[repr(transparent)]
|
||||
pub struct Encoding {
|
||||
@ -127,7 +23,7 @@ impl PyObjectProtocol for Encoding {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!(
|
||||
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
|
||||
attention_mask, special_tokens_mask, overflowing, original_str, normalized_str])",
|
||||
attention_mask, special_tokens_mask, overflowing])",
|
||||
self.encoding.get_ids().len()
|
||||
))
|
||||
}
|
||||
@ -142,50 +38,6 @@ impl PySequenceProtocol for Encoding {
|
||||
|
||||
#[pymethods]
|
||||
impl Encoding {
|
||||
#[getter]
|
||||
fn get_normalized_str(&self) -> IndexableString {
|
||||
IndexableString {
|
||||
s: self.encoding.get_normalized().clone(),
|
||||
t: IndexableStringType::Normalized,
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_original_str(&self) -> IndexableString {
|
||||
IndexableString {
|
||||
s: self.encoding.get_normalized().clone(),
|
||||
t: IndexableStringType::Original,
|
||||
}
|
||||
}
|
||||
|
||||
#[args(kwargs = "**")]
|
||||
fn get_range(
|
||||
&self,
|
||||
range: (usize, usize),
|
||||
kwargs: Option<&PyDict>,
|
||||
) -> PyResult<Option<String>> {
|
||||
let mut original = false;
|
||||
if let Some(kwargs) = kwargs {
|
||||
if let Some(koriginal) = kwargs.get_item("original") {
|
||||
original = koriginal.extract()?;
|
||||
}
|
||||
}
|
||||
|
||||
if original {
|
||||
Ok(self
|
||||
.encoding
|
||||
.get_normalized()
|
||||
.get_range_original(range.0..range.1)
|
||||
.map(|s| s.to_owned()))
|
||||
} else {
|
||||
Ok(self
|
||||
.encoding
|
||||
.get_normalized()
|
||||
.get_range(range.0..range.1)
|
||||
.map(|s| s.to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_ids(&self) -> Vec<u32> {
|
||||
self.encoding.get_ids().to_vec()
|
||||
|
@ -159,6 +159,15 @@ impl Tokenizer {
|
||||
self.tokenizer.with_padding(None);
|
||||
}
|
||||
|
||||
fn normalize(&self, sentence: &str) -> PyResult<String> {
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.normalize(sentence)
|
||||
.map(|s| s.get().to_owned()),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
||||
#[args(add_special_tokens = true)]
|
||||
fn encode(
|
||||
&self,
|
||||
|
@ -16,32 +16,9 @@ from typing import Optional, Union, List, Tuple
|
||||
|
||||
Offsets = Tuple[int, int]
|
||||
|
||||
class IndexableString:
|
||||
"""
|
||||
Works almost like a `str`, but allows indexing on offsets
|
||||
provided on an `Encoding`
|
||||
"""
|
||||
|
||||
def offsets(self, offsets: Tuple[int, int]) -> Optional[Tuple[int, int]]:
|
||||
""" Convert the Encoding's offsets to the current string.
|
||||
|
||||
`Encoding` provides a list of offsets that are actually offsets to the Normalized
|
||||
version of text. Calling this method with the offsets provided by `Encoding` will make
|
||||
sure that said offsets can be used to index the `str` directly.
|
||||
"""
|
||||
pass
|
||||
|
||||
class Encoding:
|
||||
""" An Encoding as returned by the Tokenizer """
|
||||
|
||||
@property
|
||||
def normalized_str(self) -> IndexableString:
|
||||
""" The normalized string """
|
||||
pass
|
||||
@property
|
||||
def original_str(self) -> IndexableString:
|
||||
""" The original string """
|
||||
pass
|
||||
@property
|
||||
def ids(self) -> List[int]:
|
||||
""" The tokenized ids """
|
||||
@ -244,6 +221,17 @@ class Tokenizer:
|
||||
def no_padding(self):
|
||||
""" Disable padding """
|
||||
pass
|
||||
def normalize(self, sequence: str) -> str:
|
||||
""" Normalize the given sequence
|
||||
|
||||
Args:
|
||||
sequence: str:
|
||||
The sequence to normalize
|
||||
|
||||
Returns:
|
||||
The normalized string
|
||||
"""
|
||||
pass
|
||||
def encode(
|
||||
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
|
||||
) -> Encoding:
|
||||
|
@ -125,6 +125,18 @@ class BaseTokenizer:
|
||||
"""
|
||||
return self._tokenizer.add_special_tokens(special_tokens)
|
||||
|
||||
def normalize(self, sequence: str) -> str:
|
||||
""" Normalize the given sequence
|
||||
|
||||
Args:
|
||||
sequence: str:
|
||||
The sequence to normalize
|
||||
|
||||
Returns:
|
||||
The normalized string
|
||||
"""
|
||||
return self._tokenizer.normalize(sequence)
|
||||
|
||||
def encode(
|
||||
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
|
||||
) -> Encoding:
|
||||
|
@ -6,9 +6,16 @@ a high number of files as it avoids having too many progress bars on screen.
|
||||
- Improve BPE and WordPiece builders.
|
||||
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
|
||||
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
|
||||
whitespaces are part of the actual token.
|
||||
whitespaces are part of the actual token. ([#188](https://github.com/huggingface/tokenizers/pull/188))
|
||||
- `encode` and `encode_batch` now take a new argument, specifying whether we should add the
|
||||
special tokens.
|
||||
special tokens. ([#193](https://github.com/huggingface/tokenizers/pull/193))
|
||||
- The `NormalizedString` has been removed from the `Encoding`. It is now possible to retrieve it
|
||||
by calling `normalized` on the `Tokenizer`. This brings a reduction of 70% of the memory footprint
|
||||
([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||
- The `NormalizedString` API has been improved. It is now possible to retrieve part of both strings
|
||||
using both "normalized" or "original" offsets. ([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||
- The offsets provided on `Encoding` are now relative to the original string, and not the normalized
|
||||
one anymore. ([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||
|
||||
## Fixes:
|
||||
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
|
||||
|
@ -1,5 +1,12 @@
|
||||
DATA_DIR = data
|
||||
BENCHMARK_DIR = benches
|
||||
BENCHMARK_RESOURCES = $(BENCHMARK_DIR)/gpt2-vocab.json $(BENCHMARK_DIR)/gpt2-merges.txt $(BENCHMARK_DIR)/big.txt
|
||||
TESTS_DIR = tests
|
||||
|
||||
dir_guard=@mkdir -p $(@D)
|
||||
|
||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt
|
||||
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt
|
||||
TESTS_RESOURCES = $(SHARED_RESOURCES)
|
||||
|
||||
.PHONY : build
|
||||
build :
|
||||
@ -20,7 +27,7 @@ lint :
|
||||
cargo clippy --all-targets --all-features -- -D warnings
|
||||
|
||||
.PHONY : test
|
||||
test :
|
||||
test : $(TESTS_RESOURCES)
|
||||
cargo test
|
||||
|
||||
.PHONY : doc
|
||||
@ -38,8 +45,10 @@ all-checks : lint test doc
|
||||
bench : $(BENCHMARK_RESOURCES)
|
||||
cargo bench -- --verbose
|
||||
|
||||
$(BENCHMARK_DIR)/gpt2-% :
|
||||
$(DATA_DIR)/gpt2-% :
|
||||
$(dir_guard)
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-$* -O $@
|
||||
|
||||
$(BENCHMARK_DIR)/big.txt :
|
||||
$(DATA_DIR)/big.txt :
|
||||
$(dir_guard)
|
||||
wget https://norvig.com/big.txt -O $@
|
||||
|
2
tokenizers/benches/.gitignore
vendored
2
tokenizers/benches/.gitignore
vendored
@ -1,2 +0,0 @@
|
||||
*.txt
|
||||
*.json
|
@ -68,13 +68,13 @@ fn iter_bench_encode_batch(
|
||||
}
|
||||
|
||||
fn bench_gpt2(c: &mut Criterion) {
|
||||
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
|
||||
let bpe = BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
|
||||
.build()
|
||||
.unwrap();
|
||||
let tokenizer = create_gpt2_tokenizer(bpe);
|
||||
let mut lines: Vec<EncodeInput> = vec![];
|
||||
let mut batches: Vec<Vec<EncodeInput>> = vec![vec![]];
|
||||
for line in BufReader::new(File::open(Path::new("benches/big.txt")).unwrap())
|
||||
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap())
|
||||
.lines()
|
||||
.map(line_to_input)
|
||||
{
|
||||
@ -93,7 +93,7 @@ fn bench_gpt2(c: &mut Criterion) {
|
||||
b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches))
|
||||
});
|
||||
|
||||
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
|
||||
let bpe = BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
|
||||
.cache_capacity(0)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
@ -239,7 +239,7 @@ impl PostProcessor for ByteLevel {
|
||||
None => encoding,
|
||||
Some(mut pair) => {
|
||||
process_offsets(&mut pair);
|
||||
encoding.merge_with(pair);
|
||||
encoding.merge_with(pair, false);
|
||||
encoding
|
||||
}
|
||||
};
|
||||
@ -251,7 +251,9 @@ impl PostProcessor for ByteLevel {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ByteLevel;
|
||||
use crate::tokenizer::{Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizer};
|
||||
use crate::tokenizer::{
|
||||
Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizer, Range,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn pre_tokenization() {
|
||||
@ -391,13 +393,12 @@ mod tests {
|
||||
]
|
||||
);
|
||||
assert_eq!(input.get(), "iâŃ¢j");
|
||||
assert_eq!(input.get_range_original(1..4), Some("⭢".into()));
|
||||
assert_eq!(input.get_range_original(Range::Normalized(1..4)), Some("⭢"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn processor_trims_offsets() {
|
||||
let start = Encoding::new(
|
||||
NormalizedString::from(""),
|
||||
vec![],
|
||||
vec![],
|
||||
vec![
|
||||
@ -412,7 +413,6 @@ mod tests {
|
||||
vec![],
|
||||
);
|
||||
let expected = Encoding::new(
|
||||
NormalizedString::from(""),
|
||||
vec![],
|
||||
vec![],
|
||||
vec![
|
||||
@ -434,7 +434,7 @@ mod tests {
|
||||
);
|
||||
|
||||
let mut pair_expected = expected.clone();
|
||||
pair_expected.merge_with(expected);
|
||||
pair_expected.merge_with(expected, false);
|
||||
assert_eq!(
|
||||
pair_expected,
|
||||
bytelevel
|
||||
|
@ -43,7 +43,6 @@ impl PostProcessor for BertProcessing {
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
let mut new_encoding = Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
@ -68,7 +67,6 @@ impl PostProcessor for BertProcessing {
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
@ -91,7 +89,6 @@ impl PostProcessor for BertProcessing {
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
let new_pair_encoding = Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
@ -112,7 +109,6 @@ impl PostProcessor for BertProcessing {
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
@ -125,7 +121,7 @@ impl PostProcessor for BertProcessing {
|
||||
.collect(),
|
||||
);
|
||||
|
||||
new_encoding.merge_with(new_pair_encoding);
|
||||
new_encoding.merge_with(new_pair_encoding, false);
|
||||
}
|
||||
|
||||
Ok(new_encoding)
|
||||
|
@ -43,7 +43,6 @@ impl PostProcessor for RobertaProcessing {
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
let mut new_encoding = Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
@ -68,7 +67,6 @@ impl PostProcessor for RobertaProcessing {
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
let new_pair_encoding = Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
@ -78,7 +76,7 @@ impl PostProcessor for RobertaProcessing {
|
||||
encoding.take_overflowing(),
|
||||
);
|
||||
|
||||
new_encoding.merge_with(new_pair_encoding);
|
||||
new_encoding.merge_with(new_pair_encoding, false);
|
||||
}
|
||||
|
||||
Ok(new_encoding)
|
||||
|
@ -1,26 +1,9 @@
|
||||
use crate::tokenizer::NormalizedString;
|
||||
use crate::utils::padding::PaddingDirection;
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// The various possible padding directions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PaddingDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for PaddingDirection {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
PaddingDirection::Left => "left",
|
||||
PaddingDirection::Right => "right",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the output of a `Tokenizer`.
|
||||
#[derive(Default, PartialEq, Debug, Clone)]
|
||||
pub struct Encoding {
|
||||
normalized: NormalizedString,
|
||||
ids: Vec<u32>,
|
||||
type_ids: Vec<u32>,
|
||||
tokens: Vec<String>,
|
||||
@ -32,7 +15,6 @@ pub struct Encoding {
|
||||
impl Encoding {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
normalized: NormalizedString,
|
||||
ids: Vec<u32>,
|
||||
type_ids: Vec<u32>,
|
||||
tokens: Vec<String>,
|
||||
@ -42,7 +24,6 @@ impl Encoding {
|
||||
overflowing: Vec<Encoding>,
|
||||
) -> Self {
|
||||
Encoding {
|
||||
normalized,
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
@ -53,10 +34,6 @@ impl Encoding {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_normalized(&self) -> &NormalizedString {
|
||||
&self.normalized
|
||||
}
|
||||
|
||||
pub fn get_tokens(&self) -> &[String] {
|
||||
&self.tokens[..]
|
||||
}
|
||||
@ -124,7 +101,6 @@ impl Encoding {
|
||||
}
|
||||
|
||||
let o = Encoding {
|
||||
normalized: self.normalized.clone(),
|
||||
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
|
||||
type_ids: get_current_part(
|
||||
&prev_encoding.type_ids,
|
||||
@ -173,7 +149,7 @@ impl Encoding {
|
||||
}
|
||||
|
||||
/// Merge ourself with the given `Encoding`. Happens in place.
|
||||
pub fn merge_with(&mut self, pair: Encoding) {
|
||||
pub fn merge_with(&mut self, pair: Encoding, growing_offsets: bool) {
|
||||
// Handle merging the overflowing parts too: Combine them all
|
||||
// In most of the cases, we expect `pair.overflowing.len() == 0`
|
||||
let mut overflowings = vec![];
|
||||
@ -182,33 +158,33 @@ impl Encoding {
|
||||
for self_o in &self.overflowing {
|
||||
// 1. The pair itself
|
||||
let mut n_encoding = self_o.clone();
|
||||
n_encoding.merge_with(pair.clone());
|
||||
n_encoding.merge_with(pair.clone(), growing_offsets);
|
||||
overflowings.push(n_encoding);
|
||||
|
||||
// 2. Its overflowings (this should rarely happen...)
|
||||
for other_o in &pair.overflowing {
|
||||
let mut n_encoding = self_o.clone();
|
||||
n_encoding.merge_with(other_o.clone());
|
||||
n_encoding.merge_with(other_o.clone(), growing_offsets);
|
||||
overflowings.push(n_encoding);
|
||||
}
|
||||
}
|
||||
// 2. Ourself with all the other overflowings (this should rarely happen too...)
|
||||
for other_o in &pair.overflowing {
|
||||
let mut n_encoding = self.clone();
|
||||
n_encoding.merge_with(other_o.clone());
|
||||
n_encoding.merge_with(other_o.clone(), growing_offsets);
|
||||
overflowings.push(n_encoding);
|
||||
}
|
||||
|
||||
// Finish by merging ourself with the other encoding
|
||||
self.normalized.merge_with(&pair.normalized);
|
||||
self.ids.extend(pair.ids);
|
||||
self.type_ids.extend(pair.type_ids);
|
||||
self.tokens.extend(pair.tokens);
|
||||
|
||||
let starting_offset = self
|
||||
.offsets
|
||||
.iter()
|
||||
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
|
||||
let starting_offset = if growing_offsets {
|
||||
self.offsets.last().map_or(0, |o| o.1)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
self.offsets.extend(
|
||||
pair.offsets
|
||||
.into_iter()
|
||||
@ -304,7 +280,6 @@ mod tests {
|
||||
#[test]
|
||||
fn merge_encodings() {
|
||||
let mut a = Encoding {
|
||||
normalized: NormalizedString::from("Hello "),
|
||||
ids: vec![1],
|
||||
type_ids: vec![0],
|
||||
tokens: vec![String::from("Hello ")],
|
||||
@ -314,7 +289,6 @@ mod tests {
|
||||
overflowing: vec![],
|
||||
};
|
||||
let b = Encoding {
|
||||
normalized: NormalizedString::from("World!"),
|
||||
ids: vec![2],
|
||||
type_ids: vec![1],
|
||||
tokens: vec![String::from("World!")],
|
||||
@ -323,12 +297,11 @@ mod tests {
|
||||
attention_mask: vec![1],
|
||||
overflowing: vec![],
|
||||
};
|
||||
a.merge_with(b);
|
||||
a.merge_with(b, true);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
Encoding {
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![1, 2],
|
||||
type_ids: vec![0, 1],
|
||||
tokens: vec![String::from("Hello "), String::from("World!")],
|
||||
@ -343,7 +316,6 @@ mod tests {
|
||||
#[test]
|
||||
fn truncate() {
|
||||
let mut a = Encoding {
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![1, 2, 3],
|
||||
type_ids: vec![0, 0, 0],
|
||||
tokens: vec![
|
||||
@ -361,7 +333,6 @@ mod tests {
|
||||
assert_eq!(
|
||||
a,
|
||||
Encoding {
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![1, 2],
|
||||
type_ids: vec![0, 0],
|
||||
tokens: vec![String::from("Hello"), String::from("World")],
|
||||
@ -369,7 +340,6 @@ mod tests {
|
||||
special_tokens_mask: vec![0, 0],
|
||||
attention_mask: vec![1, 1],
|
||||
overflowing: vec![Encoding {
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![3],
|
||||
type_ids: vec![0],
|
||||
tokens: vec![String::from("!")],
|
||||
|
@ -9,10 +9,9 @@
|
||||
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
|
||||
//! ...).
|
||||
|
||||
pub use crate::utils::{
|
||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
||||
TruncationStrategy,
|
||||
};
|
||||
use crate::utils::iter::ResultShunt;
|
||||
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
|
||||
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
@ -31,6 +30,11 @@ pub use normalizer::*;
|
||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||
pub type Offsets = (usize, usize);
|
||||
|
||||
/// Takes care of pre-processing strings.
|
||||
pub trait Normalizer {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
|
||||
}
|
||||
|
||||
/// The `PreTokenizer` is in charge of doing the pre-segmentation step. It splits the given string
|
||||
/// in multiple substrings, keeping track of the offsets of said substrings from the
|
||||
/// `NormalizedString`. In some occasions, the `PreTokenizer` might need to modify the given
|
||||
@ -71,7 +75,7 @@ impl dyn PostProcessor {
|
||||
match pair_encoding {
|
||||
None => Ok(encoding),
|
||||
Some(pair) => {
|
||||
encoding.merge_with(pair);
|
||||
encoding.merge_with(pair, false);
|
||||
Ok(encoding)
|
||||
}
|
||||
}
|
||||
@ -290,99 +294,154 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_added_tokens(&self, is_pair: bool) -> usize {
|
||||
self.post_processor
|
||||
.as_ref()
|
||||
.map_or(0, |p| p.as_ref().added_tokens(is_pair))
|
||||
/// Normalize the given sentence and return the corresponding normalized string
|
||||
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
||||
let mut normalized = self
|
||||
.split_on_added_tokens(sentence)
|
||||
.into_iter()
|
||||
.map(|(sentence, id)| -> Result<NormalizedString> {
|
||||
if id.is_some() {
|
||||
Ok(NormalizedString::from(&sentence))
|
||||
} else {
|
||||
let mut normalized = self.do_normalize(&sentence)?;
|
||||
let _ = self.pre_tokenize(&mut normalized)?;
|
||||
|
||||
Ok(normalized)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let others = normalized.split_off(1);
|
||||
let mut normalized: NormalizedString = normalized.into_iter().next().unwrap();
|
||||
for n in others {
|
||||
normalized.merge_with(&n);
|
||||
}
|
||||
|
||||
Ok(normalized)
|
||||
}
|
||||
|
||||
/// Encode the given sentence
|
||||
pub fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> Result<Encoding> {
|
||||
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
|
||||
// First we need to split into as many sequences as needed to avoid splitting
|
||||
// on our added tokens
|
||||
let mut encodings = self
|
||||
.split_on_added_tokens(&sentence)
|
||||
.into_iter()
|
||||
.map(|(sentence, id)| -> Result<Encoding> {
|
||||
// If this is one of our added tokens, lets return an encoding directly
|
||||
if let Some(id) = id {
|
||||
return Ok(Encoding::new(
|
||||
NormalizedString::from(&sentence),
|
||||
vec![id],
|
||||
vec![type_id],
|
||||
vec![sentence.to_owned()],
|
||||
vec![(0, sentence.len())],
|
||||
vec![0],
|
||||
vec![1],
|
||||
vec![],
|
||||
));
|
||||
}
|
||||
let generate_output =
|
||||
move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> {
|
||||
// First we need to split into as many sequences as needed to avoid splitting
|
||||
// on our added tokens
|
||||
let results = self.split_on_added_tokens(&sentence).into_iter().map(
|
||||
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||
// If this is one of our added tokens, lets return an encoding directly
|
||||
if let Some(id) = id {
|
||||
return Ok((
|
||||
Encoding::new(
|
||||
vec![id],
|
||||
vec![type_id],
|
||||
vec![sentence.to_owned()],
|
||||
vec![(0, sentence.len())],
|
||||
vec![0],
|
||||
vec![1],
|
||||
vec![],
|
||||
),
|
||||
NormalizedString::from(&sentence),
|
||||
));
|
||||
}
|
||||
|
||||
// 1. Normalization
|
||||
let mut normalized = self.normalize(&sentence)?;
|
||||
// 1. Normalization
|
||||
let mut normalized = self.do_normalize(&sentence)?;
|
||||
|
||||
// 2. Pre tokenization
|
||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||
// 2. Pre tokenization
|
||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||
|
||||
// 3. Model
|
||||
let output = self.model.tokenize(pre_tokenized)?;
|
||||
let length = output.len();
|
||||
// 3. Model
|
||||
let output = self.model.tokenize(pre_tokenized)?;
|
||||
let length = output.len();
|
||||
|
||||
let (ids, tokens, offsets) = output.into_iter().fold(
|
||||
(
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
),
|
||||
|(mut ids, mut tokens, mut offsets), t| {
|
||||
ids.push(t.id);
|
||||
tokens.push(t.value);
|
||||
offsets.push(t.offsets);
|
||||
(ids, tokens, offsets)
|
||||
},
|
||||
);
|
||||
let (ids, tokens, offsets) = output.into_iter().fold(
|
||||
(
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
),
|
||||
|(mut ids, mut tokens, mut offsets), t| {
|
||||
ids.push(t.id);
|
||||
tokens.push(t.value);
|
||||
offsets.push(t.offsets);
|
||||
(ids, tokens, offsets)
|
||||
},
|
||||
);
|
||||
|
||||
Ok(Encoding::new(
|
||||
normalized,
|
||||
ids,
|
||||
vec![type_id; length],
|
||||
tokens,
|
||||
offsets,
|
||||
vec![0; length],
|
||||
vec![1; length],
|
||||
vec![],
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<Encoding>>>()?;
|
||||
Ok((
|
||||
Encoding::new(
|
||||
ids,
|
||||
vec![type_id; length],
|
||||
tokens,
|
||||
offsets,
|
||||
vec![0; length],
|
||||
vec![1; length],
|
||||
vec![],
|
||||
),
|
||||
normalized,
|
||||
))
|
||||
},
|
||||
);
|
||||
|
||||
if encodings.is_empty() {
|
||||
return Ok(Encoding::default());
|
||||
}
|
||||
let (mut encodings, mut normalized) =
|
||||
ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
|
||||
|
||||
let others = encodings.split_off(1);
|
||||
let mut first: Encoding = encodings.into_iter().next().unwrap();
|
||||
if encodings.is_empty() {
|
||||
return Ok((Encoding::default(), NormalizedString::from("")));
|
||||
}
|
||||
|
||||
for encoding in others {
|
||||
first.merge_with(encoding);
|
||||
}
|
||||
let others = encodings.split_off(1);
|
||||
let mut first: Encoding = encodings.into_iter().next().unwrap();
|
||||
|
||||
Ok(first)
|
||||
};
|
||||
for encoding in others {
|
||||
first.merge_with(encoding, true);
|
||||
}
|
||||
|
||||
let others = normalized.split_off(1);
|
||||
let mut normalized: NormalizedString = normalized.into_iter().next().unwrap();
|
||||
for n in others {
|
||||
normalized.merge_with(&n);
|
||||
}
|
||||
|
||||
Ok((first, normalized))
|
||||
};
|
||||
|
||||
let (sentence, pair) = match input {
|
||||
EncodeInput::Single(s1) => (s1, None),
|
||||
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
|
||||
};
|
||||
|
||||
let encoding = generate_output(sentence, 0)?;
|
||||
let pair_encoding = match pair {
|
||||
Some(pair) => Some(generate_output(pair, 1)?),
|
||||
None => None,
|
||||
let (encoding, normalized) = generate_output(sentence, 0)?;
|
||||
let (pair_encoding, pair_normalized) = match pair {
|
||||
Some(pair) => {
|
||||
let (e, n) = generate_output(pair, 1)?;
|
||||
(Some(e), Some(n))
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
// 4. Post processing
|
||||
self.post_process(encoding, pair_encoding, add_special_tokens)
|
||||
let mut output = self.post_process(encoding, pair_encoding, add_special_tokens)?;
|
||||
|
||||
// 5. Convert offsets back to original string
|
||||
let mut current_offset = (0, 0);
|
||||
let mut n_source = &normalized;
|
||||
output
|
||||
.get_offsets_mut()
|
||||
.iter_mut()
|
||||
.for_each(|(start, end)| {
|
||||
if (*start, *end) < current_offset {
|
||||
n_source = &pair_normalized.as_ref().unwrap_or(&normalized);
|
||||
}
|
||||
current_offset = (*start, *end);
|
||||
let (s, e) = n_source
|
||||
.convert_offsets(Range::Normalized(*start..*end))
|
||||
.map_or((*start, *end), |range| (range.start, range.end));
|
||||
*start = s;
|
||||
*end = e;
|
||||
});
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Encode all the sentences in parallel, using multiple threads
|
||||
@ -471,7 +530,7 @@ impl Tokenizer {
|
||||
match file.read_line(&mut buf)? {
|
||||
0 => break,
|
||||
b => {
|
||||
let mut normalized = self.normalize(&buf)?;
|
||||
let mut normalized = self.do_normalize(&buf)?;
|
||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||
trainer.process_tokens(
|
||||
&mut words,
|
||||
@ -522,7 +581,7 @@ impl Tokenizer {
|
||||
}
|
||||
|
||||
/// Normalization logic, go through all normalizers
|
||||
fn normalize(&self, sequence: &str) -> Result<NormalizedString> {
|
||||
fn do_normalize(&self, sequence: &str) -> Result<NormalizedString> {
|
||||
let mut normalized = NormalizedString::from(sequence);
|
||||
|
||||
if let Some(normalizer) = &self.normalizer {
|
||||
|
@ -1,21 +1,63 @@
|
||||
use super::Result;
|
||||
use std::cmp::Ordering;
|
||||
use std::ops::{Bound, RangeBounds};
|
||||
use unicode_normalization_alignments::UnicodeNormalization;
|
||||
|
||||
/// Takes care of pre-processing strings.
|
||||
pub trait Normalizer {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
|
||||
/// Represents a Range usable by the NormalizedString to index its content.
|
||||
/// A Range can use indices relative to either the `Original` or the `Normalized` string
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Range<T: RangeBounds<usize>> {
|
||||
Original(T),
|
||||
Normalized(T),
|
||||
}
|
||||
|
||||
/// A normalized string takes care of keeping both versions of a `String`, and
|
||||
/// provides necessary alignments to retrieve ranges of both strings.
|
||||
impl<T> Range<T>
|
||||
where
|
||||
T: RangeBounds<usize>,
|
||||
{
|
||||
/// Unwrap the underlying range
|
||||
fn unwrap(self) -> T {
|
||||
match self {
|
||||
Range::Original(r) => r,
|
||||
Range::Normalized(r) => r,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the current Range to a `std::ops::Range<usize>`. This requires the `max_len`
|
||||
/// of the represented string (in chars, not bytes) in order to cover the case where the
|
||||
/// original provided range was unbounded
|
||||
fn into_full_range(self, max_len: usize) -> std::ops::Range<usize> {
|
||||
let range = self.unwrap();
|
||||
|
||||
let start = match range.start_bound() {
|
||||
Bound::Unbounded => 0,
|
||||
Bound::Included(i) => *i,
|
||||
Bound::Excluded(i) => *i + 1,
|
||||
};
|
||||
let end = match range.end_bound() {
|
||||
Bound::Unbounded => max_len,
|
||||
Bound::Included(i) => *i + 1,
|
||||
Bound::Excluded(i) => *i,
|
||||
};
|
||||
|
||||
start..end
|
||||
}
|
||||
}
|
||||
|
||||
/// A `NormalizedString` takes care of processing an "original" string to modify it and obtain a
|
||||
/// "normalized" string. It keeps both version of the string, alignments information between both
|
||||
/// and provides an interface to retrieve ranges of each string, using offsets from any of them.
|
||||
///
|
||||
/// It is possible to retrieve a part of the original string, by indexing it with offsets from the
|
||||
/// normalized one, and the other way around too. It is also possible to convert offsets from one
|
||||
/// referential to the other one easily.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct NormalizedString {
|
||||
/// The original version of the string, before any modification
|
||||
original: String,
|
||||
/// The normalized version of the string, after all modifications
|
||||
normalized: String,
|
||||
/// Mapping from normalized string to original one
|
||||
/// (pos, changes) where pos is the position in the modified string, and changes an isize
|
||||
/// representing the number of insertions or deletions
|
||||
/// Mapping from normalized string to original one: (start, end) for each character of the
|
||||
/// normalized string
|
||||
alignments: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
@ -26,6 +68,7 @@ impl std::cmp::PartialEq for NormalizedString {
|
||||
}
|
||||
|
||||
impl NormalizedString {
|
||||
/// Create a NormalizedString from the given str
|
||||
pub fn from(s: &str) -> Self {
|
||||
NormalizedString {
|
||||
original: s.to_owned(),
|
||||
@ -44,60 +87,68 @@ impl NormalizedString {
|
||||
&self.original
|
||||
}
|
||||
|
||||
/// Return the range of the original string corresponding to the received range on the
|
||||
/// normalized string. Returns None if out of bounds
|
||||
pub fn get_original_offsets(
|
||||
/// Convert the given offsets range from one referential to the other one:
|
||||
/// `Original => Normalized` or `Normalized => Original`
|
||||
pub fn convert_offsets<T: RangeBounds<usize>>(
|
||||
&self,
|
||||
range: std::ops::Range<usize>,
|
||||
range: Range<T>,
|
||||
) -> Option<std::ops::Range<usize>> {
|
||||
self.alignments
|
||||
.get(range)
|
||||
.map(|alignments| {
|
||||
if alignments.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let start = alignments[0].0;
|
||||
let end = alignments[alignments.len() - 1].1;
|
||||
Some(start..end)
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
}
|
||||
|
||||
fn get_range_of(&self, s: &str, range: std::ops::Range<usize>) -> Option<String> {
|
||||
let len = s.chars().count();
|
||||
if range.start >= len || range.end > len {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
s.chars()
|
||||
match range {
|
||||
Range::Original(_) => {
|
||||
let (mut start, mut end) = (0, 0);
|
||||
let r = range.into_full_range(self.alignments.last().map_or(0, |(_, e)| *e));
|
||||
println!("{:?}\t{:?}", r, self.alignments);
|
||||
self.alignments
|
||||
.iter()
|
||||
.enumerate()
|
||||
.skip(range.start)
|
||||
.map(|(i, c)| {
|
||||
if i >= range.start && i < range.end {
|
||||
Some(c)
|
||||
} else {
|
||||
None
|
||||
.take_while(|(_, alignment)| r.end >= alignment.1)
|
||||
.for_each(|(i, alignment)| {
|
||||
println!("{:?}", alignment);
|
||||
if alignment.0 <= r.start {
|
||||
start = i;
|
||||
}
|
||||
})
|
||||
.fuse()
|
||||
.filter(|c| c.is_some())
|
||||
.map(|c| c.unwrap())
|
||||
.collect::<String>(),
|
||||
)
|
||||
if alignment.1 <= r.end {
|
||||
end = i + 1;
|
||||
}
|
||||
});
|
||||
Some(start..end)
|
||||
}
|
||||
Range::Normalized(_) => self
|
||||
.alignments
|
||||
.get(range.into_full_range(self.alignments.len()))
|
||||
.map(|alignments| {
|
||||
if alignments.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let start = alignments[0].0;
|
||||
let end = alignments[alignments.len() - 1].1;
|
||||
Some(start..end)
|
||||
}
|
||||
})
|
||||
.flatten(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a range of the normalized string (indexing on char not bytes)
|
||||
pub fn get_range(&self, range: std::ops::Range<usize>) -> Option<String> {
|
||||
self.get_range_of(&self.normalized, range)
|
||||
pub fn get_range<T: RangeBounds<usize>>(&self, range: Range<T>) -> Option<&str> {
|
||||
match range {
|
||||
Range::Original(_) => self
|
||||
.convert_offsets(range)
|
||||
.map(|r| get_range_of(&self.normalized, r))
|
||||
.flatten(),
|
||||
Range::Normalized(r) => get_range_of(&self.normalized, r),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a range of the original string, using a range from the normalized string
|
||||
pub fn get_range_original(&self, range: std::ops::Range<usize>) -> Option<String> {
|
||||
self.get_original_offsets(range)
|
||||
.map(|range| self.get_range_of(&self.original, range))
|
||||
.flatten()
|
||||
/// Return a range of the original string (indexing on char not bytes)
|
||||
pub fn get_range_original<T: RangeBounds<usize>>(&self, range: Range<T>) -> Option<&str> {
|
||||
match range {
|
||||
Range::Original(r) => get_range_of(&self.original, r),
|
||||
Range::Normalized(_) => self
|
||||
.convert_offsets(range)
|
||||
.map(|r| get_range_of(&self.original, r))
|
||||
.flatten(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies transformations to the current normalized version, updating the current
|
||||
@ -115,19 +166,10 @@ impl NormalizedString {
|
||||
/// them has a `change` of `1`, but more doesn't make any sense.
|
||||
/// We treat any value above `1` as `1`.
|
||||
pub fn transform<I: Iterator<Item = (char, isize)>>(&mut self, dest: I, initial_offset: usize) {
|
||||
let mut offset = 0;
|
||||
let mut remaining_offset = initial_offset;
|
||||
let mut offset = -(initial_offset as isize);
|
||||
let (ch, alignments): (Vec<_>, Vec<_>) = dest
|
||||
.enumerate()
|
||||
.map(|(index, (c, changes))| {
|
||||
let changes = if remaining_offset != 0 {
|
||||
let c = changes - remaining_offset as isize;
|
||||
remaining_offset = 0;
|
||||
c
|
||||
} else {
|
||||
changes
|
||||
};
|
||||
|
||||
let uof = if offset < 0 {
|
||||
-offset as usize
|
||||
} else {
|
||||
@ -149,24 +191,10 @@ impl NormalizedString {
|
||||
}
|
||||
// No changes required here
|
||||
Ordering::Equal => self.alignments.get(idx).copied(),
|
||||
// Some characters where removed, so we merge our range with the one from the
|
||||
// removed characters as the new alignment
|
||||
// Some characters where removed, nothing to change in alignments
|
||||
Ordering::Less => {
|
||||
let uch = -changes as usize;
|
||||
offset += changes;
|
||||
self.alignments.get(idx..=idx + uch).map(|alignments| {
|
||||
let min = alignments
|
||||
.iter()
|
||||
.map(|(start, end)| usize::min(*start, *end))
|
||||
.min()
|
||||
.unwrap();
|
||||
let max = alignments
|
||||
.iter()
|
||||
.map(|(start, end)| usize::max(*start, *end))
|
||||
.max()
|
||||
.unwrap();
|
||||
(min, max)
|
||||
})
|
||||
self.alignments.get(idx).copied()
|
||||
}
|
||||
};
|
||||
|
||||
@ -421,6 +449,37 @@ impl NormalizedString {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a range of the given string slice, by indexing chars instead of bytes
|
||||
pub fn get_range_of<T: RangeBounds<usize>>(s: &str, range: T) -> Option<&str> {
|
||||
let len = s.chars().count();
|
||||
let start = match range.start_bound() {
|
||||
Bound::Unbounded => 0,
|
||||
Bound::Included(i) => *i,
|
||||
Bound::Excluded(i) => *i + 1,
|
||||
};
|
||||
let end = match range.end_bound() {
|
||||
Bound::Unbounded => len,
|
||||
Bound::Included(i) => *i + 1,
|
||||
Bound::Excluded(i) => *i,
|
||||
};
|
||||
|
||||
if start >= len || end > len || start >= end {
|
||||
None
|
||||
} else {
|
||||
let start_b = s
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.nth(start as usize)
|
||||
.unwrap_or(0);
|
||||
let end_b = s
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.nth(end as usize)
|
||||
.unwrap_or_else(|| s.len());
|
||||
Some(&s[start_b..end_b])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -462,7 +521,7 @@ mod tests {
|
||||
n.filter(|c| *c != 'n');
|
||||
assert_eq!(
|
||||
&n.alignments,
|
||||
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), (6, 7)]
|
||||
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
|
||||
);
|
||||
}
|
||||
|
||||
@ -472,18 +531,43 @@ mod tests {
|
||||
n.nfd().filter(|c| !c.is_mark_nonspacing() && *c != 'n');
|
||||
assert_eq!(
|
||||
&n.alignments,
|
||||
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), (6, 7)]
|
||||
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_conversion() {
|
||||
let mut n = NormalizedString::from(" __Hello__ ");
|
||||
n.filter(|c| !c.is_whitespace()).lowercase();
|
||||
let hello_n = n.convert_offsets(Range::Original(6..11));
|
||||
assert_eq!(hello_n, Some(2..7));
|
||||
assert_eq!(
|
||||
n.get_range(Range::Normalized(hello_n.clone().unwrap())),
|
||||
Some("hello")
|
||||
);
|
||||
assert_eq!(
|
||||
n.get_range_original(Range::Normalized(hello_n.unwrap())),
|
||||
Some("Hello")
|
||||
);
|
||||
assert_eq!(n.get_range(Range::Original(6..11)), Some("hello"));
|
||||
assert_eq!(n.get_range_original(Range::Original(6..11)), Some("Hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn original_range() {
|
||||
let mut n = NormalizedString::from("Hello_______ World!");
|
||||
n.filter(|c| *c != '_').lowercase();
|
||||
let world_n = n.get_range(6..11).unwrap();
|
||||
let world_o = n.get_range_original(6..11).unwrap();
|
||||
let world_n = n.get_range(Range::Normalized(6..11)).unwrap();
|
||||
let world_o = n.get_range_original(Range::Normalized(6..11)).unwrap();
|
||||
assert_eq!(world_n, "world");
|
||||
assert_eq!(world_o, "World");
|
||||
let original_range = Range::Original(n.convert_offsets(Range::Normalized(6..11)).unwrap());
|
||||
assert_eq!(n.get_range(original_range.clone()).unwrap(), "world");
|
||||
assert_eq!(
|
||||
n.get_range_original(original_range.clone()).unwrap(),
|
||||
"World"
|
||||
);
|
||||
assert_eq!(original_range.into_full_range(n.len_original()), 13..18);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -505,8 +589,8 @@ mod tests {
|
||||
|
||||
assert_eq!(&n.normalized, " Hello ");
|
||||
assert_eq!(
|
||||
n.get_range_original(0..n.normalized.len()),
|
||||
Some("Hello".into())
|
||||
n.get_range_original(Range::Normalized(1..n.normalized.len() - 1)),
|
||||
Some("Hello")
|
||||
);
|
||||
}
|
||||
|
||||
@ -514,10 +598,13 @@ mod tests {
|
||||
fn remove_at_beginning() {
|
||||
let mut n = NormalizedString::from(" Hello");
|
||||
n.filter(|c| !c.is_whitespace());
|
||||
assert_eq!(n.get_range_original(1.."Hello".len()), Some("ello".into()));
|
||||
assert_eq!(
|
||||
n.get_range_original(0..n.normalized.len()),
|
||||
Some(" Hello".into())
|
||||
n.get_range_original(Range::Normalized(1.."Hello".len())),
|
||||
Some("ello")
|
||||
);
|
||||
assert_eq!(
|
||||
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||
Some("Hello")
|
||||
);
|
||||
}
|
||||
|
||||
@ -525,10 +612,10 @@ mod tests {
|
||||
fn remove_at_end() {
|
||||
let mut n = NormalizedString::from("Hello ");
|
||||
n.filter(|c| !c.is_whitespace());
|
||||
assert_eq!(n.get_range_original(0..4), Some("Hell".into()));
|
||||
assert_eq!(n.get_range_original(Range::Normalized(0..4)), Some("Hell"));
|
||||
assert_eq!(
|
||||
n.get_range_original(0..n.normalized.len()),
|
||||
Some("Hello ".into())
|
||||
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||
Some("Hello")
|
||||
);
|
||||
}
|
||||
|
||||
@ -539,10 +626,13 @@ mod tests {
|
||||
assert_eq!(&n.normalized, "Hello");
|
||||
|
||||
assert_eq!(
|
||||
n.get_range_original(0.."Hello".len()),
|
||||
Some(" Hello ".into())
|
||||
n.get_range_original(Range::Normalized(0.."Hello".len())),
|
||||
Some("Hello")
|
||||
);
|
||||
assert_eq!(
|
||||
n.get_range_original(Range::Normalized(1.."Hell".len())),
|
||||
Some("ell")
|
||||
);
|
||||
assert_eq!(n.get_range_original(1.."Hell".len()), Some("ell".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -551,8 +641,8 @@ mod tests {
|
||||
n.lstrip();
|
||||
assert_eq!(&n.normalized, "This is an example ");
|
||||
assert_eq!(
|
||||
n.get_range_original(0..n.normalized.len()),
|
||||
Some(" This is an example ".into())
|
||||
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||
Some("This is an example ")
|
||||
);
|
||||
}
|
||||
|
||||
@ -562,8 +652,8 @@ mod tests {
|
||||
n.rstrip();
|
||||
assert_eq!(&n.normalized, " This is an example");
|
||||
assert_eq!(
|
||||
n.get_range_original(0..n.normalized.len()),
|
||||
Some(" This is an example ".into())
|
||||
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||
Some(" This is an example")
|
||||
);
|
||||
}
|
||||
|
||||
@ -573,8 +663,8 @@ mod tests {
|
||||
n.strip();
|
||||
assert_eq!(&n.normalized, "This is an example");
|
||||
assert_eq!(
|
||||
n.get_range_original(0..n.normalized.len()),
|
||||
Some(" This is an example ".into())
|
||||
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||
Some("This is an example")
|
||||
);
|
||||
}
|
||||
|
||||
@ -597,7 +687,7 @@ mod tests {
|
||||
(4, 5)
|
||||
]
|
||||
);
|
||||
assert_eq!(n.get_original_offsets(0..4), Some(0..0));
|
||||
assert_eq!(n.convert_offsets(Range::Normalized(0..4)), Some(0..0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -619,6 +709,16 @@ mod tests {
|
||||
(3, 3)
|
||||
]
|
||||
);
|
||||
assert_eq!(n.get_original_offsets(3.." there".len()), Some(3..3));
|
||||
assert_eq!(
|
||||
n.convert_offsets(Range::Normalized(3.." there".len())),
|
||||
Some(3..3)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_range() {
|
||||
let s = String::from("Hello my name is John 👋");
|
||||
assert_eq!(get_range_of(&s, ..), Some(&s[..]));
|
||||
assert_eq!(get_range_of(&s, 17..), Some("John 👋"));
|
||||
}
|
||||
}
|
||||
|
58
tokenizers/src/utils/iter.rs
Normal file
58
tokenizers/src/utils/iter.rs
Normal file
@ -0,0 +1,58 @@
|
||||
//! This comes from the Rust libcore and is duplicated here because it is not exported
|
||||
//! (cf https://github.com/rust-lang/rust/blob/25091ed9b7739e12466fb2490baa1e8a2815121c/src/libcore/iter/adapters/mod.rs#L2664)
|
||||
//! We are now using the version from https://stackoverflow.com/questions/44544323/how-to-unzip-a-sequence-of-resulta-b-e-to-a-veca-vecb-and-stop-on-f
|
||||
//! because the one from the libcore seems to cause overflowing stacks in some cases
|
||||
|
||||
pub struct ResultShunt<I, E> {
|
||||
iter: I,
|
||||
error: Option<E>,
|
||||
}
|
||||
|
||||
impl<I, T, E> ResultShunt<I, E>
|
||||
where
|
||||
I: Iterator<Item = Result<T, E>>,
|
||||
{
|
||||
/// Process the given iterator as if it yielded a `T` instead of a
|
||||
/// `Result<T, _>`. Any errors will stop the inner iterator and
|
||||
/// the overall result will be an error.
|
||||
pub fn process<F, U>(iter: I, mut f: F) -> Result<U, E>
|
||||
where
|
||||
F: FnMut(&mut Self) -> U,
|
||||
{
|
||||
let mut shunt = ResultShunt::new(iter);
|
||||
let value = f(shunt.by_ref());
|
||||
shunt.reconstruct(value)
|
||||
}
|
||||
|
||||
fn new(iter: I) -> Self {
|
||||
ResultShunt { iter, error: None }
|
||||
}
|
||||
|
||||
/// Consume the adapter and rebuild a `Result` value. This should
|
||||
/// *always* be called, otherwise any potential error would be
|
||||
/// lost.
|
||||
fn reconstruct<U>(self, val: U) -> Result<U, E> {
|
||||
match self.error {
|
||||
None => Ok(val),
|
||||
Some(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, T, E> Iterator for ResultShunt<I, E>
|
||||
where
|
||||
I: Iterator<Item = Result<T, E>>,
|
||||
{
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self.iter.next() {
|
||||
Some(Ok(v)) => Some(v),
|
||||
Some(Err(e)) => {
|
||||
self.error = Some(e);
|
||||
None
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
}
|
3
tokenizers/src/utils/mod.rs
Normal file
3
tokenizers/src/utils/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod iter;
|
||||
pub mod padding;
|
||||
pub mod truncation;
|
63
tokenizers/src/utils/padding.rs
Normal file
63
tokenizers/src/utils/padding.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use crate::tokenizer::{Encoding, Result};
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// The various possible padding directions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PaddingDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for PaddingDirection {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
PaddingDirection::Left => "left",
|
||||
PaddingDirection::Right => "right",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PaddingParams {
|
||||
pub strategy: PaddingStrategy,
|
||||
pub direction: PaddingDirection,
|
||||
pub pad_id: u32,
|
||||
pub pad_type_id: u32,
|
||||
pub pad_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PaddingStrategy {
|
||||
BatchLongest,
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
pub fn pad_encodings(
|
||||
mut encodings: Vec<Encoding>,
|
||||
params: &PaddingParams,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if encodings.is_empty() {
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let pad_length = match params.strategy {
|
||||
PaddingStrategy::Fixed(size) => size,
|
||||
PaddingStrategy::BatchLongest => encodings
|
||||
.par_iter()
|
||||
.map(|e| e.get_ids().len())
|
||||
.max()
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
encodings.par_iter_mut().for_each(|encoding| {
|
||||
encoding.pad(
|
||||
pad_length,
|
||||
params.pad_id,
|
||||
params.pad_type_id,
|
||||
¶ms.pad_token,
|
||||
params.direction,
|
||||
)
|
||||
});
|
||||
|
||||
Ok(encodings)
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
use crate::tokenizer::{Encoding, PaddingDirection, Result};
|
||||
use rayon::prelude::*;
|
||||
use crate::tokenizer::{Encoding, Result};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TruncationParams {
|
||||
@ -8,21 +7,6 @@ pub struct TruncationParams {
|
||||
pub stride: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PaddingParams {
|
||||
pub strategy: PaddingStrategy,
|
||||
pub direction: PaddingDirection,
|
||||
pub pad_id: u32,
|
||||
pub pad_type_id: u32,
|
||||
pub pad_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PaddingStrategy {
|
||||
BatchLongest,
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
SecondSequenceNotProvided,
|
||||
@ -118,33 +102,3 @@ pub fn truncate_encodings(
|
||||
|
||||
Ok((encoding, pair_encoding))
|
||||
}
|
||||
|
||||
pub fn pad_encodings(
|
||||
mut encodings: Vec<Encoding>,
|
||||
params: &PaddingParams,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if encodings.is_empty() {
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let pad_length = match params.strategy {
|
||||
PaddingStrategy::Fixed(size) => size,
|
||||
PaddingStrategy::BatchLongest => encodings
|
||||
.par_iter()
|
||||
.map(|e| e.get_ids().len())
|
||||
.max()
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
encodings.par_iter_mut().for_each(|encoding| {
|
||||
encoding.pad(
|
||||
pad_length,
|
||||
params.pad_id,
|
||||
params.pad_type_id,
|
||||
¶ms.pad_token,
|
||||
params.direction,
|
||||
)
|
||||
});
|
||||
|
||||
Ok(encodings)
|
||||
}
|
154
tokenizers/tests/offsets.rs
Normal file
154
tokenizers/tests/offsets.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use tokenizers::models::bpe::BPE;
|
||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||
use tokenizers::tokenizer::{get_range_of, EncodeInput, Tokenizer};
|
||||
|
||||
fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
|
||||
let mut tokenizer = Tokenizer::new(Box::new(
|
||||
BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
|
||||
.build()
|
||||
.expect("Files not found, run `make test` to download these files"),
|
||||
));
|
||||
tokenizer.with_pre_tokenizer(Box::new(
|
||||
ByteLevel::default().add_prefix_space(add_prefix_space),
|
||||
));
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||
tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets)));
|
||||
|
||||
tokenizer
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn offset_as_range(offset: (usize, usize)) -> std::ops::Range<usize> {
|
||||
offset.0..offset.1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn byte_level_basic() {
|
||||
// Without trimming offsets
|
||||
let tokenizer = get_byte_level(true, false);
|
||||
|
||||
let input = String::from("Hello there, how are you?");
|
||||
let output = tokenizer
|
||||
.encode(EncodeInput::Single(input.clone()), false)
|
||||
.unwrap();
|
||||
|
||||
let offsets = output.get_offsets();
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[0])),
|
||||
Some("Hello")
|
||||
);
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[1])),
|
||||
Some(" there")
|
||||
);
|
||||
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some(","));
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[3])),
|
||||
Some(" how")
|
||||
);
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[4])),
|
||||
Some(" are")
|
||||
);
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[5])),
|
||||
Some(" you")
|
||||
);
|
||||
assert_eq!(get_range_of(&input, offset_as_range(offsets[6])), Some("?"));
|
||||
|
||||
// And when trimming offsets:
|
||||
let tokenizer = get_byte_level(true, true);
|
||||
|
||||
let input = String::from("Hello there, how are you?");
|
||||
let output = tokenizer
|
||||
.encode(EncodeInput::Single(input.clone()), false)
|
||||
.unwrap();
|
||||
|
||||
let offsets = output.get_offsets();
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[0])),
|
||||
Some("Hello")
|
||||
);
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[1])),
|
||||
Some("there")
|
||||
);
|
||||
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some(","));
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[3])),
|
||||
Some("how")
|
||||
);
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[4])),
|
||||
Some("are")
|
||||
);
|
||||
assert_eq!(
|
||||
get_range_of(&input, offset_as_range(offsets[5])),
|
||||
Some("you")
|
||||
);
|
||||
assert_eq!(get_range_of(&input, offset_as_range(offsets[6])), Some("?"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn byte_level_unicode() {
|
||||
let tokenizer = get_byte_level(true, false);
|
||||
|
||||
let input = String::from("i⭢j");
|
||||
let output = tokenizer
|
||||
.encode(EncodeInput::Single(input.clone()), false)
|
||||
.unwrap();
|
||||
|
||||
let offsets = output.get_offsets();
|
||||
assert_eq!(get_range_of(&input, offset_as_range(offsets[1])), Some("⭢"));
|
||||
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some("⭢"));
|
||||
assert_eq!(get_range_of(&input, offset_as_range(offsets[3])), Some("⭢"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn byte_level_double_sequence() {
|
||||
let input_a = String::from("My name is Anthony");
|
||||
let input_b = String::from("What is my name?");
|
||||
|
||||
// Without trimming offsets
|
||||
let tokenizer = get_byte_level(true, false);
|
||||
let output = tokenizer
|
||||
.encode(EncodeInput::Dual(input_a.clone(), input_b.clone()), false)
|
||||
.unwrap();
|
||||
|
||||
let offsets = output.get_offsets();
|
||||
assert_eq!(
|
||||
offsets,
|
||||
&[
|
||||
(0, 2),
|
||||
(2, 7),
|
||||
(7, 10),
|
||||
(10, 18),
|
||||
(0, 4),
|
||||
(4, 7),
|
||||
(7, 10),
|
||||
(10, 15),
|
||||
(15, 16)
|
||||
]
|
||||
);
|
||||
|
||||
// When trimming offsets
|
||||
let tokenizer = get_byte_level(true, true);
|
||||
let output = tokenizer
|
||||
.encode(EncodeInput::Dual(input_a, input_b), false)
|
||||
.unwrap();
|
||||
let offsets = output.get_offsets();
|
||||
assert_eq!(
|
||||
offsets,
|
||||
&[
|
||||
(0, 2),
|
||||
(3, 7),
|
||||
(8, 10),
|
||||
(11, 18),
|
||||
(0, 4),
|
||||
(5, 7),
|
||||
(8, 10),
|
||||
(11, 15),
|
||||
(15, 16)
|
||||
]
|
||||
);
|
||||
}
|
Reference in New Issue
Block a user