mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 23:09:34 +00:00
Merge pull request #197 from huggingface/remove-normalized
Remove NormalizedString from Encoding
This commit is contained in:
17
.github/workflows/rust.yml
vendored
17
.github/workflows/rust.yml
vendored
@ -60,7 +60,22 @@ jobs:
|
|||||||
args: --manifest-path ./tokenizers/Cargo.toml --all-targets --all-features -- -D warnings
|
args: --manifest-path ./tokenizers/Cargo.toml --all-targets --all-features -- -D warnings
|
||||||
|
|
||||||
- name: Run Tests
|
- name: Run Tests
|
||||||
|
if: matrix.os != 'windows-latest'
|
||||||
|
shell: bash
|
||||||
|
working-directory: ./tokenizers
|
||||||
|
run: make test
|
||||||
|
|
||||||
|
# Skip integration tests for now on Windows
|
||||||
|
- name: Run lib Tests on Windows
|
||||||
|
if: matrix.os == 'windows-latest'
|
||||||
uses: actions-rs/cargo@v1
|
uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: test
|
command: test
|
||||||
args: --verbose --manifest-path ./tokenizers/Cargo.toml
|
args: --verbose --manifest-path ./tokenizers/Cargo.toml --lib
|
||||||
|
|
||||||
|
- name: Run doc Tests on Windows
|
||||||
|
if: matrix.os == 'windows-latest'
|
||||||
|
uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
args: --verbose --manifest-path ./tokenizers/Cargo.toml --doc
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,6 +7,7 @@ target
|
|||||||
Cargo.lock
|
Cargo.lock
|
||||||
|
|
||||||
/data
|
/data
|
||||||
|
tokenizers/data
|
||||||
/docs
|
/docs
|
||||||
|
|
||||||
__pycache__
|
__pycache__
|
||||||
|
11
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
11
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
@ -44,17 +44,6 @@ export interface RawEncoding {
|
|||||||
*/
|
*/
|
||||||
getTypeIds(): number[];
|
getTypeIds(): number[];
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the original string
|
|
||||||
*
|
|
||||||
* @param [begin] The index from which to start (can be negative).
|
|
||||||
* @param [end] The index (excluded) to which to stop (can be negative).
|
|
||||||
* Stopping at the end of the string if not provided.
|
|
||||||
* @returns The full original string if no parameter is provided,
|
|
||||||
* otherwise the original string between `begin` and `end`
|
|
||||||
*/
|
|
||||||
getOriginalString(begin?: number, end?: number): string;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pad the current Encoding at the given length
|
* Pad the current Encoding at the given length
|
||||||
*
|
*
|
||||||
|
@ -22,7 +22,6 @@ describe("RawEncoding", () => {
|
|||||||
expect(typeof encoding.getIds).toBe("function");
|
expect(typeof encoding.getIds).toBe("function");
|
||||||
expect(typeof encoding.getLength).toBe("function");
|
expect(typeof encoding.getLength).toBe("function");
|
||||||
expect(typeof encoding.getOffsets).toBe("function");
|
expect(typeof encoding.getOffsets).toBe("function");
|
||||||
expect(typeof encoding.getOriginalString).toBe("function");
|
|
||||||
expect(typeof encoding.getOverflowing).toBe("function");
|
expect(typeof encoding.getOverflowing).toBe("function");
|
||||||
expect(typeof encoding.getSpecialTokensMask).toBe("function");
|
expect(typeof encoding.getSpecialTokensMask).toBe("function");
|
||||||
expect(typeof encoding.getTokens).toBe("function");
|
expect(typeof encoding.getTokens).toBe("function");
|
||||||
@ -31,109 +30,6 @@ describe("RawEncoding", () => {
|
|||||||
expect(typeof encoding.truncate).toBe("function");
|
expect(typeof encoding.truncate).toBe("function");
|
||||||
});
|
});
|
||||||
|
|
||||||
describe("getOriginalString", () => {
|
|
||||||
it("returns the full original string when no params", () => {
|
|
||||||
const original = encoding.getOriginalString();
|
|
||||||
expect(original).toEqual(originalString);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("accepts `undefined` as first parameter", () => {
|
|
||||||
const original = encoding.getOriginalString(undefined);
|
|
||||||
expect(original).toEqual(originalString);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("accepts `undefined` as second parameter", () => {
|
|
||||||
const original = encoding.getOriginalString(0, undefined);
|
|
||||||
expect(original).toEqual(originalString);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when `begin` is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(1000)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("returns the original string starting at the specified index", () => {
|
|
||||||
const original = encoding.getOriginalString(3);
|
|
||||||
expect(original).toEqual("name is john");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when `end` is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(0, 1000)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("returns the original string between the two specified indexes", () => {
|
|
||||||
const original = encoding.getOriginalString(3, 7);
|
|
||||||
expect(original).toEqual("name");
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("with only a negative `begin`", () => {
|
|
||||||
it("returns the original string counting from the end when in the range", () => {
|
|
||||||
const original = encoding.getOriginalString(-4);
|
|
||||||
expect(original).toEqual("john");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(-1000)).toThrow();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("with a positive `begin` and a negative `end`", () => {
|
|
||||||
it("returns the original string when resulting range is valid", () => {
|
|
||||||
const original = encoding.getOriginalString(3, -5);
|
|
||||||
expect(original).toEqual("name is");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
|
||||||
expect(() => encoding.getOriginalString(7, -10)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when `begin` is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(1000, -10)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when resulting `end` index is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(7, -1000)).toThrow();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("with a negative `begin` and a positive `end`", () => {
|
|
||||||
it("returns the original string when resulting range is valid", () => {
|
|
||||||
const original = encoding.getOriginalString(-7, 10);
|
|
||||||
expect(original).toEqual("is");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when resulting `begin` index is upper than `end`", () => {
|
|
||||||
expect(() => encoding.getOriginalString(-3, 5)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when `end` is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(-5, 1000)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when resulting `begin` index is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(-1000, 10)).toThrow();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("with negatives `begin` and `end`", () => {
|
|
||||||
it("returns the original string when resulting range is valid", () => {
|
|
||||||
const original = encoding.getOriginalString(-7, -5);
|
|
||||||
expect(original).toEqual("is");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
|
||||||
expect(() => encoding.getOriginalString(-5, -10)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when resulting `begin` index is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(-1000, -10)).toThrow();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("throws an error when resulting `end` index is out of range", () => {
|
|
||||||
expect(() => encoding.getOriginalString(-10, -1000)).toThrow();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("truncate", () => {
|
describe("truncate", () => {
|
||||||
it("accepts `undefined` as second parameter", () => {
|
it("accepts `undefined` as second parameter", () => {
|
||||||
expect(encoding.truncate(10, undefined)).toBeUndefined();
|
expect(encoding.truncate(10, undefined)).toBeUndefined();
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
/* eslint-disable @typescript-eslint/no-empty-function */
|
/* eslint-disable @typescript-eslint/no-empty-function */
|
||||||
|
|
||||||
import { promisify } from "util";
|
import { promisify } from "util";
|
||||||
@ -112,7 +113,7 @@ describe("Tokenizer", () => {
|
|||||||
[2, 6],
|
[2, 6],
|
||||||
[6, 8],
|
[6, 8],
|
||||||
[8, 12],
|
[8, 12],
|
||||||
[12, 16]
|
[0, 4]
|
||||||
]);
|
]);
|
||||||
expect(encoding.getOverflowing()).toEqual([]);
|
expect(encoding.getOverflowing()).toEqual([]);
|
||||||
expect(encoding.getSpecialTokensMask()).toEqual([0, 0, 0, 0, 0]);
|
expect(encoding.getSpecialTokensMask()).toEqual([0, 0, 0, 0, 0]);
|
||||||
|
11
bindings/node/lib/bindings/utils.d.ts
vendored
Normal file
11
bindings/node/lib/bindings/utils.d.ts
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
/**
|
||||||
|
* Returns a subpart of a string according to specified indexes, and respecting unicode characters
|
||||||
|
*
|
||||||
|
* @param text The text for which to return a subpart
|
||||||
|
* @param [begin] The index from which to start (can be negative).
|
||||||
|
* @param [end] The index (excluded) to which to stop (can be negative).
|
||||||
|
* Stopping at the end of the string if not provided.
|
||||||
|
* @returns The full string if no start/end indexes are provided,
|
||||||
|
* otherwise the original string between `begin` (included) and `end` (excluded)
|
||||||
|
*/
|
||||||
|
export function slice(text: string, start?: number, end?: number): string;
|
5
bindings/node/lib/bindings/utils.js
Normal file
5
bindings/node/lib/bindings/utils.js
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
const native = require("./native");
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
slice: native.utils_slice
|
||||||
|
};
|
107
bindings/node/lib/bindings/utils.test.ts
Normal file
107
bindings/node/lib/bindings/utils.test.ts
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import { slice } from "./utils";
|
||||||
|
|
||||||
|
describe("slice", () => {
|
||||||
|
const text = "My name is John 👋";
|
||||||
|
const sliceText = slice.bind({}, text);
|
||||||
|
|
||||||
|
it("returns the full text when no params", () => {
|
||||||
|
const sliced = sliceText();
|
||||||
|
expect(sliced).toEqual(text);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("accepts `undefined` as second parameter", () => {
|
||||||
|
const original = sliceText(undefined);
|
||||||
|
expect(original).toEqual(text);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("accepts `undefined` as third parameter", () => {
|
||||||
|
const original = sliceText(0, undefined);
|
||||||
|
expect(original).toEqual(text);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when `begin` is out of range", () => {
|
||||||
|
expect(() => sliceText(1000)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns slice starting at the specified index", () => {
|
||||||
|
const original = sliceText(3);
|
||||||
|
expect(original).toEqual("name is John 👋");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when `end` is out of range", () => {
|
||||||
|
expect(() => sliceText(0, 1000)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns the text between the two specified indexes", () => {
|
||||||
|
const original = sliceText(3, 7);
|
||||||
|
expect(original).toEqual("name");
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("with only a negative `begin`", () => {
|
||||||
|
it("returns the original string counting from the end when in the range", () => {
|
||||||
|
const original = sliceText(-1);
|
||||||
|
expect(original).toEqual("👋");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when out of range", () => {
|
||||||
|
expect(() => sliceText(-1000)).toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("with a positive `begin` and a negative `end`", () => {
|
||||||
|
it("returns correct slice when resulting range is valid", () => {
|
||||||
|
const original = sliceText(3, -7);
|
||||||
|
expect(original).toEqual("name is");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
||||||
|
expect(() => sliceText(7, -12)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when `begin` is out of range", () => {
|
||||||
|
expect(() => sliceText(1000, -12)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when resulting `end` index is out of range", () => {
|
||||||
|
expect(() => sliceText(7, -1000)).toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("with a negative `begin` and a positive `end`", () => {
|
||||||
|
it("returns correct slice when resulting range is valid", () => {
|
||||||
|
const original = sliceText(-9, 10);
|
||||||
|
expect(original).toEqual("is");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when resulting `begin` index is upper than `end`", () => {
|
||||||
|
expect(() => sliceText(-3, 5)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when `end` is out of range", () => {
|
||||||
|
expect(() => sliceText(-5, 1000)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when resulting `begin` index is out of range", () => {
|
||||||
|
expect(() => sliceText(-1000, 10)).toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("with negatives `begin` and `end`", () => {
|
||||||
|
it("returns correct slice when resulting range is valid", () => {
|
||||||
|
const original = sliceText(-9, -7);
|
||||||
|
expect(original).toEqual("is");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when resulting `end` index is lower than `begin`", () => {
|
||||||
|
expect(() => sliceText(-5, -10)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when resulting `begin` index is out of range", () => {
|
||||||
|
expect(() => sliceText(-1000, -10)).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws an error when resulting `end` index is out of range", () => {
|
||||||
|
expect(() => sliceText(-10, -1000)).toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
@ -5,7 +5,6 @@ export class Encoding {
|
|||||||
private _ids?: number[];
|
private _ids?: number[];
|
||||||
private _length?: number;
|
private _length?: number;
|
||||||
private _offsets?: [number, number][];
|
private _offsets?: [number, number][];
|
||||||
private _originalString?: string;
|
|
||||||
private _overflowing?: Encoding[];
|
private _overflowing?: Encoding[];
|
||||||
private _specialTokensMask?: number[];
|
private _specialTokensMask?: number[];
|
||||||
private _tokens?: string[];
|
private _tokens?: string[];
|
||||||
@ -103,27 +102,6 @@ export class Encoding {
|
|||||||
return (this._typeIds = this.rawEncoding.getTypeIds());
|
return (this._typeIds = this.rawEncoding.getTypeIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the original string
|
|
||||||
*
|
|
||||||
* @param [begin] The index from which to start (can be negative).
|
|
||||||
* @param [end] The index (excluded) to which to stop (can be negative).
|
|
||||||
* Stopping at the end of the string if not provided.
|
|
||||||
* @returns The full original string if no parameter is provided,
|
|
||||||
* otherwise the original string between `begin` and `end`
|
|
||||||
*/
|
|
||||||
getOriginalString(begin?: number, end?: number): string {
|
|
||||||
if (begin === undefined && end === undefined) {
|
|
||||||
if (this._originalString !== undefined) {
|
|
||||||
return this._originalString;
|
|
||||||
} else {
|
|
||||||
return (this._originalString = this.rawEncoding.getOriginalString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.rawEncoding.getOriginalString(begin, end);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pad the current Encoding at the given length
|
* Pad the current Encoding at the given length
|
||||||
*
|
*
|
||||||
@ -153,7 +131,6 @@ export class Encoding {
|
|||||||
"_ids",
|
"_ids",
|
||||||
"_length",
|
"_length",
|
||||||
"_offsets",
|
"_offsets",
|
||||||
"_originalString",
|
|
||||||
"_overflowing",
|
"_overflowing",
|
||||||
"_specialTokensMask",
|
"_specialTokensMask",
|
||||||
"_tokens",
|
"_tokens",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// export * from "./bindings";
|
// export * from "./bindings";
|
||||||
export * from "./implementations/tokenizers";
|
export * from "./implementations/tokenizers";
|
||||||
export * from "./bindings/enums";
|
export * from "./bindings/enums";
|
||||||
|
export { slice } from "./bindings/utils";
|
||||||
export { PaddingOptions, TruncationOptions } from "./bindings/tokenizer";
|
export { PaddingOptions, TruncationOptions } from "./bindings/tokenizer";
|
||||||
export { Encoding } from "./implementations/encoding";
|
export { Encoding } from "./implementations/encoding";
|
||||||
|
102
bindings/node/native/src/container.rs
Normal file
102
bindings/node/native/src/container.rs
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
/// A Container type
|
||||||
|
///
|
||||||
|
/// Provides an interface to allow transfer of ownership between Node and Rust.
|
||||||
|
/// It either contains a Box with full ownership of the content, or a pointer to the content.
|
||||||
|
///
|
||||||
|
/// The main goal here is to allow Node calling into Rust to initialize some objects. Later
|
||||||
|
/// these objects may need to be used by Rust who will expect to take ownership. Since Node
|
||||||
|
/// does not allow any sort of ownership transfer, it will keep a reference to this object
|
||||||
|
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
|
||||||
|
/// and just keep a pointer in the Node object.
|
||||||
|
pub enum Container<T: ?Sized> {
|
||||||
|
Owned(Box<T>),
|
||||||
|
Pointer(*mut T),
|
||||||
|
Empty,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Container<T>
|
||||||
|
where
|
||||||
|
T: ?Sized,
|
||||||
|
{
|
||||||
|
pub fn from_ref(reference: &Box<T>) -> Self {
|
||||||
|
let content: *const T = &**reference;
|
||||||
|
Container::Pointer(content as *mut _)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_owned(&self) -> bool {
|
||||||
|
if let Container::Owned(_) = &self {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
|
||||||
|
pub fn take(self) -> Option<Box<T>> {
|
||||||
|
match self {
|
||||||
|
Container::Owned(obj) => Some(obj),
|
||||||
|
Container::Pointer(_) => None,
|
||||||
|
Container::Empty => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Replace an empty content by the new provided owned one, otherwise do nothing
|
||||||
|
pub fn to_owned(&mut self, o: Box<T>) {
|
||||||
|
if let Container::Empty = self {
|
||||||
|
unsafe {
|
||||||
|
let new_container = Container::Owned(o);
|
||||||
|
std::ptr::write(self, new_container);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
|
||||||
|
pub fn to_pointer(&mut self) -> Option<Box<T>> {
|
||||||
|
if let Container::Owned(_) = self {
|
||||||
|
unsafe {
|
||||||
|
let old_container = std::ptr::read(self);
|
||||||
|
let ptr = Box::into_raw(old_container.take().unwrap());
|
||||||
|
let new_container = Container::Pointer(ptr);
|
||||||
|
std::ptr::write(self, new_container);
|
||||||
|
|
||||||
|
Some(Box::from_raw(ptr))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute<F, U>(&self, closure: F) -> U
|
||||||
|
where
|
||||||
|
F: FnOnce(Option<&Box<T>>) -> U,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Container::Owned(val) => closure(Some(val)),
|
||||||
|
Container::Pointer(ptr) => unsafe {
|
||||||
|
let val = Box::from_raw(*ptr);
|
||||||
|
let res = closure(Some(&val));
|
||||||
|
// We call this to make sure we don't drop the Box
|
||||||
|
Box::into_raw(val);
|
||||||
|
res
|
||||||
|
},
|
||||||
|
Container::Empty => closure(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute_mut<F, U>(&mut self, closure: F) -> U
|
||||||
|
where
|
||||||
|
F: FnOnce(Option<&mut Box<T>>) -> U,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Container::Owned(val) => closure(Some(val)),
|
||||||
|
Container::Pointer(ptr) => unsafe {
|
||||||
|
let mut val = Box::from_raw(*ptr);
|
||||||
|
let res = closure(Some(&mut val));
|
||||||
|
// We call this to make sure we don't drop the Box
|
||||||
|
Box::into_raw(val);
|
||||||
|
res
|
||||||
|
},
|
||||||
|
Container::Empty => closure(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use crate::utils::Container;
|
use crate::container::Container;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
|
|
||||||
/// Decoder
|
/// Decoder
|
||||||
|
@ -2,7 +2,7 @@ extern crate tokenizers as tk;
|
|||||||
|
|
||||||
use tk::tokenizer::PaddingDirection;
|
use tk::tokenizer::PaddingDirection;
|
||||||
|
|
||||||
use crate::utils::Container;
|
use crate::container::Container;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
|
|
||||||
/// Encoding
|
/// Encoding
|
||||||
@ -159,62 +159,6 @@ declare_types! {
|
|||||||
Ok(js_overflowings.upcast())
|
Ok(js_overflowings.upcast())
|
||||||
}
|
}
|
||||||
|
|
||||||
method getOriginalString(mut cx) {
|
|
||||||
// getOriginalString(begin?: number, end?: number)
|
|
||||||
let this = cx.this();
|
|
||||||
|
|
||||||
let len_original = {
|
|
||||||
let guard = cx.lock();
|
|
||||||
let len = this.borrow(&guard).encoding.execute(|encoding| {
|
|
||||||
encoding.unwrap().get_normalized().len_original()
|
|
||||||
});
|
|
||||||
len
|
|
||||||
};
|
|
||||||
|
|
||||||
let get_index = |x: i32| -> usize {
|
|
||||||
if x >= 0 {
|
|
||||||
x as usize
|
|
||||||
} else {
|
|
||||||
(len_original as i32 + x) as usize
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let begin_index = if let Some(begin_arg) = cx.argument_opt(0) {
|
|
||||||
if begin_arg.downcast::<JsUndefined>().is_err() {
|
|
||||||
let begin = begin_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
|
||||||
get_index(begin)
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
let end_index = if let Some(end_arg) = cx.argument_opt(1) {
|
|
||||||
if end_arg.downcast::<JsUndefined>().is_err() {
|
|
||||||
let end = end_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
|
||||||
get_index(end)
|
|
||||||
} else {
|
|
||||||
len_original
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
len_original
|
|
||||||
};
|
|
||||||
|
|
||||||
let original = {
|
|
||||||
let guard = cx.lock();
|
|
||||||
let original = this.borrow(&guard).encoding.execute(|encoding| {
|
|
||||||
encoding.unwrap().get_normalized().get_range_original(begin_index..end_index)
|
|
||||||
});
|
|
||||||
original
|
|
||||||
};
|
|
||||||
if let Some(original) = original {
|
|
||||||
Ok(cx.string(original).upcast())
|
|
||||||
} else {
|
|
||||||
cx.throw_error("Error in offsets")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
method pad(mut cx) {
|
method pad(mut cx) {
|
||||||
// pad(length: number, options?: {
|
// pad(length: number, options?: {
|
||||||
// direction?: 'left' | 'right' = 'right',
|
// direction?: 'left' | 'right' = 'right',
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
extern crate neon;
|
extern crate neon;
|
||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
|
mod container;
|
||||||
mod decoders;
|
mod decoders;
|
||||||
mod encoding;
|
mod encoding;
|
||||||
mod models;
|
mod models;
|
||||||
@ -31,6 +32,8 @@ register_module!(mut m, {
|
|||||||
pre_tokenizers::register(&mut m, "pre_tokenizers")?;
|
pre_tokenizers::register(&mut m, "pre_tokenizers")?;
|
||||||
// Trainers
|
// Trainers
|
||||||
trainers::register(&mut m, "trainers")?;
|
trainers::register(&mut m, "trainers")?;
|
||||||
|
// Utils
|
||||||
|
utils::register(&mut m, "utils")?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
});
|
});
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
|
use crate::container::Container;
|
||||||
use crate::tasks::models::{BPEFromFilesTask, WordPieceFromFilesTask};
|
use crate::tasks::models::{BPEFromFilesTask, WordPieceFromFilesTask};
|
||||||
use crate::utils::Container;
|
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use crate::utils::Container;
|
use crate::container::Container;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
|
|
||||||
/// Normalizer
|
/// Normalizer
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use crate::utils::Container;
|
use crate::container::Container;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
|
|
||||||
/// PreTokenizers
|
/// PreTokenizers
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use crate::utils::Container;
|
use crate::container::Container;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
|
|
||||||
/// Processor
|
/// Processor
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
|
use crate::container::Container;
|
||||||
use crate::decoders::JsDecoder;
|
use crate::decoders::JsDecoder;
|
||||||
use crate::models::JsModel;
|
use crate::models::JsModel;
|
||||||
use crate::normalizers::JsNormalizer;
|
use crate::normalizers::JsNormalizer;
|
||||||
@ -7,7 +8,6 @@ use crate::pre_tokenizers::JsPreTokenizer;
|
|||||||
use crate::processors::JsPostProcessor;
|
use crate::processors::JsPostProcessor;
|
||||||
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer};
|
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer};
|
||||||
use crate::trainers::JsTrainer;
|
use crate::trainers::JsTrainer;
|
||||||
use crate::utils::Container;
|
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
|
|
||||||
use tk::tokenizer::{
|
use tk::tokenizer::{
|
||||||
@ -64,17 +64,43 @@ declare_types! {
|
|||||||
let mut with_added_tokens = true;
|
let mut with_added_tokens = true;
|
||||||
if let Some(args) = cx.argument_opt(0) {
|
if let Some(args) = cx.argument_opt(0) {
|
||||||
if args.downcast::<JsUndefined>().is_err() {
|
if args.downcast::<JsUndefined>().is_err() {
|
||||||
with_added_tokens = args.downcast::<JsBoolean>().or_throw(&mut cx)?.value() as bool;
|
with_added_tokens = args.downcast::<JsBoolean>()
|
||||||
|
.or_throw(&mut cx)?
|
||||||
|
.value() as bool;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut this = cx.this();
|
let mut this = cx.this();
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
let size = this.borrow_mut(&guard).tokenizer.get_vocab_size(with_added_tokens);
|
let size = this.borrow_mut(&guard)
|
||||||
|
.tokenizer
|
||||||
|
.get_vocab_size(with_added_tokens);
|
||||||
|
|
||||||
Ok(cx.number(size as f64).upcast())
|
Ok(cx.number(size as f64).upcast())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
method normalize(mut cx) {
|
||||||
|
// normalize(sentence: String) -> String
|
||||||
|
let sentence = cx.argument::<JsString>(0)?.value();
|
||||||
|
|
||||||
|
let this = cx.this();
|
||||||
|
let guard = cx.lock();
|
||||||
|
|
||||||
|
let result = {
|
||||||
|
this.borrow(&guard)
|
||||||
|
.tokenizer
|
||||||
|
.normalize(&sentence)
|
||||||
|
.map(|s| s.get().to_owned())
|
||||||
|
};
|
||||||
|
let normalized = result
|
||||||
|
.map_err(|e| {
|
||||||
|
cx.throw_error::<_, ()>(format!("{}", e))
|
||||||
|
.unwrap_err()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(cx.string(normalized).upcast())
|
||||||
|
}
|
||||||
|
|
||||||
method encode(mut cx) {
|
method encode(mut cx) {
|
||||||
// encode(
|
// encode(
|
||||||
// sentence: String,
|
// sentence: String,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use crate::utils::Container;
|
use crate::container::Container;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
@ -1,102 +1,51 @@
|
|||||||
/// A Container type
|
extern crate tokenizers as tk;
|
||||||
///
|
|
||||||
/// Provides an interface to allow transfer of ownership between Node and Rust.
|
|
||||||
/// It either contains a Box with full ownership of the content, or a pointer to the content.
|
|
||||||
///
|
|
||||||
/// The main goal here is to allow Node calling into Rust to initialize some objects. Later
|
|
||||||
/// these objects may need to be used by Rust who will expect to take ownership. Since Node
|
|
||||||
/// does not allow any sort of ownership transfer, it will keep a reference to this object
|
|
||||||
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
|
|
||||||
/// and just keep a pointer in the Node object.
|
|
||||||
pub enum Container<T: ?Sized> {
|
|
||||||
Owned(Box<T>),
|
|
||||||
Pointer(*mut T),
|
|
||||||
Empty,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Container<T>
|
use neon::prelude::*;
|
||||||
where
|
|
||||||
T: ?Sized,
|
|
||||||
{
|
|
||||||
pub fn from_ref(reference: &Box<T>) -> Self {
|
|
||||||
let content: *const T = &**reference;
|
|
||||||
Container::Pointer(content as *mut _)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_owned(&self) -> bool {
|
/// slice(s: string, start?: number, end?: number)
|
||||||
if let Container::Owned(_) = &self {
|
fn slice(mut cx: FunctionContext) -> JsResult<JsString> {
|
||||||
true
|
let s = cx.argument::<JsString>(0)?.value();
|
||||||
|
let len = s.chars().count();
|
||||||
|
|
||||||
|
let get_index = |x: i32| -> usize {
|
||||||
|
if x >= 0 {
|
||||||
|
x as usize
|
||||||
} else {
|
} else {
|
||||||
false
|
(len as i32 + x) as usize
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
|
let begin_index = if let Some(begin_arg) = cx.argument_opt(1) {
|
||||||
pub fn take(self) -> Option<Box<T>> {
|
if begin_arg.downcast::<JsUndefined>().is_err() {
|
||||||
match self {
|
let begin = begin_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
||||||
Container::Owned(obj) => Some(obj),
|
get_index(begin)
|
||||||
Container::Pointer(_) => None,
|
} else {
|
||||||
Container::Empty => None,
|
0
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Replace an empty content by the new provided owned one, otherwise do nothing
|
|
||||||
pub fn to_owned(&mut self, o: Box<T>) {
|
|
||||||
if let Container::Empty = self {
|
|
||||||
unsafe {
|
|
||||||
let new_container = Container::Owned(o);
|
|
||||||
std::ptr::write(self, new_container);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
|
|
||||||
pub fn to_pointer(&mut self) -> Option<Box<T>> {
|
|
||||||
if let Container::Owned(_) = self {
|
|
||||||
unsafe {
|
|
||||||
let old_container = std::ptr::read(self);
|
|
||||||
let ptr = Box::into_raw(old_container.take().unwrap());
|
|
||||||
let new_container = Container::Pointer(ptr);
|
|
||||||
std::ptr::write(self, new_container);
|
|
||||||
|
|
||||||
Some(Box::from_raw(ptr))
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
let end_index = if let Some(end_arg) = cx.argument_opt(2) {
|
||||||
|
if end_arg.downcast::<JsUndefined>().is_err() {
|
||||||
|
let end = end_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
|
||||||
|
get_index(end)
|
||||||
|
} else {
|
||||||
|
len
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
len
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(slice) = tk::tokenizer::get_range_of(&s, begin_index..end_index) {
|
||||||
|
Ok(cx.string(slice))
|
||||||
|
} else {
|
||||||
|
cx.throw_error("Error in offsets")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execute<F, U>(&self, closure: F) -> U
|
/// Register everything here
|
||||||
where
|
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||||
F: FnOnce(Option<&Box<T>>) -> U,
|
m.export_function(&format!("{}_slice", prefix), slice)?;
|
||||||
{
|
Ok(())
|
||||||
match self {
|
|
||||||
Container::Owned(val) => closure(Some(val)),
|
|
||||||
Container::Pointer(ptr) => unsafe {
|
|
||||||
let val = Box::from_raw(*ptr);
|
|
||||||
let res = closure(Some(&val));
|
|
||||||
// We call this to make sure we don't drop the Box
|
|
||||||
Box::into_raw(val);
|
|
||||||
res
|
|
||||||
},
|
|
||||||
Container::Empty => closure(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn execute_mut<F, U>(&mut self, closure: F) -> U
|
|
||||||
where
|
|
||||||
F: FnOnce(Option<&mut Box<T>>) -> U,
|
|
||||||
{
|
|
||||||
match self {
|
|
||||||
Container::Owned(val) => closure(Some(val)),
|
|
||||||
Container::Pointer(ptr) => unsafe {
|
|
||||||
let mut val = Box::from_raw(*ptr);
|
|
||||||
let res = closure(Some(&mut val));
|
|
||||||
// We call this to make sure we don't drop the Box
|
|
||||||
Box::into_raw(val);
|
|
||||||
res
|
|
||||||
},
|
|
||||||
Container::Empty => closure(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -6,18 +6,28 @@ a high number of files as it avoids having too many progress bars on screen.
|
|||||||
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
|
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
|
||||||
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
|
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
|
||||||
whitespaces are part of the actual token.
|
whitespaces are part of the actual token.
|
||||||
It has been added to `ByteLevelBPETokenizer` and but it is off by default (`trim_offsets=False`).
|
It has been added to `ByteLevelBPETokenizer` but it is off by default (`trim_offsets=False`).
|
||||||
- `encode` and `encode_batch` no take a new optional argument, specifying whether we should add the
|
([#188](https://github.com/huggingface/tokenizers/pull/188))
|
||||||
special tokens. This stays activated by default.
|
- `encode` and `encode_batch` now take a new optional argument, specifying whether we should add the
|
||||||
|
special tokens. This is activated by default. ([#193](https://github.com/huggingface/tokenizers/pull/193))
|
||||||
|
- `original_str` and `normalized_str` have been removed from the `Encoding` returned by `encode` and
|
||||||
|
`encode_batch`. This brings a reduction of 70% the memory footprint.
|
||||||
|
([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||||
|
|
||||||
## Fixes:
|
## Fixes:
|
||||||
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
|
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE ([#193](https://github.com/huggingface/tokenizers/pull/193)):
|
||||||
- when `add_prefix_space=True`
|
- when `add_prefix_space=True`
|
||||||
- when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156))
|
- when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156))
|
||||||
|
|
||||||
## How to migrate:
|
## How to migrate:
|
||||||
- Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant. If you are
|
- Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant. If you are
|
||||||
using `ByteLevelBPETokenizer`, this option is disabled by default (`trim_offsets=False`).
|
using `ByteLevelBPETokenizer`, this option is disabled by default (`trim_offsets=False`).
|
||||||
|
- Access to the `original_str` on the `Encoding` has been removed. The original string is the input
|
||||||
|
of `encode` so it didn't make sense to keep it here.
|
||||||
|
- No need to call `original_str.offsets(offsets[N])` to convert offsets to the original string. They
|
||||||
|
are now relative to the original string by default.
|
||||||
|
- Access to the `normalized_str` on the `Encoding` has been removed. Can be retrieved by calling
|
||||||
|
`normalize(sequence)` on the `Tokenizer`
|
||||||
|
|
||||||
# v0.6.0
|
# v0.6.0
|
||||||
|
|
||||||
|
@ -1,115 +1,11 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use crate::error::PyError;
|
use crate::error::PyError;
|
||||||
use pyo3::exceptions;
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
use pyo3::{PyMappingProtocol, PyObjectProtocol, PySequenceProtocol};
|
use pyo3::{PyObjectProtocol, PySequenceProtocol};
|
||||||
use tk::tokenizer::PaddingDirection;
|
use tk::tokenizer::PaddingDirection;
|
||||||
|
|
||||||
fn get_range(item: PyObject, max_len: usize) -> PyResult<std::ops::Range<usize>> {
|
|
||||||
let gil = Python::acquire_gil();
|
|
||||||
let py = gil.python();
|
|
||||||
|
|
||||||
let slice = if let Ok(index) = item.extract::<isize>(py) {
|
|
||||||
if index >= max_len as isize || index < -(max_len as isize) {
|
|
||||||
Err(exceptions::IndexError::py_err("Index out of bounds"))
|
|
||||||
} else {
|
|
||||||
Ok(if index == -1 {
|
|
||||||
PySlice::new(py, index, max_len as isize, 1)
|
|
||||||
} else {
|
|
||||||
PySlice::new(py, index, index + 1, 1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else if let Ok(slice) = item.cast_as::<PySlice>(py) {
|
|
||||||
Ok(slice)
|
|
||||||
} else if let Ok(offset) = item.cast_as::<PyTuple>(py) {
|
|
||||||
if offset.len() == 2 {
|
|
||||||
let start = offset.get_item(0).extract::<isize>()?;
|
|
||||||
let end = offset.get_item(1).extract::<isize>()?;
|
|
||||||
Ok(PySlice::new(py, start, end, 1))
|
|
||||||
} else {
|
|
||||||
Err(exceptions::TypeError::py_err("Expected Tuple[int, int]"))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(exceptions::TypeError::py_err(
|
|
||||||
"Expected number or slice or Tuple[int, int]",
|
|
||||||
))
|
|
||||||
}?;
|
|
||||||
|
|
||||||
// Find out range from the slice
|
|
||||||
let len: std::os::raw::c_long = (max_len as i32) as _;
|
|
||||||
let PySliceIndices { start, stop, .. } = slice.indices(len)?;
|
|
||||||
|
|
||||||
Ok(start as usize..stop as usize)
|
|
||||||
}
|
|
||||||
|
|
||||||
enum IndexableStringType {
|
|
||||||
Original,
|
|
||||||
Normalized,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyclass(dict)]
|
|
||||||
pub struct IndexableString {
|
|
||||||
s: tk::tokenizer::NormalizedString,
|
|
||||||
t: IndexableStringType,
|
|
||||||
}
|
|
||||||
#[pymethods]
|
|
||||||
impl IndexableString {
|
|
||||||
fn offsets(&self, item: PyObject) -> PyResult<Option<(usize, usize)>> {
|
|
||||||
let range = get_range(item, self.s.len())?;
|
|
||||||
|
|
||||||
match self.t {
|
|
||||||
IndexableStringType::Original => Ok(self
|
|
||||||
.s
|
|
||||||
.get_original_offsets(range)
|
|
||||||
.map(|range| (range.start, range.end))),
|
|
||||||
IndexableStringType::Normalized => Ok(Some((range.start, range.end))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyproto]
|
|
||||||
impl PyObjectProtocol for IndexableString {
|
|
||||||
fn __repr__(&self) -> PyResult<String> {
|
|
||||||
Ok(match self.t {
|
|
||||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
|
||||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __str__(&self) -> PyResult<String> {
|
|
||||||
Ok(match self.t {
|
|
||||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
|
||||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyproto]
|
|
||||||
impl PyMappingProtocol for IndexableString {
|
|
||||||
fn __getitem__(&self, item: PyObject) -> PyResult<String> {
|
|
||||||
// Find out the range
|
|
||||||
let range = get_range(item, self.s.len())?;
|
|
||||||
|
|
||||||
// Get the range from the relevant string
|
|
||||||
let s = match self.t {
|
|
||||||
IndexableStringType::Original => self.s.get_range_original(range),
|
|
||||||
IndexableStringType::Normalized => self.s.get_range(range),
|
|
||||||
};
|
|
||||||
|
|
||||||
s.map(|s| s.to_owned())
|
|
||||||
.ok_or_else(|| exceptions::IndexError::py_err("Wrong offsets"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __len__(self) -> PyResult<usize> {
|
|
||||||
Ok(match self.t {
|
|
||||||
IndexableStringType::Original => self.s.len_original(),
|
|
||||||
IndexableStringType::Normalized => self.s.len(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
pub struct Encoding {
|
pub struct Encoding {
|
||||||
@ -127,7 +23,7 @@ impl PyObjectProtocol for Encoding {
|
|||||||
fn __repr__(&self) -> PyResult<String> {
|
fn __repr__(&self) -> PyResult<String> {
|
||||||
Ok(format!(
|
Ok(format!(
|
||||||
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
|
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
|
||||||
attention_mask, special_tokens_mask, overflowing, original_str, normalized_str])",
|
attention_mask, special_tokens_mask, overflowing])",
|
||||||
self.encoding.get_ids().len()
|
self.encoding.get_ids().len()
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@ -142,50 +38,6 @@ impl PySequenceProtocol for Encoding {
|
|||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl Encoding {
|
impl Encoding {
|
||||||
#[getter]
|
|
||||||
fn get_normalized_str(&self) -> IndexableString {
|
|
||||||
IndexableString {
|
|
||||||
s: self.encoding.get_normalized().clone(),
|
|
||||||
t: IndexableStringType::Normalized,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
fn get_original_str(&self) -> IndexableString {
|
|
||||||
IndexableString {
|
|
||||||
s: self.encoding.get_normalized().clone(),
|
|
||||||
t: IndexableStringType::Original,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[args(kwargs = "**")]
|
|
||||||
fn get_range(
|
|
||||||
&self,
|
|
||||||
range: (usize, usize),
|
|
||||||
kwargs: Option<&PyDict>,
|
|
||||||
) -> PyResult<Option<String>> {
|
|
||||||
let mut original = false;
|
|
||||||
if let Some(kwargs) = kwargs {
|
|
||||||
if let Some(koriginal) = kwargs.get_item("original") {
|
|
||||||
original = koriginal.extract()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if original {
|
|
||||||
Ok(self
|
|
||||||
.encoding
|
|
||||||
.get_normalized()
|
|
||||||
.get_range_original(range.0..range.1)
|
|
||||||
.map(|s| s.to_owned()))
|
|
||||||
} else {
|
|
||||||
Ok(self
|
|
||||||
.encoding
|
|
||||||
.get_normalized()
|
|
||||||
.get_range(range.0..range.1)
|
|
||||||
.map(|s| s.to_owned()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_ids(&self) -> Vec<u32> {
|
fn get_ids(&self) -> Vec<u32> {
|
||||||
self.encoding.get_ids().to_vec()
|
self.encoding.get_ids().to_vec()
|
||||||
|
@ -159,6 +159,15 @@ impl Tokenizer {
|
|||||||
self.tokenizer.with_padding(None);
|
self.tokenizer.with_padding(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize(&self, sentence: &str) -> PyResult<String> {
|
||||||
|
ToPyResult(
|
||||||
|
self.tokenizer
|
||||||
|
.normalize(sentence)
|
||||||
|
.map(|s| s.get().to_owned()),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
#[args(add_special_tokens = true)]
|
#[args(add_special_tokens = true)]
|
||||||
fn encode(
|
fn encode(
|
||||||
&self,
|
&self,
|
||||||
|
@ -16,32 +16,9 @@ from typing import Optional, Union, List, Tuple
|
|||||||
|
|
||||||
Offsets = Tuple[int, int]
|
Offsets = Tuple[int, int]
|
||||||
|
|
||||||
class IndexableString:
|
|
||||||
"""
|
|
||||||
Works almost like a `str`, but allows indexing on offsets
|
|
||||||
provided on an `Encoding`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def offsets(self, offsets: Tuple[int, int]) -> Optional[Tuple[int, int]]:
|
|
||||||
""" Convert the Encoding's offsets to the current string.
|
|
||||||
|
|
||||||
`Encoding` provides a list of offsets that are actually offsets to the Normalized
|
|
||||||
version of text. Calling this method with the offsets provided by `Encoding` will make
|
|
||||||
sure that said offsets can be used to index the `str` directly.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Encoding:
|
class Encoding:
|
||||||
""" An Encoding as returned by the Tokenizer """
|
""" An Encoding as returned by the Tokenizer """
|
||||||
|
|
||||||
@property
|
|
||||||
def normalized_str(self) -> IndexableString:
|
|
||||||
""" The normalized string """
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def original_str(self) -> IndexableString:
|
|
||||||
""" The original string """
|
|
||||||
pass
|
|
||||||
@property
|
@property
|
||||||
def ids(self) -> List[int]:
|
def ids(self) -> List[int]:
|
||||||
""" The tokenized ids """
|
""" The tokenized ids """
|
||||||
@ -244,6 +221,17 @@ class Tokenizer:
|
|||||||
def no_padding(self):
|
def no_padding(self):
|
||||||
""" Disable padding """
|
""" Disable padding """
|
||||||
pass
|
pass
|
||||||
|
def normalize(self, sequence: str) -> str:
|
||||||
|
""" Normalize the given sequence
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence: str:
|
||||||
|
The sequence to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized string
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def encode(
|
def encode(
|
||||||
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
|
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
|
||||||
) -> Encoding:
|
) -> Encoding:
|
||||||
|
@ -125,6 +125,18 @@ class BaseTokenizer:
|
|||||||
"""
|
"""
|
||||||
return self._tokenizer.add_special_tokens(special_tokens)
|
return self._tokenizer.add_special_tokens(special_tokens)
|
||||||
|
|
||||||
|
def normalize(self, sequence: str) -> str:
|
||||||
|
""" Normalize the given sequence
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence: str:
|
||||||
|
The sequence to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized string
|
||||||
|
"""
|
||||||
|
return self._tokenizer.normalize(sequence)
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
|
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
|
||||||
) -> Encoding:
|
) -> Encoding:
|
||||||
|
@ -6,9 +6,16 @@ a high number of files as it avoids having too many progress bars on screen.
|
|||||||
- Improve BPE and WordPiece builders.
|
- Improve BPE and WordPiece builders.
|
||||||
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
|
- `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This
|
||||||
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
|
avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these
|
||||||
whitespaces are part of the actual token.
|
whitespaces are part of the actual token. ([#188](https://github.com/huggingface/tokenizers/pull/188))
|
||||||
- `encode` and `encode_batch` now take a new argument, specifying whether we should add the
|
- `encode` and `encode_batch` now take a new argument, specifying whether we should add the
|
||||||
special tokens.
|
special tokens. ([#193](https://github.com/huggingface/tokenizers/pull/193))
|
||||||
|
- The `NormalizedString` has been removed from the `Encoding`. It is now possible to retrieve it
|
||||||
|
by calling `normalized` on the `Tokenizer`. This brings a reduction of 70% of the memory footprint
|
||||||
|
([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||||
|
- The `NormalizedString` API has been improved. It is now possible to retrieve part of both strings
|
||||||
|
using both "normalized" or "original" offsets. ([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||||
|
- The offsets provided on `Encoding` are now relative to the original string, and not the normalized
|
||||||
|
one anymore. ([#197](https://github.com/huggingface/tokenizers/pull/197))
|
||||||
|
|
||||||
## Fixes:
|
## Fixes:
|
||||||
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
|
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
|
DATA_DIR = data
|
||||||
BENCHMARK_DIR = benches
|
BENCHMARK_DIR = benches
|
||||||
BENCHMARK_RESOURCES = $(BENCHMARK_DIR)/gpt2-vocab.json $(BENCHMARK_DIR)/gpt2-merges.txt $(BENCHMARK_DIR)/big.txt
|
TESTS_DIR = tests
|
||||||
|
|
||||||
|
dir_guard=@mkdir -p $(@D)
|
||||||
|
|
||||||
|
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt
|
||||||
|
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt
|
||||||
|
TESTS_RESOURCES = $(SHARED_RESOURCES)
|
||||||
|
|
||||||
.PHONY : build
|
.PHONY : build
|
||||||
build :
|
build :
|
||||||
@ -20,7 +27,7 @@ lint :
|
|||||||
cargo clippy --all-targets --all-features -- -D warnings
|
cargo clippy --all-targets --all-features -- -D warnings
|
||||||
|
|
||||||
.PHONY : test
|
.PHONY : test
|
||||||
test :
|
test : $(TESTS_RESOURCES)
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
.PHONY : doc
|
.PHONY : doc
|
||||||
@ -38,8 +45,10 @@ all-checks : lint test doc
|
|||||||
bench : $(BENCHMARK_RESOURCES)
|
bench : $(BENCHMARK_RESOURCES)
|
||||||
cargo bench -- --verbose
|
cargo bench -- --verbose
|
||||||
|
|
||||||
$(BENCHMARK_DIR)/gpt2-% :
|
$(DATA_DIR)/gpt2-% :
|
||||||
|
$(dir_guard)
|
||||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-$* -O $@
|
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-$* -O $@
|
||||||
|
|
||||||
$(BENCHMARK_DIR)/big.txt :
|
$(DATA_DIR)/big.txt :
|
||||||
|
$(dir_guard)
|
||||||
wget https://norvig.com/big.txt -O $@
|
wget https://norvig.com/big.txt -O $@
|
||||||
|
2
tokenizers/benches/.gitignore
vendored
2
tokenizers/benches/.gitignore
vendored
@ -1,2 +0,0 @@
|
|||||||
*.txt
|
|
||||||
*.json
|
|
@ -68,13 +68,13 @@ fn iter_bench_encode_batch(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn bench_gpt2(c: &mut Criterion) {
|
fn bench_gpt2(c: &mut Criterion) {
|
||||||
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
|
let bpe = BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let tokenizer = create_gpt2_tokenizer(bpe);
|
let tokenizer = create_gpt2_tokenizer(bpe);
|
||||||
let mut lines: Vec<EncodeInput> = vec![];
|
let mut lines: Vec<EncodeInput> = vec![];
|
||||||
let mut batches: Vec<Vec<EncodeInput>> = vec![vec![]];
|
let mut batches: Vec<Vec<EncodeInput>> = vec![vec![]];
|
||||||
for line in BufReader::new(File::open(Path::new("benches/big.txt")).unwrap())
|
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap())
|
||||||
.lines()
|
.lines()
|
||||||
.map(line_to_input)
|
.map(line_to_input)
|
||||||
{
|
{
|
||||||
@ -93,7 +93,7 @@ fn bench_gpt2(c: &mut Criterion) {
|
|||||||
b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches))
|
b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches))
|
||||||
});
|
});
|
||||||
|
|
||||||
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
|
let bpe = BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
|
||||||
.cache_capacity(0)
|
.cache_capacity(0)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -239,7 +239,7 @@ impl PostProcessor for ByteLevel {
|
|||||||
None => encoding,
|
None => encoding,
|
||||||
Some(mut pair) => {
|
Some(mut pair) => {
|
||||||
process_offsets(&mut pair);
|
process_offsets(&mut pair);
|
||||||
encoding.merge_with(pair);
|
encoding.merge_with(pair, false);
|
||||||
encoding
|
encoding
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -251,7 +251,9 @@ impl PostProcessor for ByteLevel {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::ByteLevel;
|
use super::ByteLevel;
|
||||||
use crate::tokenizer::{Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizer};
|
use crate::tokenizer::{
|
||||||
|
Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizer, Range,
|
||||||
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pre_tokenization() {
|
fn pre_tokenization() {
|
||||||
@ -391,13 +393,12 @@ mod tests {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
assert_eq!(input.get(), "iâŃ¢j");
|
assert_eq!(input.get(), "iâŃ¢j");
|
||||||
assert_eq!(input.get_range_original(1..4), Some("⭢".into()));
|
assert_eq!(input.get_range_original(Range::Normalized(1..4)), Some("⭢"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn processor_trims_offsets() {
|
fn processor_trims_offsets() {
|
||||||
let start = Encoding::new(
|
let start = Encoding::new(
|
||||||
NormalizedString::from(""),
|
|
||||||
vec![],
|
vec![],
|
||||||
vec![],
|
vec![],
|
||||||
vec![
|
vec![
|
||||||
@ -412,7 +413,6 @@ mod tests {
|
|||||||
vec![],
|
vec![],
|
||||||
);
|
);
|
||||||
let expected = Encoding::new(
|
let expected = Encoding::new(
|
||||||
NormalizedString::from(""),
|
|
||||||
vec![],
|
vec![],
|
||||||
vec![],
|
vec![],
|
||||||
vec![
|
vec![
|
||||||
@ -434,7 +434,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let mut pair_expected = expected.clone();
|
let mut pair_expected = expected.clone();
|
||||||
pair_expected.merge_with(expected);
|
pair_expected.merge_with(expected, false);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
pair_expected,
|
pair_expected,
|
||||||
bytelevel
|
bytelevel
|
||||||
|
@ -43,7 +43,6 @@ impl PostProcessor for BertProcessing {
|
|||||||
let attention_mask = vec![1; ids.len()];
|
let attention_mask = vec![1; ids.len()];
|
||||||
|
|
||||||
let mut new_encoding = Encoding::new(
|
let mut new_encoding = Encoding::new(
|
||||||
encoding.get_normalized().clone(),
|
|
||||||
ids,
|
ids,
|
||||||
type_ids,
|
type_ids,
|
||||||
tokens,
|
tokens,
|
||||||
@ -68,7 +67,6 @@ impl PostProcessor for BertProcessing {
|
|||||||
let attention_mask = vec![1; ids.len()];
|
let attention_mask = vec![1; ids.len()];
|
||||||
|
|
||||||
Encoding::new(
|
Encoding::new(
|
||||||
encoding.get_normalized().clone(),
|
|
||||||
ids,
|
ids,
|
||||||
type_ids,
|
type_ids,
|
||||||
tokens,
|
tokens,
|
||||||
@ -91,7 +89,6 @@ impl PostProcessor for BertProcessing {
|
|||||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
let new_pair_encoding = Encoding::new(
|
let new_pair_encoding = Encoding::new(
|
||||||
encoding.get_normalized().clone(),
|
|
||||||
pair_ids,
|
pair_ids,
|
||||||
pair_type_ids,
|
pair_type_ids,
|
||||||
pair_tokens,
|
pair_tokens,
|
||||||
@ -112,7 +109,6 @@ impl PostProcessor for BertProcessing {
|
|||||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
Encoding::new(
|
Encoding::new(
|
||||||
encoding.get_normalized().clone(),
|
|
||||||
pair_ids,
|
pair_ids,
|
||||||
pair_type_ids,
|
pair_type_ids,
|
||||||
pair_tokens,
|
pair_tokens,
|
||||||
@ -125,7 +121,7 @@ impl PostProcessor for BertProcessing {
|
|||||||
.collect(),
|
.collect(),
|
||||||
);
|
);
|
||||||
|
|
||||||
new_encoding.merge_with(new_pair_encoding);
|
new_encoding.merge_with(new_pair_encoding, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(new_encoding)
|
Ok(new_encoding)
|
||||||
|
@ -43,7 +43,6 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
let attention_mask = vec![1; ids.len()];
|
let attention_mask = vec![1; ids.len()];
|
||||||
|
|
||||||
let mut new_encoding = Encoding::new(
|
let mut new_encoding = Encoding::new(
|
||||||
encoding.get_normalized().clone(),
|
|
||||||
ids,
|
ids,
|
||||||
type_ids,
|
type_ids,
|
||||||
tokens,
|
tokens,
|
||||||
@ -68,7 +67,6 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
let new_pair_encoding = Encoding::new(
|
let new_pair_encoding = Encoding::new(
|
||||||
encoding.get_normalized().clone(),
|
|
||||||
pair_ids,
|
pair_ids,
|
||||||
pair_type_ids,
|
pair_type_ids,
|
||||||
pair_tokens,
|
pair_tokens,
|
||||||
@ -78,7 +76,7 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
encoding.take_overflowing(),
|
encoding.take_overflowing(),
|
||||||
);
|
);
|
||||||
|
|
||||||
new_encoding.merge_with(new_pair_encoding);
|
new_encoding.merge_with(new_pair_encoding, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(new_encoding)
|
Ok(new_encoding)
|
||||||
|
@ -1,26 +1,9 @@
|
|||||||
use crate::tokenizer::NormalizedString;
|
use crate::utils::padding::PaddingDirection;
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
/// The various possible padding directions.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum PaddingDirection {
|
|
||||||
Left,
|
|
||||||
Right,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::convert::AsRef<str> for PaddingDirection {
|
|
||||||
fn as_ref(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
PaddingDirection::Left => "left",
|
|
||||||
PaddingDirection::Right => "right",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents the output of a `Tokenizer`.
|
/// Represents the output of a `Tokenizer`.
|
||||||
#[derive(Default, PartialEq, Debug, Clone)]
|
#[derive(Default, PartialEq, Debug, Clone)]
|
||||||
pub struct Encoding {
|
pub struct Encoding {
|
||||||
normalized: NormalizedString,
|
|
||||||
ids: Vec<u32>,
|
ids: Vec<u32>,
|
||||||
type_ids: Vec<u32>,
|
type_ids: Vec<u32>,
|
||||||
tokens: Vec<String>,
|
tokens: Vec<String>,
|
||||||
@ -32,7 +15,6 @@ pub struct Encoding {
|
|||||||
impl Encoding {
|
impl Encoding {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
normalized: NormalizedString,
|
|
||||||
ids: Vec<u32>,
|
ids: Vec<u32>,
|
||||||
type_ids: Vec<u32>,
|
type_ids: Vec<u32>,
|
||||||
tokens: Vec<String>,
|
tokens: Vec<String>,
|
||||||
@ -42,7 +24,6 @@ impl Encoding {
|
|||||||
overflowing: Vec<Encoding>,
|
overflowing: Vec<Encoding>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Encoding {
|
Encoding {
|
||||||
normalized,
|
|
||||||
ids,
|
ids,
|
||||||
type_ids,
|
type_ids,
|
||||||
tokens,
|
tokens,
|
||||||
@ -53,10 +34,6 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_normalized(&self) -> &NormalizedString {
|
|
||||||
&self.normalized
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_tokens(&self) -> &[String] {
|
pub fn get_tokens(&self) -> &[String] {
|
||||||
&self.tokens[..]
|
&self.tokens[..]
|
||||||
}
|
}
|
||||||
@ -124,7 +101,6 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let o = Encoding {
|
let o = Encoding {
|
||||||
normalized: self.normalized.clone(),
|
|
||||||
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
|
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
|
||||||
type_ids: get_current_part(
|
type_ids: get_current_part(
|
||||||
&prev_encoding.type_ids,
|
&prev_encoding.type_ids,
|
||||||
@ -173,7 +149,7 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Merge ourself with the given `Encoding`. Happens in place.
|
/// Merge ourself with the given `Encoding`. Happens in place.
|
||||||
pub fn merge_with(&mut self, pair: Encoding) {
|
pub fn merge_with(&mut self, pair: Encoding, growing_offsets: bool) {
|
||||||
// Handle merging the overflowing parts too: Combine them all
|
// Handle merging the overflowing parts too: Combine them all
|
||||||
// In most of the cases, we expect `pair.overflowing.len() == 0`
|
// In most of the cases, we expect `pair.overflowing.len() == 0`
|
||||||
let mut overflowings = vec![];
|
let mut overflowings = vec![];
|
||||||
@ -182,33 +158,33 @@ impl Encoding {
|
|||||||
for self_o in &self.overflowing {
|
for self_o in &self.overflowing {
|
||||||
// 1. The pair itself
|
// 1. The pair itself
|
||||||
let mut n_encoding = self_o.clone();
|
let mut n_encoding = self_o.clone();
|
||||||
n_encoding.merge_with(pair.clone());
|
n_encoding.merge_with(pair.clone(), growing_offsets);
|
||||||
overflowings.push(n_encoding);
|
overflowings.push(n_encoding);
|
||||||
|
|
||||||
// 2. Its overflowings (this should rarely happen...)
|
// 2. Its overflowings (this should rarely happen...)
|
||||||
for other_o in &pair.overflowing {
|
for other_o in &pair.overflowing {
|
||||||
let mut n_encoding = self_o.clone();
|
let mut n_encoding = self_o.clone();
|
||||||
n_encoding.merge_with(other_o.clone());
|
n_encoding.merge_with(other_o.clone(), growing_offsets);
|
||||||
overflowings.push(n_encoding);
|
overflowings.push(n_encoding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 2. Ourself with all the other overflowings (this should rarely happen too...)
|
// 2. Ourself with all the other overflowings (this should rarely happen too...)
|
||||||
for other_o in &pair.overflowing {
|
for other_o in &pair.overflowing {
|
||||||
let mut n_encoding = self.clone();
|
let mut n_encoding = self.clone();
|
||||||
n_encoding.merge_with(other_o.clone());
|
n_encoding.merge_with(other_o.clone(), growing_offsets);
|
||||||
overflowings.push(n_encoding);
|
overflowings.push(n_encoding);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish by merging ourself with the other encoding
|
// Finish by merging ourself with the other encoding
|
||||||
self.normalized.merge_with(&pair.normalized);
|
|
||||||
self.ids.extend(pair.ids);
|
self.ids.extend(pair.ids);
|
||||||
self.type_ids.extend(pair.type_ids);
|
self.type_ids.extend(pair.type_ids);
|
||||||
self.tokens.extend(pair.tokens);
|
self.tokens.extend(pair.tokens);
|
||||||
|
|
||||||
let starting_offset = self
|
let starting_offset = if growing_offsets {
|
||||||
.offsets
|
self.offsets.last().map_or(0, |o| o.1)
|
||||||
.iter()
|
} else {
|
||||||
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
|
0
|
||||||
|
};
|
||||||
self.offsets.extend(
|
self.offsets.extend(
|
||||||
pair.offsets
|
pair.offsets
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -304,7 +280,6 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn merge_encodings() {
|
fn merge_encodings() {
|
||||||
let mut a = Encoding {
|
let mut a = Encoding {
|
||||||
normalized: NormalizedString::from("Hello "),
|
|
||||||
ids: vec![1],
|
ids: vec![1],
|
||||||
type_ids: vec![0],
|
type_ids: vec![0],
|
||||||
tokens: vec![String::from("Hello ")],
|
tokens: vec![String::from("Hello ")],
|
||||||
@ -314,7 +289,6 @@ mod tests {
|
|||||||
overflowing: vec![],
|
overflowing: vec![],
|
||||||
};
|
};
|
||||||
let b = Encoding {
|
let b = Encoding {
|
||||||
normalized: NormalizedString::from("World!"),
|
|
||||||
ids: vec![2],
|
ids: vec![2],
|
||||||
type_ids: vec![1],
|
type_ids: vec![1],
|
||||||
tokens: vec![String::from("World!")],
|
tokens: vec![String::from("World!")],
|
||||||
@ -323,12 +297,11 @@ mod tests {
|
|||||||
attention_mask: vec![1],
|
attention_mask: vec![1],
|
||||||
overflowing: vec![],
|
overflowing: vec![],
|
||||||
};
|
};
|
||||||
a.merge_with(b);
|
a.merge_with(b, true);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
a,
|
a,
|
||||||
Encoding {
|
Encoding {
|
||||||
normalized: NormalizedString::from("Hello World!"),
|
|
||||||
ids: vec![1, 2],
|
ids: vec![1, 2],
|
||||||
type_ids: vec![0, 1],
|
type_ids: vec![0, 1],
|
||||||
tokens: vec![String::from("Hello "), String::from("World!")],
|
tokens: vec![String::from("Hello "), String::from("World!")],
|
||||||
@ -343,7 +316,6 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn truncate() {
|
fn truncate() {
|
||||||
let mut a = Encoding {
|
let mut a = Encoding {
|
||||||
normalized: NormalizedString::from("Hello World!"),
|
|
||||||
ids: vec![1, 2, 3],
|
ids: vec![1, 2, 3],
|
||||||
type_ids: vec![0, 0, 0],
|
type_ids: vec![0, 0, 0],
|
||||||
tokens: vec![
|
tokens: vec![
|
||||||
@ -361,7 +333,6 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
a,
|
a,
|
||||||
Encoding {
|
Encoding {
|
||||||
normalized: NormalizedString::from("Hello World!"),
|
|
||||||
ids: vec![1, 2],
|
ids: vec![1, 2],
|
||||||
type_ids: vec![0, 0],
|
type_ids: vec![0, 0],
|
||||||
tokens: vec![String::from("Hello"), String::from("World")],
|
tokens: vec![String::from("Hello"), String::from("World")],
|
||||||
@ -369,7 +340,6 @@ mod tests {
|
|||||||
special_tokens_mask: vec![0, 0],
|
special_tokens_mask: vec![0, 0],
|
||||||
attention_mask: vec![1, 1],
|
attention_mask: vec![1, 1],
|
||||||
overflowing: vec![Encoding {
|
overflowing: vec![Encoding {
|
||||||
normalized: NormalizedString::from("Hello World!"),
|
|
||||||
ids: vec![3],
|
ids: vec![3],
|
||||||
type_ids: vec![0],
|
type_ids: vec![0],
|
||||||
tokens: vec![String::from("!")],
|
tokens: vec![String::from("!")],
|
||||||
|
@ -9,10 +9,9 @@
|
|||||||
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
|
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
|
||||||
//! ...).
|
//! ...).
|
||||||
|
|
||||||
pub use crate::utils::{
|
use crate::utils::iter::ResultShunt;
|
||||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
|
||||||
TruncationStrategy,
|
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||||
};
|
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
@ -31,6 +30,11 @@ pub use normalizer::*;
|
|||||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
pub type Offsets = (usize, usize);
|
pub type Offsets = (usize, usize);
|
||||||
|
|
||||||
|
/// Takes care of pre-processing strings.
|
||||||
|
pub trait Normalizer {
|
||||||
|
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
/// The `PreTokenizer` is in charge of doing the pre-segmentation step. It splits the given string
|
/// The `PreTokenizer` is in charge of doing the pre-segmentation step. It splits the given string
|
||||||
/// in multiple substrings, keeping track of the offsets of said substrings from the
|
/// in multiple substrings, keeping track of the offsets of said substrings from the
|
||||||
/// `NormalizedString`. In some occasions, the `PreTokenizer` might need to modify the given
|
/// `NormalizedString`. In some occasions, the `PreTokenizer` might need to modify the given
|
||||||
@ -71,7 +75,7 @@ impl dyn PostProcessor {
|
|||||||
match pair_encoding {
|
match pair_encoding {
|
||||||
None => Ok(encoding),
|
None => Ok(encoding),
|
||||||
Some(pair) => {
|
Some(pair) => {
|
||||||
encoding.merge_with(pair);
|
encoding.merge_with(pair, false);
|
||||||
Ok(encoding)
|
Ok(encoding)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -290,25 +294,44 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn num_added_tokens(&self, is_pair: bool) -> usize {
|
/// Normalize the given sentence and return the corresponding normalized string
|
||||||
self.post_processor
|
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
||||||
.as_ref()
|
let mut normalized = self
|
||||||
.map_or(0, |p| p.as_ref().added_tokens(is_pair))
|
.split_on_added_tokens(sentence)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(sentence, id)| -> Result<NormalizedString> {
|
||||||
|
if id.is_some() {
|
||||||
|
Ok(NormalizedString::from(&sentence))
|
||||||
|
} else {
|
||||||
|
let mut normalized = self.do_normalize(&sentence)?;
|
||||||
|
let _ = self.pre_tokenize(&mut normalized)?;
|
||||||
|
|
||||||
|
Ok(normalized)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let others = normalized.split_off(1);
|
||||||
|
let mut normalized: NormalizedString = normalized.into_iter().next().unwrap();
|
||||||
|
for n in others {
|
||||||
|
normalized.merge_with(&n);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(normalized)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encode the given sentence
|
/// Encode the given sentence
|
||||||
pub fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> Result<Encoding> {
|
pub fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> Result<Encoding> {
|
||||||
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
|
let generate_output =
|
||||||
|
move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> {
|
||||||
// First we need to split into as many sequences as needed to avoid splitting
|
// First we need to split into as many sequences as needed to avoid splitting
|
||||||
// on our added tokens
|
// on our added tokens
|
||||||
let mut encodings = self
|
let results = self.split_on_added_tokens(&sentence).into_iter().map(
|
||||||
.split_on_added_tokens(&sentence)
|
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||||
.into_iter()
|
|
||||||
.map(|(sentence, id)| -> Result<Encoding> {
|
|
||||||
// If this is one of our added tokens, lets return an encoding directly
|
// If this is one of our added tokens, lets return an encoding directly
|
||||||
if let Some(id) = id {
|
if let Some(id) = id {
|
||||||
return Ok(Encoding::new(
|
return Ok((
|
||||||
NormalizedString::from(&sentence),
|
Encoding::new(
|
||||||
vec![id],
|
vec![id],
|
||||||
vec![type_id],
|
vec![type_id],
|
||||||
vec![sentence.to_owned()],
|
vec![sentence.to_owned()],
|
||||||
@ -316,11 +339,13 @@ impl Tokenizer {
|
|||||||
vec![0],
|
vec![0],
|
||||||
vec![1],
|
vec![1],
|
||||||
vec![],
|
vec![],
|
||||||
|
),
|
||||||
|
NormalizedString::from(&sentence),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. Normalization
|
// 1. Normalization
|
||||||
let mut normalized = self.normalize(&sentence)?;
|
let mut normalized = self.do_normalize(&sentence)?;
|
||||||
|
|
||||||
// 2. Pre tokenization
|
// 2. Pre tokenization
|
||||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||||
@ -343,8 +368,8 @@ impl Tokenizer {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(Encoding::new(
|
Ok((
|
||||||
normalized,
|
Encoding::new(
|
||||||
ids,
|
ids,
|
||||||
vec![type_id; length],
|
vec![type_id; length],
|
||||||
tokens,
|
tokens,
|
||||||
@ -352,22 +377,33 @@ impl Tokenizer {
|
|||||||
vec![0; length],
|
vec![0; length],
|
||||||
vec![1; length],
|
vec![1; length],
|
||||||
vec![],
|
vec![],
|
||||||
|
),
|
||||||
|
normalized,
|
||||||
))
|
))
|
||||||
})
|
},
|
||||||
.collect::<Result<Vec<Encoding>>>()?;
|
);
|
||||||
|
|
||||||
|
let (mut encodings, mut normalized) =
|
||||||
|
ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
|
||||||
|
|
||||||
if encodings.is_empty() {
|
if encodings.is_empty() {
|
||||||
return Ok(Encoding::default());
|
return Ok((Encoding::default(), NormalizedString::from("")));
|
||||||
}
|
}
|
||||||
|
|
||||||
let others = encodings.split_off(1);
|
let others = encodings.split_off(1);
|
||||||
let mut first: Encoding = encodings.into_iter().next().unwrap();
|
let mut first: Encoding = encodings.into_iter().next().unwrap();
|
||||||
|
|
||||||
for encoding in others {
|
for encoding in others {
|
||||||
first.merge_with(encoding);
|
first.merge_with(encoding, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(first)
|
let others = normalized.split_off(1);
|
||||||
|
let mut normalized: NormalizedString = normalized.into_iter().next().unwrap();
|
||||||
|
for n in others {
|
||||||
|
normalized.merge_with(&n);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((first, normalized))
|
||||||
};
|
};
|
||||||
|
|
||||||
let (sentence, pair) = match input {
|
let (sentence, pair) = match input {
|
||||||
@ -375,14 +411,37 @@ impl Tokenizer {
|
|||||||
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
|
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let encoding = generate_output(sentence, 0)?;
|
let (encoding, normalized) = generate_output(sentence, 0)?;
|
||||||
let pair_encoding = match pair {
|
let (pair_encoding, pair_normalized) = match pair {
|
||||||
Some(pair) => Some(generate_output(pair, 1)?),
|
Some(pair) => {
|
||||||
None => None,
|
let (e, n) = generate_output(pair, 1)?;
|
||||||
|
(Some(e), Some(n))
|
||||||
|
}
|
||||||
|
None => (None, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
// 4. Post processing
|
// 4. Post processing
|
||||||
self.post_process(encoding, pair_encoding, add_special_tokens)
|
let mut output = self.post_process(encoding, pair_encoding, add_special_tokens)?;
|
||||||
|
|
||||||
|
// 5. Convert offsets back to original string
|
||||||
|
let mut current_offset = (0, 0);
|
||||||
|
let mut n_source = &normalized;
|
||||||
|
output
|
||||||
|
.get_offsets_mut()
|
||||||
|
.iter_mut()
|
||||||
|
.for_each(|(start, end)| {
|
||||||
|
if (*start, *end) < current_offset {
|
||||||
|
n_source = &pair_normalized.as_ref().unwrap_or(&normalized);
|
||||||
|
}
|
||||||
|
current_offset = (*start, *end);
|
||||||
|
let (s, e) = n_source
|
||||||
|
.convert_offsets(Range::Normalized(*start..*end))
|
||||||
|
.map_or((*start, *end), |range| (range.start, range.end));
|
||||||
|
*start = s;
|
||||||
|
*end = e;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encode all the sentences in parallel, using multiple threads
|
/// Encode all the sentences in parallel, using multiple threads
|
||||||
@ -471,7 +530,7 @@ impl Tokenizer {
|
|||||||
match file.read_line(&mut buf)? {
|
match file.read_line(&mut buf)? {
|
||||||
0 => break,
|
0 => break,
|
||||||
b => {
|
b => {
|
||||||
let mut normalized = self.normalize(&buf)?;
|
let mut normalized = self.do_normalize(&buf)?;
|
||||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||||
trainer.process_tokens(
|
trainer.process_tokens(
|
||||||
&mut words,
|
&mut words,
|
||||||
@ -522,7 +581,7 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Normalization logic, go through all normalizers
|
/// Normalization logic, go through all normalizers
|
||||||
fn normalize(&self, sequence: &str) -> Result<NormalizedString> {
|
fn do_normalize(&self, sequence: &str) -> Result<NormalizedString> {
|
||||||
let mut normalized = NormalizedString::from(sequence);
|
let mut normalized = NormalizedString::from(sequence);
|
||||||
|
|
||||||
if let Some(normalizer) = &self.normalizer {
|
if let Some(normalizer) = &self.normalizer {
|
||||||
|
@ -1,21 +1,63 @@
|
|||||||
use super::Result;
|
|
||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
|
use std::ops::{Bound, RangeBounds};
|
||||||
use unicode_normalization_alignments::UnicodeNormalization;
|
use unicode_normalization_alignments::UnicodeNormalization;
|
||||||
|
|
||||||
/// Takes care of pre-processing strings.
|
/// Represents a Range usable by the NormalizedString to index its content.
|
||||||
pub trait Normalizer {
|
/// A Range can use indices relative to either the `Original` or the `Normalized` string
|
||||||
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum Range<T: RangeBounds<usize>> {
|
||||||
|
Original(T),
|
||||||
|
Normalized(T),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A normalized string takes care of keeping both versions of a `String`, and
|
impl<T> Range<T>
|
||||||
/// provides necessary alignments to retrieve ranges of both strings.
|
where
|
||||||
|
T: RangeBounds<usize>,
|
||||||
|
{
|
||||||
|
/// Unwrap the underlying range
|
||||||
|
fn unwrap(self) -> T {
|
||||||
|
match self {
|
||||||
|
Range::Original(r) => r,
|
||||||
|
Range::Normalized(r) => r,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts the current Range to a `std::ops::Range<usize>`. This requires the `max_len`
|
||||||
|
/// of the represented string (in chars, not bytes) in order to cover the case where the
|
||||||
|
/// original provided range was unbounded
|
||||||
|
fn into_full_range(self, max_len: usize) -> std::ops::Range<usize> {
|
||||||
|
let range = self.unwrap();
|
||||||
|
|
||||||
|
let start = match range.start_bound() {
|
||||||
|
Bound::Unbounded => 0,
|
||||||
|
Bound::Included(i) => *i,
|
||||||
|
Bound::Excluded(i) => *i + 1,
|
||||||
|
};
|
||||||
|
let end = match range.end_bound() {
|
||||||
|
Bound::Unbounded => max_len,
|
||||||
|
Bound::Included(i) => *i + 1,
|
||||||
|
Bound::Excluded(i) => *i,
|
||||||
|
};
|
||||||
|
|
||||||
|
start..end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A `NormalizedString` takes care of processing an "original" string to modify it and obtain a
|
||||||
|
/// "normalized" string. It keeps both version of the string, alignments information between both
|
||||||
|
/// and provides an interface to retrieve ranges of each string, using offsets from any of them.
|
||||||
|
///
|
||||||
|
/// It is possible to retrieve a part of the original string, by indexing it with offsets from the
|
||||||
|
/// normalized one, and the other way around too. It is also possible to convert offsets from one
|
||||||
|
/// referential to the other one easily.
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct NormalizedString {
|
pub struct NormalizedString {
|
||||||
|
/// The original version of the string, before any modification
|
||||||
original: String,
|
original: String,
|
||||||
|
/// The normalized version of the string, after all modifications
|
||||||
normalized: String,
|
normalized: String,
|
||||||
/// Mapping from normalized string to original one
|
/// Mapping from normalized string to original one: (start, end) for each character of the
|
||||||
/// (pos, changes) where pos is the position in the modified string, and changes an isize
|
/// normalized string
|
||||||
/// representing the number of insertions or deletions
|
|
||||||
alignments: Vec<(usize, usize)>,
|
alignments: Vec<(usize, usize)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,6 +68,7 @@ impl std::cmp::PartialEq for NormalizedString {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl NormalizedString {
|
impl NormalizedString {
|
||||||
|
/// Create a NormalizedString from the given str
|
||||||
pub fn from(s: &str) -> Self {
|
pub fn from(s: &str) -> Self {
|
||||||
NormalizedString {
|
NormalizedString {
|
||||||
original: s.to_owned(),
|
original: s.to_owned(),
|
||||||
@ -44,14 +87,35 @@ impl NormalizedString {
|
|||||||
&self.original
|
&self.original
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the range of the original string corresponding to the received range on the
|
/// Convert the given offsets range from one referential to the other one:
|
||||||
/// normalized string. Returns None if out of bounds
|
/// `Original => Normalized` or `Normalized => Original`
|
||||||
pub fn get_original_offsets(
|
pub fn convert_offsets<T: RangeBounds<usize>>(
|
||||||
&self,
|
&self,
|
||||||
range: std::ops::Range<usize>,
|
range: Range<T>,
|
||||||
) -> Option<std::ops::Range<usize>> {
|
) -> Option<std::ops::Range<usize>> {
|
||||||
|
match range {
|
||||||
|
Range::Original(_) => {
|
||||||
|
let (mut start, mut end) = (0, 0);
|
||||||
|
let r = range.into_full_range(self.alignments.last().map_or(0, |(_, e)| *e));
|
||||||
|
println!("{:?}\t{:?}", r, self.alignments);
|
||||||
self.alignments
|
self.alignments
|
||||||
.get(range)
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.take_while(|(_, alignment)| r.end >= alignment.1)
|
||||||
|
.for_each(|(i, alignment)| {
|
||||||
|
println!("{:?}", alignment);
|
||||||
|
if alignment.0 <= r.start {
|
||||||
|
start = i;
|
||||||
|
}
|
||||||
|
if alignment.1 <= r.end {
|
||||||
|
end = i + 1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Some(start..end)
|
||||||
|
}
|
||||||
|
Range::Normalized(_) => self
|
||||||
|
.alignments
|
||||||
|
.get(range.into_full_range(self.alignments.len()))
|
||||||
.map(|alignments| {
|
.map(|alignments| {
|
||||||
if alignments.is_empty() {
|
if alignments.is_empty() {
|
||||||
None
|
None
|
||||||
@ -61,43 +125,30 @@ impl NormalizedString {
|
|||||||
Some(start..end)
|
Some(start..end)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.flatten()
|
.flatten(),
|
||||||
}
|
|
||||||
|
|
||||||
fn get_range_of(&self, s: &str, range: std::ops::Range<usize>) -> Option<String> {
|
|
||||||
let len = s.chars().count();
|
|
||||||
if range.start >= len || range.end > len {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(
|
|
||||||
s.chars()
|
|
||||||
.enumerate()
|
|
||||||
.skip(range.start)
|
|
||||||
.map(|(i, c)| {
|
|
||||||
if i >= range.start && i < range.end {
|
|
||||||
Some(c)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.fuse()
|
|
||||||
.filter(|c| c.is_some())
|
|
||||||
.map(|c| c.unwrap())
|
|
||||||
.collect::<String>(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return a range of the normalized string (indexing on char not bytes)
|
/// Return a range of the normalized string (indexing on char not bytes)
|
||||||
pub fn get_range(&self, range: std::ops::Range<usize>) -> Option<String> {
|
pub fn get_range<T: RangeBounds<usize>>(&self, range: Range<T>) -> Option<&str> {
|
||||||
self.get_range_of(&self.normalized, range)
|
match range {
|
||||||
|
Range::Original(_) => self
|
||||||
|
.convert_offsets(range)
|
||||||
|
.map(|r| get_range_of(&self.normalized, r))
|
||||||
|
.flatten(),
|
||||||
|
Range::Normalized(r) => get_range_of(&self.normalized, r),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return a range of the original string, using a range from the normalized string
|
/// Return a range of the original string (indexing on char not bytes)
|
||||||
pub fn get_range_original(&self, range: std::ops::Range<usize>) -> Option<String> {
|
pub fn get_range_original<T: RangeBounds<usize>>(&self, range: Range<T>) -> Option<&str> {
|
||||||
self.get_original_offsets(range)
|
match range {
|
||||||
.map(|range| self.get_range_of(&self.original, range))
|
Range::Original(r) => get_range_of(&self.original, r),
|
||||||
.flatten()
|
Range::Normalized(_) => self
|
||||||
|
.convert_offsets(range)
|
||||||
|
.map(|r| get_range_of(&self.original, r))
|
||||||
|
.flatten(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies transformations to the current normalized version, updating the current
|
/// Applies transformations to the current normalized version, updating the current
|
||||||
@ -115,19 +166,10 @@ impl NormalizedString {
|
|||||||
/// them has a `change` of `1`, but more doesn't make any sense.
|
/// them has a `change` of `1`, but more doesn't make any sense.
|
||||||
/// We treat any value above `1` as `1`.
|
/// We treat any value above `1` as `1`.
|
||||||
pub fn transform<I: Iterator<Item = (char, isize)>>(&mut self, dest: I, initial_offset: usize) {
|
pub fn transform<I: Iterator<Item = (char, isize)>>(&mut self, dest: I, initial_offset: usize) {
|
||||||
let mut offset = 0;
|
let mut offset = -(initial_offset as isize);
|
||||||
let mut remaining_offset = initial_offset;
|
|
||||||
let (ch, alignments): (Vec<_>, Vec<_>) = dest
|
let (ch, alignments): (Vec<_>, Vec<_>) = dest
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(index, (c, changes))| {
|
.map(|(index, (c, changes))| {
|
||||||
let changes = if remaining_offset != 0 {
|
|
||||||
let c = changes - remaining_offset as isize;
|
|
||||||
remaining_offset = 0;
|
|
||||||
c
|
|
||||||
} else {
|
|
||||||
changes
|
|
||||||
};
|
|
||||||
|
|
||||||
let uof = if offset < 0 {
|
let uof = if offset < 0 {
|
||||||
-offset as usize
|
-offset as usize
|
||||||
} else {
|
} else {
|
||||||
@ -149,24 +191,10 @@ impl NormalizedString {
|
|||||||
}
|
}
|
||||||
// No changes required here
|
// No changes required here
|
||||||
Ordering::Equal => self.alignments.get(idx).copied(),
|
Ordering::Equal => self.alignments.get(idx).copied(),
|
||||||
// Some characters where removed, so we merge our range with the one from the
|
// Some characters where removed, nothing to change in alignments
|
||||||
// removed characters as the new alignment
|
|
||||||
Ordering::Less => {
|
Ordering::Less => {
|
||||||
let uch = -changes as usize;
|
|
||||||
offset += changes;
|
offset += changes;
|
||||||
self.alignments.get(idx..=idx + uch).map(|alignments| {
|
self.alignments.get(idx).copied()
|
||||||
let min = alignments
|
|
||||||
.iter()
|
|
||||||
.map(|(start, end)| usize::min(*start, *end))
|
|
||||||
.min()
|
|
||||||
.unwrap();
|
|
||||||
let max = alignments
|
|
||||||
.iter()
|
|
||||||
.map(|(start, end)| usize::max(*start, *end))
|
|
||||||
.max()
|
|
||||||
.unwrap();
|
|
||||||
(min, max)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -421,6 +449,37 @@ impl NormalizedString {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a range of the given string slice, by indexing chars instead of bytes
|
||||||
|
pub fn get_range_of<T: RangeBounds<usize>>(s: &str, range: T) -> Option<&str> {
|
||||||
|
let len = s.chars().count();
|
||||||
|
let start = match range.start_bound() {
|
||||||
|
Bound::Unbounded => 0,
|
||||||
|
Bound::Included(i) => *i,
|
||||||
|
Bound::Excluded(i) => *i + 1,
|
||||||
|
};
|
||||||
|
let end = match range.end_bound() {
|
||||||
|
Bound::Unbounded => len,
|
||||||
|
Bound::Included(i) => *i + 1,
|
||||||
|
Bound::Excluded(i) => *i,
|
||||||
|
};
|
||||||
|
|
||||||
|
if start >= len || end > len || start >= end {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let start_b = s
|
||||||
|
.char_indices()
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.nth(start as usize)
|
||||||
|
.unwrap_or(0);
|
||||||
|
let end_b = s
|
||||||
|
.char_indices()
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.nth(end as usize)
|
||||||
|
.unwrap_or_else(|| s.len());
|
||||||
|
Some(&s[start_b..end_b])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -462,7 +521,7 @@ mod tests {
|
|||||||
n.filter(|c| *c != 'n');
|
n.filter(|c| *c != 'n');
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&n.alignments,
|
&n.alignments,
|
||||||
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), (6, 7)]
|
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -472,18 +531,43 @@ mod tests {
|
|||||||
n.nfd().filter(|c| !c.is_mark_nonspacing() && *c != 'n');
|
n.nfd().filter(|c| !c.is_mark_nonspacing() && *c != 'n');
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&n.alignments,
|
&n.alignments,
|
||||||
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), (6, 7)]
|
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn range_conversion() {
|
||||||
|
let mut n = NormalizedString::from(" __Hello__ ");
|
||||||
|
n.filter(|c| !c.is_whitespace()).lowercase();
|
||||||
|
let hello_n = n.convert_offsets(Range::Original(6..11));
|
||||||
|
assert_eq!(hello_n, Some(2..7));
|
||||||
|
assert_eq!(
|
||||||
|
n.get_range(Range::Normalized(hello_n.clone().unwrap())),
|
||||||
|
Some("hello")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
n.get_range_original(Range::Normalized(hello_n.unwrap())),
|
||||||
|
Some("Hello")
|
||||||
|
);
|
||||||
|
assert_eq!(n.get_range(Range::Original(6..11)), Some("hello"));
|
||||||
|
assert_eq!(n.get_range_original(Range::Original(6..11)), Some("Hello"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn original_range() {
|
fn original_range() {
|
||||||
let mut n = NormalizedString::from("Hello_______ World!");
|
let mut n = NormalizedString::from("Hello_______ World!");
|
||||||
n.filter(|c| *c != '_').lowercase();
|
n.filter(|c| *c != '_').lowercase();
|
||||||
let world_n = n.get_range(6..11).unwrap();
|
let world_n = n.get_range(Range::Normalized(6..11)).unwrap();
|
||||||
let world_o = n.get_range_original(6..11).unwrap();
|
let world_o = n.get_range_original(Range::Normalized(6..11)).unwrap();
|
||||||
assert_eq!(world_n, "world");
|
assert_eq!(world_n, "world");
|
||||||
assert_eq!(world_o, "World");
|
assert_eq!(world_o, "World");
|
||||||
|
let original_range = Range::Original(n.convert_offsets(Range::Normalized(6..11)).unwrap());
|
||||||
|
assert_eq!(n.get_range(original_range.clone()).unwrap(), "world");
|
||||||
|
assert_eq!(
|
||||||
|
n.get_range_original(original_range.clone()).unwrap(),
|
||||||
|
"World"
|
||||||
|
);
|
||||||
|
assert_eq!(original_range.into_full_range(n.len_original()), 13..18);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -505,8 +589,8 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(&n.normalized, " Hello ");
|
assert_eq!(&n.normalized, " Hello ");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
n.get_range_original(0..n.normalized.len()),
|
n.get_range_original(Range::Normalized(1..n.normalized.len() - 1)),
|
||||||
Some("Hello".into())
|
Some("Hello")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -514,10 +598,13 @@ mod tests {
|
|||||||
fn remove_at_beginning() {
|
fn remove_at_beginning() {
|
||||||
let mut n = NormalizedString::from(" Hello");
|
let mut n = NormalizedString::from(" Hello");
|
||||||
n.filter(|c| !c.is_whitespace());
|
n.filter(|c| !c.is_whitespace());
|
||||||
assert_eq!(n.get_range_original(1.."Hello".len()), Some("ello".into()));
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
n.get_range_original(0..n.normalized.len()),
|
n.get_range_original(Range::Normalized(1.."Hello".len())),
|
||||||
Some(" Hello".into())
|
Some("ello")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||||
|
Some("Hello")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -525,10 +612,10 @@ mod tests {
|
|||||||
fn remove_at_end() {
|
fn remove_at_end() {
|
||||||
let mut n = NormalizedString::from("Hello ");
|
let mut n = NormalizedString::from("Hello ");
|
||||||
n.filter(|c| !c.is_whitespace());
|
n.filter(|c| !c.is_whitespace());
|
||||||
assert_eq!(n.get_range_original(0..4), Some("Hell".into()));
|
assert_eq!(n.get_range_original(Range::Normalized(0..4)), Some("Hell"));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
n.get_range_original(0..n.normalized.len()),
|
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||||
Some("Hello ".into())
|
Some("Hello")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -539,10 +626,13 @@ mod tests {
|
|||||||
assert_eq!(&n.normalized, "Hello");
|
assert_eq!(&n.normalized, "Hello");
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
n.get_range_original(0.."Hello".len()),
|
n.get_range_original(Range::Normalized(0.."Hello".len())),
|
||||||
Some(" Hello ".into())
|
Some("Hello")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
n.get_range_original(Range::Normalized(1.."Hell".len())),
|
||||||
|
Some("ell")
|
||||||
);
|
);
|
||||||
assert_eq!(n.get_range_original(1.."Hell".len()), Some("ell".into()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -551,8 +641,8 @@ mod tests {
|
|||||||
n.lstrip();
|
n.lstrip();
|
||||||
assert_eq!(&n.normalized, "This is an example ");
|
assert_eq!(&n.normalized, "This is an example ");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
n.get_range_original(0..n.normalized.len()),
|
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||||
Some(" This is an example ".into())
|
Some("This is an example ")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -562,8 +652,8 @@ mod tests {
|
|||||||
n.rstrip();
|
n.rstrip();
|
||||||
assert_eq!(&n.normalized, " This is an example");
|
assert_eq!(&n.normalized, " This is an example");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
n.get_range_original(0..n.normalized.len()),
|
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||||
Some(" This is an example ".into())
|
Some(" This is an example")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -573,8 +663,8 @@ mod tests {
|
|||||||
n.strip();
|
n.strip();
|
||||||
assert_eq!(&n.normalized, "This is an example");
|
assert_eq!(&n.normalized, "This is an example");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
n.get_range_original(0..n.normalized.len()),
|
n.get_range_original(Range::Normalized(0..n.normalized.len())),
|
||||||
Some(" This is an example ".into())
|
Some("This is an example")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -597,7 +687,7 @@ mod tests {
|
|||||||
(4, 5)
|
(4, 5)
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
assert_eq!(n.get_original_offsets(0..4), Some(0..0));
|
assert_eq!(n.convert_offsets(Range::Normalized(0..4)), Some(0..0));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -619,6 +709,16 @@ mod tests {
|
|||||||
(3, 3)
|
(3, 3)
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
assert_eq!(n.get_original_offsets(3.." there".len()), Some(3..3));
|
assert_eq!(
|
||||||
|
n.convert_offsets(Range::Normalized(3.." there".len())),
|
||||||
|
Some(3..3)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_range() {
|
||||||
|
let s = String::from("Hello my name is John 👋");
|
||||||
|
assert_eq!(get_range_of(&s, ..), Some(&s[..]));
|
||||||
|
assert_eq!(get_range_of(&s, 17..), Some("John 👋"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
58
tokenizers/src/utils/iter.rs
Normal file
58
tokenizers/src/utils/iter.rs
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
//! This comes from the Rust libcore and is duplicated here because it is not exported
|
||||||
|
//! (cf https://github.com/rust-lang/rust/blob/25091ed9b7739e12466fb2490baa1e8a2815121c/src/libcore/iter/adapters/mod.rs#L2664)
|
||||||
|
//! We are now using the version from https://stackoverflow.com/questions/44544323/how-to-unzip-a-sequence-of-resulta-b-e-to-a-veca-vecb-and-stop-on-f
|
||||||
|
//! because the one from the libcore seems to cause overflowing stacks in some cases
|
||||||
|
|
||||||
|
pub struct ResultShunt<I, E> {
|
||||||
|
iter: I,
|
||||||
|
error: Option<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, T, E> ResultShunt<I, E>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = Result<T, E>>,
|
||||||
|
{
|
||||||
|
/// Process the given iterator as if it yielded a `T` instead of a
|
||||||
|
/// `Result<T, _>`. Any errors will stop the inner iterator and
|
||||||
|
/// the overall result will be an error.
|
||||||
|
pub fn process<F, U>(iter: I, mut f: F) -> Result<U, E>
|
||||||
|
where
|
||||||
|
F: FnMut(&mut Self) -> U,
|
||||||
|
{
|
||||||
|
let mut shunt = ResultShunt::new(iter);
|
||||||
|
let value = f(shunt.by_ref());
|
||||||
|
shunt.reconstruct(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(iter: I) -> Self {
|
||||||
|
ResultShunt { iter, error: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consume the adapter and rebuild a `Result` value. This should
|
||||||
|
/// *always* be called, otherwise any potential error would be
|
||||||
|
/// lost.
|
||||||
|
fn reconstruct<U>(self, val: U) -> Result<U, E> {
|
||||||
|
match self.error {
|
||||||
|
None => Ok(val),
|
||||||
|
Some(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, T, E> Iterator for ResultShunt<I, E>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = Result<T, E>>,
|
||||||
|
{
|
||||||
|
type Item = T;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
match self.iter.next() {
|
||||||
|
Some(Ok(v)) => Some(v),
|
||||||
|
Some(Err(e)) => {
|
||||||
|
self.error = Some(e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
3
tokenizers/src/utils/mod.rs
Normal file
3
tokenizers/src/utils/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub mod iter;
|
||||||
|
pub mod padding;
|
||||||
|
pub mod truncation;
|
63
tokenizers/src/utils/padding.rs
Normal file
63
tokenizers/src/utils/padding.rs
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
use crate::tokenizer::{Encoding, Result};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
/// The various possible padding directions.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum PaddingDirection {
|
||||||
|
Left,
|
||||||
|
Right,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::AsRef<str> for PaddingDirection {
|
||||||
|
fn as_ref(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
PaddingDirection::Left => "left",
|
||||||
|
PaddingDirection::Right => "right",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PaddingParams {
|
||||||
|
pub strategy: PaddingStrategy,
|
||||||
|
pub direction: PaddingDirection,
|
||||||
|
pub pad_id: u32,
|
||||||
|
pub pad_type_id: u32,
|
||||||
|
pub pad_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum PaddingStrategy {
|
||||||
|
BatchLongest,
|
||||||
|
Fixed(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pad_encodings(
|
||||||
|
mut encodings: Vec<Encoding>,
|
||||||
|
params: &PaddingParams,
|
||||||
|
) -> Result<Vec<Encoding>> {
|
||||||
|
if encodings.is_empty() {
|
||||||
|
return Ok(encodings);
|
||||||
|
}
|
||||||
|
|
||||||
|
let pad_length = match params.strategy {
|
||||||
|
PaddingStrategy::Fixed(size) => size,
|
||||||
|
PaddingStrategy::BatchLongest => encodings
|
||||||
|
.par_iter()
|
||||||
|
.map(|e| e.get_ids().len())
|
||||||
|
.max()
|
||||||
|
.unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
encodings.par_iter_mut().for_each(|encoding| {
|
||||||
|
encoding.pad(
|
||||||
|
pad_length,
|
||||||
|
params.pad_id,
|
||||||
|
params.pad_type_id,
|
||||||
|
¶ms.pad_token,
|
||||||
|
params.direction,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(encodings)
|
||||||
|
}
|
@ -1,5 +1,4 @@
|
|||||||
use crate::tokenizer::{Encoding, PaddingDirection, Result};
|
use crate::tokenizer::{Encoding, Result};
|
||||||
use rayon::prelude::*;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct TruncationParams {
|
pub struct TruncationParams {
|
||||||
@ -8,21 +7,6 @@ pub struct TruncationParams {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PaddingParams {
|
|
||||||
pub strategy: PaddingStrategy,
|
|
||||||
pub direction: PaddingDirection,
|
|
||||||
pub pad_id: u32,
|
|
||||||
pub pad_type_id: u32,
|
|
||||||
pub pad_token: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum PaddingStrategy {
|
|
||||||
BatchLongest,
|
|
||||||
Fixed(usize),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
SecondSequenceNotProvided,
|
SecondSequenceNotProvided,
|
||||||
@ -118,33 +102,3 @@ pub fn truncate_encodings(
|
|||||||
|
|
||||||
Ok((encoding, pair_encoding))
|
Ok((encoding, pair_encoding))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pad_encodings(
|
|
||||||
mut encodings: Vec<Encoding>,
|
|
||||||
params: &PaddingParams,
|
|
||||||
) -> Result<Vec<Encoding>> {
|
|
||||||
if encodings.is_empty() {
|
|
||||||
return Ok(encodings);
|
|
||||||
}
|
|
||||||
|
|
||||||
let pad_length = match params.strategy {
|
|
||||||
PaddingStrategy::Fixed(size) => size,
|
|
||||||
PaddingStrategy::BatchLongest => encodings
|
|
||||||
.par_iter()
|
|
||||||
.map(|e| e.get_ids().len())
|
|
||||||
.max()
|
|
||||||
.unwrap(),
|
|
||||||
};
|
|
||||||
|
|
||||||
encodings.par_iter_mut().for_each(|encoding| {
|
|
||||||
encoding.pad(
|
|
||||||
pad_length,
|
|
||||||
params.pad_id,
|
|
||||||
params.pad_type_id,
|
|
||||||
¶ms.pad_token,
|
|
||||||
params.direction,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(encodings)
|
|
||||||
}
|
|
154
tokenizers/tests/offsets.rs
Normal file
154
tokenizers/tests/offsets.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
use tokenizers::models::bpe::BPE;
|
||||||
|
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||||
|
use tokenizers::tokenizer::{get_range_of, EncodeInput, Tokenizer};
|
||||||
|
|
||||||
|
fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
|
||||||
|
let mut tokenizer = Tokenizer::new(Box::new(
|
||||||
|
BPE::from_files("data/gpt2-vocab.json", "data/gpt2-merges.txt")
|
||||||
|
.build()
|
||||||
|
.expect("Files not found, run `make test` to download these files"),
|
||||||
|
));
|
||||||
|
tokenizer.with_pre_tokenizer(Box::new(
|
||||||
|
ByteLevel::default().add_prefix_space(add_prefix_space),
|
||||||
|
));
|
||||||
|
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||||
|
tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets)));
|
||||||
|
|
||||||
|
tokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn offset_as_range(offset: (usize, usize)) -> std::ops::Range<usize> {
|
||||||
|
offset.0..offset.1
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn byte_level_basic() {
|
||||||
|
// Without trimming offsets
|
||||||
|
let tokenizer = get_byte_level(true, false);
|
||||||
|
|
||||||
|
let input = String::from("Hello there, how are you?");
|
||||||
|
let output = tokenizer
|
||||||
|
.encode(EncodeInput::Single(input.clone()), false)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let offsets = output.get_offsets();
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[0])),
|
||||||
|
Some("Hello")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[1])),
|
||||||
|
Some(" there")
|
||||||
|
);
|
||||||
|
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some(","));
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[3])),
|
||||||
|
Some(" how")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[4])),
|
||||||
|
Some(" are")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[5])),
|
||||||
|
Some(" you")
|
||||||
|
);
|
||||||
|
assert_eq!(get_range_of(&input, offset_as_range(offsets[6])), Some("?"));
|
||||||
|
|
||||||
|
// And when trimming offsets:
|
||||||
|
let tokenizer = get_byte_level(true, true);
|
||||||
|
|
||||||
|
let input = String::from("Hello there, how are you?");
|
||||||
|
let output = tokenizer
|
||||||
|
.encode(EncodeInput::Single(input.clone()), false)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let offsets = output.get_offsets();
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[0])),
|
||||||
|
Some("Hello")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[1])),
|
||||||
|
Some("there")
|
||||||
|
);
|
||||||
|
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some(","));
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[3])),
|
||||||
|
Some("how")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[4])),
|
||||||
|
Some("are")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
get_range_of(&input, offset_as_range(offsets[5])),
|
||||||
|
Some("you")
|
||||||
|
);
|
||||||
|
assert_eq!(get_range_of(&input, offset_as_range(offsets[6])), Some("?"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn byte_level_unicode() {
|
||||||
|
let tokenizer = get_byte_level(true, false);
|
||||||
|
|
||||||
|
let input = String::from("i⭢j");
|
||||||
|
let output = tokenizer
|
||||||
|
.encode(EncodeInput::Single(input.clone()), false)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let offsets = output.get_offsets();
|
||||||
|
assert_eq!(get_range_of(&input, offset_as_range(offsets[1])), Some("⭢"));
|
||||||
|
assert_eq!(get_range_of(&input, offset_as_range(offsets[2])), Some("⭢"));
|
||||||
|
assert_eq!(get_range_of(&input, offset_as_range(offsets[3])), Some("⭢"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn byte_level_double_sequence() {
|
||||||
|
let input_a = String::from("My name is Anthony");
|
||||||
|
let input_b = String::from("What is my name?");
|
||||||
|
|
||||||
|
// Without trimming offsets
|
||||||
|
let tokenizer = get_byte_level(true, false);
|
||||||
|
let output = tokenizer
|
||||||
|
.encode(EncodeInput::Dual(input_a.clone(), input_b.clone()), false)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let offsets = output.get_offsets();
|
||||||
|
assert_eq!(
|
||||||
|
offsets,
|
||||||
|
&[
|
||||||
|
(0, 2),
|
||||||
|
(2, 7),
|
||||||
|
(7, 10),
|
||||||
|
(10, 18),
|
||||||
|
(0, 4),
|
||||||
|
(4, 7),
|
||||||
|
(7, 10),
|
||||||
|
(10, 15),
|
||||||
|
(15, 16)
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
// When trimming offsets
|
||||||
|
let tokenizer = get_byte_level(true, true);
|
||||||
|
let output = tokenizer
|
||||||
|
.encode(EncodeInput::Dual(input_a, input_b), false)
|
||||||
|
.unwrap();
|
||||||
|
let offsets = output.get_offsets();
|
||||||
|
assert_eq!(
|
||||||
|
offsets,
|
||||||
|
&[
|
||||||
|
(0, 2),
|
||||||
|
(3, 7),
|
||||||
|
(8, 10),
|
||||||
|
(11, 18),
|
||||||
|
(0, 4),
|
||||||
|
(5, 7),
|
||||||
|
(8, 10),
|
||||||
|
(11, 15),
|
||||||
|
(15, 16)
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
Reference in New Issue
Block a user