mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 04:29:21 +00:00
Add an Encoding.sequences to allow masking
This commit is contained in:
5
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
5
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
@ -123,6 +123,11 @@ export interface RawEncoding {
|
|||||||
*/
|
*/
|
||||||
getWords(): (number | undefined)[];
|
getWords(): (number | undefined)[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The sequences indices
|
||||||
|
*/
|
||||||
|
getSequences(): (number | undefined)[];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pad the current Encoding at the given length
|
* Pad the current Encoding at the given length
|
||||||
*
|
*
|
||||||
|
@ -112,6 +112,13 @@ describe("RawEncoding", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("getSequences", () => {
|
||||||
|
it("returns the correct list of indexes", () => {
|
||||||
|
expect(encoding.getSequences()).toEqual([0, 0, 0, 0, 0]);
|
||||||
|
expect(encodingDual.getSequences()).toEqual([0, 0, 0, 0, 0, 1, 1, 1, 1]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe("wordToTokens", () => {
|
describe("wordToTokens", () => {
|
||||||
it("returns the correct indexes", () => {
|
it("returns the correct indexes", () => {
|
||||||
const indexes = encoding.wordToTokens(3);
|
const indexes = encoding.wordToTokens(3);
|
||||||
|
@ -11,6 +11,7 @@ export class Encoding {
|
|||||||
private _tokens?: string[];
|
private _tokens?: string[];
|
||||||
private _typeIds?: number[];
|
private _typeIds?: number[];
|
||||||
private _wordIndexes?: (number | undefined)[];
|
private _wordIndexes?: (number | undefined)[];
|
||||||
|
private _sequenceIndexes?: (number | undefined)[];
|
||||||
|
|
||||||
constructor(private _rawEncoding: RawEncoding) {}
|
constructor(private _rawEncoding: RawEncoding) {}
|
||||||
|
|
||||||
@ -151,6 +152,14 @@ export class Encoding {
|
|||||||
return (this._wordIndexes = this._rawEncoding.getWords());
|
return (this._wordIndexes = this._rawEncoding.getWords());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
get sequenceIndexes(): (number | undefined)[] {
|
||||||
|
if (this._sequenceIndexes) {
|
||||||
|
return this._sequenceIndexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (this._sequenceIndexes = this._rawEncoding.getSequences());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the encoded tokens corresponding to the word at the given index in one of the input
|
* Get the encoded tokens corresponding to the word at the given index in one of the input
|
||||||
* sequences, with the form [startToken, endToken+1]
|
* sequences, with the form [startToken, endToken+1]
|
||||||
|
@ -115,7 +115,7 @@ declare_types! {
|
|||||||
}
|
}
|
||||||
|
|
||||||
method getWords(mut cx) {
|
method getWords(mut cx) {
|
||||||
// getWords(): number[]
|
// getWords(): (number | undefined)[]
|
||||||
|
|
||||||
let this = cx.this();
|
let this = cx.this();
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
@ -127,6 +127,18 @@ declare_types! {
|
|||||||
Ok(neon_serde::to_value(&mut cx, &ids)?)
|
Ok(neon_serde::to_value(&mut cx, &ids)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
method getSequences(mut cx) {
|
||||||
|
// getSequences(): (number | undefined)[]
|
||||||
|
|
||||||
|
let this = cx.this();
|
||||||
|
let guard = cx.lock();
|
||||||
|
let ids = this.borrow(&guard)
|
||||||
|
.encoding.as_ref().expect("Uninitialized Encoding")
|
||||||
|
.get_sequences();
|
||||||
|
|
||||||
|
Ok(neon_serde::to_value(&mut cx, &ids)?)
|
||||||
|
}
|
||||||
|
|
||||||
method getOffsets(mut cx) {
|
method getOffsets(mut cx) {
|
||||||
// getOffsets(): [number, number][]
|
// getOffsets(): [number, number][]
|
||||||
|
|
||||||
|
@ -327,6 +327,17 @@ class Encoding:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@property
|
@property
|
||||||
|
def sequences(self) -> List[Optional[int]]:
|
||||||
|
"""The generated sequence indices.
|
||||||
|
|
||||||
|
They represent the index of the input sequence associated to each token.
|
||||||
|
The sequence id can be None if the token is not related to any input sequence,
|
||||||
|
like for example with special tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index.
|
||||||
|
"""
|
||||||
|
@property
|
||||||
def type_ids(self) -> List[int]:
|
def type_ids(self) -> List[int]:
|
||||||
"""The generated type IDs
|
"""The generated type IDs
|
||||||
|
|
||||||
|
@ -152,6 +152,19 @@ impl PyEncoding {
|
|||||||
self.encoding.get_words().to_vec()
|
self.encoding.get_words().to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The generated sequence indices.
|
||||||
|
///
|
||||||
|
/// They represent the index of the input sequence associated to each token.
|
||||||
|
/// The sequence id can be None if the token is not related to any input sequence,
|
||||||
|
/// like for example with special tokens.
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index.
|
||||||
|
#[getter]
|
||||||
|
fn get_sequences(&self) -> Vec<Option<usize>> {
|
||||||
|
self.encoding.get_sequences()
|
||||||
|
}
|
||||||
|
|
||||||
/// The generated type IDs
|
/// The generated type IDs
|
||||||
///
|
///
|
||||||
/// Generally used for tasks like sequence classification or question answering,
|
/// Generally used for tasks like sequence classification or question answering,
|
||||||
|
@ -12,6 +12,12 @@ class TestEncoding:
|
|||||||
pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?")
|
pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?")
|
||||||
return single_encoding, pair_encoding
|
return single_encoding, pair_encoding
|
||||||
|
|
||||||
|
def test_sequences(self, encodings):
|
||||||
|
single, pair = encodings
|
||||||
|
|
||||||
|
assert single.sequences == [None, 0, 0, 0, 0, None]
|
||||||
|
assert pair.sequences == [None, 0, 0, 0, 0, None, 1, 1, 1, None]
|
||||||
|
|
||||||
def test_n_sequences(self, encodings):
|
def test_n_sequences(self, encodings):
|
||||||
single, pair = encodings
|
single, pair = encodings
|
||||||
assert single.n_sequences == 1
|
assert single.n_sequences == 1
|
||||||
|
@ -129,6 +129,16 @@ impl Encoding {
|
|||||||
&self.words
|
&self.words
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_sequences(&self) -> Vec<Option<usize>> {
|
||||||
|
let mut sequences = vec![None; self.len()];
|
||||||
|
for seq_id in 0..self.n_sequences() {
|
||||||
|
let range = self.sequence_range(seq_id);
|
||||||
|
let seq_len = range.len();
|
||||||
|
sequences.splice(range, std::iter::repeat(Some(seq_id)).take(seq_len));
|
||||||
|
}
|
||||||
|
sequences
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_words_mut(&mut self) -> &mut [Option<u32>] {
|
pub fn get_words_mut(&mut self) -> &mut [Option<u32>] {
|
||||||
&mut self.words
|
&mut self.words
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user