Add an Encoding.sequences to allow masking

This commit is contained in:
Anthony MOI
2020-11-05 13:12:15 -05:00
committed by Anthony MOI
parent 385d25720a
commit 57d162b269
8 changed files with 74 additions and 1 deletions

View File

@ -123,6 +123,11 @@ export interface RawEncoding {
*/ */
getWords(): (number | undefined)[]; getWords(): (number | undefined)[];
/**
* The sequences indices
*/
getSequences(): (number | undefined)[];
/** /**
* Pad the current Encoding at the given length * Pad the current Encoding at the given length
* *

View File

@ -112,6 +112,13 @@ describe("RawEncoding", () => {
}); });
}); });
describe("getSequences", () => {
it("returns the correct list of indexes", () => {
expect(encoding.getSequences()).toEqual([0, 0, 0, 0, 0]);
expect(encodingDual.getSequences()).toEqual([0, 0, 0, 0, 0, 1, 1, 1, 1]);
});
});
describe("wordToTokens", () => { describe("wordToTokens", () => {
it("returns the correct indexes", () => { it("returns the correct indexes", () => {
const indexes = encoding.wordToTokens(3); const indexes = encoding.wordToTokens(3);

View File

@ -11,6 +11,7 @@ export class Encoding {
private _tokens?: string[]; private _tokens?: string[];
private _typeIds?: number[]; private _typeIds?: number[];
private _wordIndexes?: (number | undefined)[]; private _wordIndexes?: (number | undefined)[];
private _sequenceIndexes?: (number | undefined)[];
constructor(private _rawEncoding: RawEncoding) {} constructor(private _rawEncoding: RawEncoding) {}
@ -151,6 +152,14 @@ export class Encoding {
return (this._wordIndexes = this._rawEncoding.getWords()); return (this._wordIndexes = this._rawEncoding.getWords());
} }
get sequenceIndexes(): (number | undefined)[] {
if (this._sequenceIndexes) {
return this._sequenceIndexes;
}
return (this._sequenceIndexes = this._rawEncoding.getSequences());
}
/** /**
* Get the encoded tokens corresponding to the word at the given index in one of the input * Get the encoded tokens corresponding to the word at the given index in one of the input
* sequences, with the form [startToken, endToken+1] * sequences, with the form [startToken, endToken+1]

View File

@ -115,7 +115,7 @@ declare_types! {
} }
method getWords(mut cx) { method getWords(mut cx) {
// getWords(): number[] // getWords(): (number | undefined)[]
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
@ -127,6 +127,18 @@ declare_types! {
Ok(neon_serde::to_value(&mut cx, &ids)?) Ok(neon_serde::to_value(&mut cx, &ids)?)
} }
method getSequences(mut cx) {
// getSequences(): (number | undefined)[]
let this = cx.this();
let guard = cx.lock();
let ids = this.borrow(&guard)
.encoding.as_ref().expect("Uninitialized Encoding")
.get_sequences();
Ok(neon_serde::to_value(&mut cx, &ids)?)
}
method getOffsets(mut cx) { method getOffsets(mut cx) {
// getOffsets(): [number, number][] // getOffsets(): [number, number][]

View File

@ -327,6 +327,17 @@ class Encoding:
""" """
pass pass
@property @property
def sequences(self) -> List[Optional[int]]:
"""The generated sequence indices.
They represent the index of the input sequence associated to each token.
The sequence id can be None if the token is not related to any input sequence,
like for example with special tokens.
Returns:
A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index.
"""
@property
def type_ids(self) -> List[int]: def type_ids(self) -> List[int]:
"""The generated type IDs """The generated type IDs

View File

@ -152,6 +152,19 @@ impl PyEncoding {
self.encoding.get_words().to_vec() self.encoding.get_words().to_vec()
} }
/// The generated sequence indices.
///
/// They represent the index of the input sequence associated to each token.
/// The sequence id can be None if the token is not related to any input sequence,
/// like for example with special tokens.
///
/// Returns:
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index.
#[getter]
fn get_sequences(&self) -> Vec<Option<usize>> {
self.encoding.get_sequences()
}
/// The generated type IDs /// The generated type IDs
/// ///
/// Generally used for tasks like sequence classification or question answering, /// Generally used for tasks like sequence classification or question answering,

View File

@ -12,6 +12,12 @@ class TestEncoding:
pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?") pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?")
return single_encoding, pair_encoding return single_encoding, pair_encoding
def test_sequences(self, encodings):
single, pair = encodings
assert single.sequences == [None, 0, 0, 0, 0, None]
assert pair.sequences == [None, 0, 0, 0, 0, None, 1, 1, 1, None]
def test_n_sequences(self, encodings): def test_n_sequences(self, encodings):
single, pair = encodings single, pair = encodings
assert single.n_sequences == 1 assert single.n_sequences == 1

View File

@ -129,6 +129,16 @@ impl Encoding {
&self.words &self.words
} }
pub fn get_sequences(&self) -> Vec<Option<usize>> {
let mut sequences = vec![None; self.len()];
for seq_id in 0..self.n_sequences() {
let range = self.sequence_range(seq_id);
let seq_len = range.len();
sequences.splice(range, std::iter::repeat(Some(seq_id)).take(seq_len));
}
sequences
}
pub fn get_words_mut(&mut self) -> &mut [Option<u32>] { pub fn get_words_mut(&mut self) -> &mut [Option<u32>] {
&mut self.words &mut self.words
} }