words -> word_ids & sequences -> sequence_ids

This commit is contained in:
Anthony MOI
2020-11-09 15:14:14 -05:00
committed by Anthony MOI
parent 57d162b269
commit d3d9f2c76b
11 changed files with 70 additions and 44 deletions

View File

@ -121,12 +121,12 @@ export interface RawEncoding {
* The tokenized words indexes
* @since 0.6.0
*/
getWords(): (number | undefined)[];
getWordIds(): (number | undefined)[];
/**
* The sequences indices
*/
getSequences(): (number | undefined)[];
getSequenceIds(): (number | undefined)[];
/**
* Pad the current Encoding at the given length

View File

@ -94,7 +94,8 @@ describe("RawEncoding", () => {
expect(typeof encoding.getSpecialTokensMask).toBe("function");
expect(typeof encoding.getTokens).toBe("function");
expect(typeof encoding.getTypeIds).toBe("function");
expect(typeof encoding.getWords).toBe("function");
expect(typeof encoding.getWordIds).toBe("function");
expect(typeof encoding.getSequenceIds).toBe("function");
expect(typeof encoding.pad).toBe("function");
expect(typeof encoding.truncate).toBe("function");
});
@ -105,17 +106,17 @@ describe("RawEncoding", () => {
});
});
describe("getWords", () => {
describe("getWordIds", () => {
it("returns the correct list of indexes", () => {
const indexes = encoding.getWords();
const indexes = encoding.getWordIds();
expect(indexes).toEqual([0, 1, 2, 3, 3]);
});
});
describe("getSequences", () => {
describe("getSequenceIds", () => {
it("returns the correct list of indexes", () => {
expect(encoding.getSequences()).toEqual([0, 0, 0, 0, 0]);
expect(encodingDual.getSequences()).toEqual([0, 0, 0, 0, 0, 1, 1, 1, 1]);
expect(encoding.getSequenceIds()).toEqual([0, 0, 0, 0, 0]);
expect(encodingDual.getSequenceIds()).toEqual([0, 0, 0, 0, 0, 1, 1, 1, 1]);
});
});

View File

@ -149,7 +149,7 @@ export class Encoding {
return this._wordIndexes;
}
return (this._wordIndexes = this._rawEncoding.getWords());
return (this._wordIndexes = this._rawEncoding.getWordIds());
}
get sequenceIndexes(): (number | undefined)[] {
@ -157,7 +157,7 @@ export class Encoding {
return this._sequenceIndexes;
}
return (this._sequenceIndexes = this._rawEncoding.getSequences());
return (this._sequenceIndexes = this._rawEncoding.getSequenceIds());
}
/**

View File

@ -114,27 +114,27 @@ declare_types! {
Ok(neon_serde::to_value(&mut cx, &tokens)?)
}
method getWords(mut cx) {
// getWords(): (number | undefined)[]
method getWordIds(mut cx) {
// getWordIds(): (number | undefined)[]
let this = cx.this();
let guard = cx.lock();
let ids = this.borrow(&guard)
.encoding.as_ref().expect("Uninitialized Encoding")
.get_words()
.get_word_ids()
.to_vec();
Ok(neon_serde::to_value(&mut cx, &ids)?)
}
method getSequences(mut cx) {
// getSequences(): (number | undefined)[]
method getSequenceIds(mut cx) {
// getSequenceIds(): (number | undefined)[]
let this = cx.this();
let guard = cx.lock();
let ids = this.borrow(&guard)
.encoding.as_ref().expect("Uninitialized Encoding")
.get_sequences();
.get_sequence_ids();
Ok(neon_serde::to_value(&mut cx, &ids)?)
}

View File

@ -5,7 +5,7 @@ use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::{Offsets, PaddingDirection};
use tokenizers as tk;
use crate::error::PyError;
use crate::error::{deprecation_warning, PyError};
/// The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`.
#[pyclass(dict, module = "tokenizers", name=Encoding)]
@ -137,6 +137,10 @@ impl PyEncoding {
/// The generated word indices.
///
/// .. warning::
/// This is deprecated and will be removed in a future version.
/// Please use :obj:`~tokenizers.Encoding.word_ids` instead.
///
/// They represent the index of the word associated to each token.
/// When the input is pre-tokenized, they correspond to the ID of the given input label,
/// otherwise they correspond to the words indices as defined by the
@ -148,8 +152,29 @@ impl PyEncoding {
/// Returns:
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index.
#[getter]
fn get_words(&self) -> Vec<Option<u32>> {
self.encoding.get_words().to_vec()
fn get_words(&self) -> PyResult<Vec<Option<u32>>> {
deprecation_warning(
"0.9.4",
"Encoding.words is deprecated, please use Encoding.word_ids instead.",
)?;
Ok(self.get_word_ids())
}
/// The generated word indices.
///
/// They represent the index of the word associated to each token.
/// When the input is pre-tokenized, they correspond to the ID of the given input label,
/// otherwise they correspond to the words indices as defined by the
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer` that was used.
///
/// For special tokens and such (any token that was generated from something that was
/// not part of the input), the output is :obj:`None`
///
/// Returns:
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index.
#[getter]
fn get_word_ids(&self) -> Vec<Option<u32>> {
self.encoding.get_word_ids().to_vec()
}
/// The generated sequence indices.
@ -161,8 +186,8 @@ impl PyEncoding {
/// Returns:
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index.
#[getter]
fn get_sequences(&self) -> Vec<Option<usize>> {
self.encoding.get_sequences()
fn get_sequence_ids(&self) -> Vec<Option<usize>> {
self.encoding.get_sequence_ids()
}
/// The generated type IDs

View File

@ -12,11 +12,11 @@ class TestEncoding:
pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?")
return single_encoding, pair_encoding
def test_sequences(self, encodings):
def test_sequence_ids(self, encodings):
single, pair = encodings
assert single.sequences == [None, 0, 0, 0, 0, None]
assert pair.sequences == [None, 0, 0, 0, 0, None, 1, 1, 1, None]
assert single.sequence_ids == [None, 0, 0, 0, 0, None]
assert pair.sequence_ids == [None, 0, 0, 0, 0, None, 1, 1, 1, None]
def test_n_sequences(self, encodings):
single, pair = encodings

View File

@ -52,7 +52,7 @@ impl PostProcessor for BertProcessing {
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat();
let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
@ -80,7 +80,7 @@ impl PostProcessor for BertProcessing {
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat();
let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
@ -109,7 +109,7 @@ impl PostProcessor for BertProcessing {
let pair_ids = [&encoding.get_ids()[..], &[self.sep.1]].concat();
let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat();
let pair_tokens = [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat();
let pair_words = [&encoding.get_words()[..], &[None]].concat();
let pair_words = [&encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
@ -134,7 +134,7 @@ impl PostProcessor for BertProcessing {
let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat();
let pair_tokens =
[&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat();
let pair_words = [&encoding.get_words()[..], &[None]].concat();
let pair_words = [&encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();

View File

@ -85,7 +85,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat();
let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
@ -113,7 +113,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat();
let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
@ -147,7 +147,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()],
]
.concat();
let pair_words = [&[None], &encoding.get_words()[..], &[None]].concat();
let pair_words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
@ -177,7 +177,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()],
]
.concat();
let pair_words = [&[None], &encoding.get_words()[..], &[None]].concat();
let pair_words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets =
[&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens =

View File

@ -571,7 +571,7 @@ impl TemplateProcessing {
ids.extend(encoding.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(encoding.len()));
tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned()));
words.extend(encoding.get_words());
words.extend(encoding.get_word_ids());
offsets.extend(encoding.get_offsets());
special_tokens_mask.extend(encoding.get_special_tokens_mask());
attention_mask.extend(encoding.get_attention_mask());
@ -587,7 +587,7 @@ impl TemplateProcessing {
ids.extend(pair.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(pair.len()));
tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned()));
words.extend(pair.get_words());
words.extend(pair.get_word_ids());
offsets.extend(pair.get_offsets());
special_tokens_mask.extend(pair.get_special_tokens_mask());
attention_mask.extend(pair.get_attention_mask());

View File

@ -125,11 +125,15 @@ impl Encoding {
&self.tokens[..]
}
pub fn get_words(&self) -> &[Option<u32>] {
pub fn get_word_ids(&self) -> &[Option<u32>] {
&self.words
}
pub fn get_sequences(&self) -> Vec<Option<usize>> {
pub fn get_word_ids_mut(&mut self) -> &mut [Option<u32>] {
&mut self.words
}
pub fn get_sequence_ids(&self) -> Vec<Option<usize>> {
let mut sequences = vec![None; self.len()];
for seq_id in 0..self.n_sequences() {
let range = self.sequence_range(seq_id);
@ -139,10 +143,6 @@ impl Encoding {
sequences
}
pub fn get_words_mut(&mut self) -> &mut [Option<u32>] {
&mut self.words
}
pub fn get_ids(&self) -> &[u32] {
&self.ids
}

View File

@ -78,7 +78,7 @@ fn byte_level_double_sequence() {
]
);
assert_eq!(
output.get_words(),
output.get_word_ids(),
&[
Some(0),
Some(1),
@ -126,7 +126,7 @@ fn byte_level_pre_tokenized_sequence() {
&["ĠMy", "Ġname", "Ġis", "ĠAnth", "on", "ino"]
);
assert_eq!(
output.get_words(),
output.get_word_ids(),
&[Some(0), Some(1), Some(2), Some(3), Some(3), Some(3)]
);
assert_eq!(
@ -145,7 +145,7 @@ fn byte_level_pre_tokenized_sequence_with_trimming() {
let output = tokenizer.encode(&input[..], false).unwrap();
assert_eq!(
output.get_words(),
output.get_word_ids(),
&[Some(0), Some(1), Some(2), Some(3), Some(3), Some(3)]
);
assert_eq!(
@ -179,7 +179,7 @@ fn split_on_added_tokens_bert() {
&["yesterday", "i", "saw", "a", "[MASK]", "far", "away"]
);
assert_eq!(
output.get_words(),
output.get_word_ids(),
&[
Some(0),
Some(1),