Files
tokenizers/bindings/node/lib/implementations/encoding.ts
Nicolas Patry 152880ab3e Adding truncation_side within TruncationParams. (#860)
* Add truncation to enable_truncation

* Fix typo

* Adding truncation_side within `TruncationParams`.

* Node serialization of this direction param.

* Update the test.

* Fixing warnings/lint.

* Adding stuff (can't local debug :( )

* Slow loop... ;(

* Stub.py.

Co-authored-by: Niels Rogge <niels.rogge1@gmail.com>
2021-12-28 12:37:06 +01:00

280 lines
7.2 KiB
TypeScript
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { PaddingOptions, RawEncoding } from "../bindings/raw-encoding";
import { mergeEncodings } from "../bindings/utils";
export class Encoding {
private _attentionMask?: number[];
private _ids?: number[];
private _length?: number;
private _offsets?: [number, number][];
private _overflowing?: Encoding[];
private _specialTokensMask?: number[];
private _tokens?: string[];
private _typeIds?: number[];
private _wordIndexes?: (number | undefined)[];
private _sequenceIndexes?: (number | undefined)[];
constructor(private _rawEncoding: RawEncoding) {}
/**
* Merge a list of Encoding into one final Encoding
* @param encodings The list of encodings to merge
* @param [growingOffsets=false] Whether the offsets should accumulate while merging
*/
static merge(encodings: Encoding[], growingOffsets?: boolean): Encoding {
const mergedRaw = mergeEncodings(
encodings.map((e) => e.rawEncoding),
growingOffsets
);
return new Encoding(mergedRaw);
}
/**
* Number of sequences
*/
get nSequences(): number {
return this._rawEncoding.getNSequences();
}
setSequenceId(seqId: number): void {
return this._rawEncoding.setSequenceId(seqId);
}
/**
* Attention mask
*/
get attentionMask(): number[] {
if (this._attentionMask) {
return this._attentionMask;
}
return (this._attentionMask = this._rawEncoding.getAttentionMask());
}
/**
* Tokenized ids
*/
get ids(): number[] {
if (this._ids) {
return this._ids;
}
return (this._ids = this._rawEncoding.getIds());
}
/**
* Number of tokens
*/
get length(): number {
if (this._length !== undefined) {
return this._length;
}
return (this._length = this._rawEncoding.getLength());
}
/**
* Offsets
*/
get offsets(): [number, number][] {
if (this._offsets) {
return this._offsets;
}
return (this._offsets = this._rawEncoding.getOffsets());
}
/**
* Overflowing encodings, after truncation
*/
get overflowing(): Encoding[] {
if (this._overflowing) {
return this._overflowing;
}
return (this._overflowing = this._rawEncoding
.getOverflowing()
.map((e) => new Encoding(e)));
}
/**
* __⚠ DANGER ZONE: do not touch unless you know what you're doing ⚠__
* Access to the `rawEncoding` returned by the internal Rust code.
* @private
* @ignore
* @since 0.6.0
*/
get rawEncoding(): Readonly<RawEncoding> {
return this._rawEncoding;
}
/**
* Special tokens mask
*/
get specialTokensMask(): number[] {
if (this._specialTokensMask) {
return this._specialTokensMask;
}
return (this._specialTokensMask = this._rawEncoding.getSpecialTokensMask());
}
/**
* Tokenized string
*/
get tokens(): string[] {
if (this._tokens) {
return this._tokens;
}
return (this._tokens = this._rawEncoding.getTokens());
}
/**
* Type ids
*/
get typeIds(): number[] {
if (this._typeIds) {
return this._typeIds;
}
return (this._typeIds = this._rawEncoding.getTypeIds());
}
/**
* The tokenized words indexes
*/
get wordIndexes(): (number | undefined)[] {
if (this._wordIndexes) {
return this._wordIndexes;
}
return (this._wordIndexes = this._rawEncoding.getWordIds());
}
get sequenceIndexes(): (number | undefined)[] {
if (this._sequenceIndexes) {
return this._sequenceIndexes;
}
return (this._sequenceIndexes = this._rawEncoding.getSequenceIds());
}
/**
* Get the encoded tokens corresponding to the word at the given index in one of the input
* sequences, with the form [startToken, endToken+1]
* @param word The position of a word in one of the input sequences
* @param seqId The index of the input sequence that contains said word
* @since 0.7.0
*/
wordToTokens(word: number, seqId?: number): [number, number] | undefined {
return this._rawEncoding.wordToTokens(word, seqId);
}
/**
* Get the offsets of the word at the given index in the input sequence
* @param word The index of the word in the input sequence
* @param seqId The index of the input sequence that contains said word
* @since 0.7.0
*/
wordToChars(word: number, seqId?: number): [number, number] | undefined {
return this._rawEncoding.wordToChars(word, seqId);
}
/**
* Get the index of the sequence that contains the given token
* @param token The index of the token in the encoded sequence
*/
tokenToSequence(token: number): number | undefined {
return this._rawEncoding.tokenToSequence(token);
}
/**
* Get the offsets of the token at the given index
*
* The returned offsets are related to the input sequence that contains the
* token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence
* @since 0.7.0
*/
tokenToChars(token: number): [number, number] | undefined {
return this._rawEncoding.tokenToChars(token);
}
/**
* Get the word that contains the token at the given index
*
* The returned index is related to the input sequence that contains the
* token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence
* @since 0.7.0
*/
tokenToWord(token: number): number | undefined {
return this._rawEncoding.tokenToWord(token);
}
/**
* Find the index of the token at the position of the given char
* @param pos The position of a char in one of the input strings
* @param seqId The index of the input sequence that contains said char
* @since 0.6.0
*/
charToToken(pos: number, seqId?: number): number | undefined {
return this._rawEncoding.charToToken(pos, seqId);
}
/**
* Get the word that contains the given char
* @param pos The position of a char in the input string
* @param seqId The index of the input sequence that contains said char
* @since 0.7.0
*/
charToWord(pos: number, seqId?: number): number | undefined {
return this._rawEncoding.charToWord(pos, seqId);
}
/**
* Pad the current Encoding at the given length
*
* @param length The length at which to pad
* @param [options] Padding options
*/
pad(length: number, options?: PaddingOptions): void {
this._rawEncoding.pad(length, options);
this.resetInternalProperties();
}
/**
* Truncate the current Encoding at the given max length
*
* @param length The maximum length to be kept
* @param [stride=0] The length of the previous first sequence
* to be included in the overflowing sequence
* @param [direction='right'] Truncate direction
*/
truncate(length: number, stride?: number, direction = "right"): void {
this._rawEncoding.truncate(length, stride, direction);
this.resetInternalProperties();
}
private resetInternalProperties(): void {
for (const prop of [
"_attentionMask",
"_ids",
"_length",
"_offsets",
"_overflowing",
"_specialTokensMask",
"_tokens",
"_typeIds",
"_wordIndexes",
]) {
delete this[prop as keyof this];
}
}
}