words -> word_ids & sequences -> sequence_ids

This commit is contained in:
Anthony MOI
2020-11-09 15:14:14 -05:00
committed by Anthony MOI
parent 57d162b269
commit d3d9f2c76b
11 changed files with 70 additions and 44 deletions

View File

@ -121,12 +121,12 @@ export interface RawEncoding {
* The tokenized words indexes * The tokenized words indexes
* @since 0.6.0 * @since 0.6.0
*/ */
getWords(): (number | undefined)[]; getWordIds(): (number | undefined)[];
/** /**
* The sequences indices * The sequences indices
*/ */
getSequences(): (number | undefined)[]; getSequenceIds(): (number | undefined)[];
/** /**
* Pad the current Encoding at the given length * Pad the current Encoding at the given length

View File

@ -94,7 +94,8 @@ describe("RawEncoding", () => {
expect(typeof encoding.getSpecialTokensMask).toBe("function"); expect(typeof encoding.getSpecialTokensMask).toBe("function");
expect(typeof encoding.getTokens).toBe("function"); expect(typeof encoding.getTokens).toBe("function");
expect(typeof encoding.getTypeIds).toBe("function"); expect(typeof encoding.getTypeIds).toBe("function");
expect(typeof encoding.getWords).toBe("function"); expect(typeof encoding.getWordIds).toBe("function");
expect(typeof encoding.getSequenceIds).toBe("function");
expect(typeof encoding.pad).toBe("function"); expect(typeof encoding.pad).toBe("function");
expect(typeof encoding.truncate).toBe("function"); expect(typeof encoding.truncate).toBe("function");
}); });
@ -105,17 +106,17 @@ describe("RawEncoding", () => {
}); });
}); });
describe("getWords", () => { describe("getWordIds", () => {
it("returns the correct list of indexes", () => { it("returns the correct list of indexes", () => {
const indexes = encoding.getWords(); const indexes = encoding.getWordIds();
expect(indexes).toEqual([0, 1, 2, 3, 3]); expect(indexes).toEqual([0, 1, 2, 3, 3]);
}); });
}); });
describe("getSequences", () => { describe("getSequenceIds", () => {
it("returns the correct list of indexes", () => { it("returns the correct list of indexes", () => {
expect(encoding.getSequences()).toEqual([0, 0, 0, 0, 0]); expect(encoding.getSequenceIds()).toEqual([0, 0, 0, 0, 0]);
expect(encodingDual.getSequences()).toEqual([0, 0, 0, 0, 0, 1, 1, 1, 1]); expect(encodingDual.getSequenceIds()).toEqual([0, 0, 0, 0, 0, 1, 1, 1, 1]);
}); });
}); });

View File

@ -149,7 +149,7 @@ export class Encoding {
return this._wordIndexes; return this._wordIndexes;
} }
return (this._wordIndexes = this._rawEncoding.getWords()); return (this._wordIndexes = this._rawEncoding.getWordIds());
} }
get sequenceIndexes(): (number | undefined)[] { get sequenceIndexes(): (number | undefined)[] {
@ -157,7 +157,7 @@ export class Encoding {
return this._sequenceIndexes; return this._sequenceIndexes;
} }
return (this._sequenceIndexes = this._rawEncoding.getSequences()); return (this._sequenceIndexes = this._rawEncoding.getSequenceIds());
} }
/** /**

View File

@ -114,27 +114,27 @@ declare_types! {
Ok(neon_serde::to_value(&mut cx, &tokens)?) Ok(neon_serde::to_value(&mut cx, &tokens)?)
} }
method getWords(mut cx) { method getWordIds(mut cx) {
// getWords(): (number | undefined)[] // getWordIds(): (number | undefined)[]
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
let ids = this.borrow(&guard) let ids = this.borrow(&guard)
.encoding.as_ref().expect("Uninitialized Encoding") .encoding.as_ref().expect("Uninitialized Encoding")
.get_words() .get_word_ids()
.to_vec(); .to_vec();
Ok(neon_serde::to_value(&mut cx, &ids)?) Ok(neon_serde::to_value(&mut cx, &ids)?)
} }
method getSequences(mut cx) { method getSequenceIds(mut cx) {
// getSequences(): (number | undefined)[] // getSequenceIds(): (number | undefined)[]
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
let ids = this.borrow(&guard) let ids = this.borrow(&guard)
.encoding.as_ref().expect("Uninitialized Encoding") .encoding.as_ref().expect("Uninitialized Encoding")
.get_sequences(); .get_sequence_ids();
Ok(neon_serde::to_value(&mut cx, &ids)?) Ok(neon_serde::to_value(&mut cx, &ids)?)
} }

View File

@ -5,7 +5,7 @@ use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::{Offsets, PaddingDirection}; use tk::tokenizer::{Offsets, PaddingDirection};
use tokenizers as tk; use tokenizers as tk;
use crate::error::PyError; use crate::error::{deprecation_warning, PyError};
/// The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`. /// The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`.
#[pyclass(dict, module = "tokenizers", name=Encoding)] #[pyclass(dict, module = "tokenizers", name=Encoding)]
@ -137,6 +137,10 @@ impl PyEncoding {
/// The generated word indices. /// The generated word indices.
/// ///
/// .. warning::
/// This is deprecated and will be removed in a future version.
/// Please use :obj:`~tokenizers.Encoding.word_ids` instead.
///
/// They represent the index of the word associated to each token. /// They represent the index of the word associated to each token.
/// When the input is pre-tokenized, they correspond to the ID of the given input label, /// When the input is pre-tokenized, they correspond to the ID of the given input label,
/// otherwise they correspond to the words indices as defined by the /// otherwise they correspond to the words indices as defined by the
@ -148,8 +152,29 @@ impl PyEncoding {
/// Returns: /// Returns:
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index. /// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index.
#[getter] #[getter]
fn get_words(&self) -> Vec<Option<u32>> { fn get_words(&self) -> PyResult<Vec<Option<u32>>> {
self.encoding.get_words().to_vec() deprecation_warning(
"0.9.4",
"Encoding.words is deprecated, please use Encoding.word_ids instead.",
)?;
Ok(self.get_word_ids())
}
/// The generated word indices.
///
/// They represent the index of the word associated to each token.
/// When the input is pre-tokenized, they correspond to the ID of the given input label,
/// otherwise they correspond to the words indices as defined by the
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer` that was used.
///
/// For special tokens and such (any token that was generated from something that was
/// not part of the input), the output is :obj:`None`
///
/// Returns:
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index.
#[getter]
fn get_word_ids(&self) -> Vec<Option<u32>> {
self.encoding.get_word_ids().to_vec()
} }
/// The generated sequence indices. /// The generated sequence indices.
@ -161,8 +186,8 @@ impl PyEncoding {
/// Returns: /// Returns:
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index. /// A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index.
#[getter] #[getter]
fn get_sequences(&self) -> Vec<Option<usize>> { fn get_sequence_ids(&self) -> Vec<Option<usize>> {
self.encoding.get_sequences() self.encoding.get_sequence_ids()
} }
/// The generated type IDs /// The generated type IDs

View File

@ -12,11 +12,11 @@ class TestEncoding:
pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?") pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?")
return single_encoding, pair_encoding return single_encoding, pair_encoding
def test_sequences(self, encodings): def test_sequence_ids(self, encodings):
single, pair = encodings single, pair = encodings
assert single.sequences == [None, 0, 0, 0, 0, None] assert single.sequence_ids == [None, 0, 0, 0, 0, None]
assert pair.sequences == [None, 0, 0, 0, 0, None, 1, 1, 1, None] assert pair.sequence_ids == [None, 0, 0, 0, 0, None, 1, 1, 1, None]
def test_n_sequences(self, encodings): def test_n_sequences(self, encodings):
single, pair = encodings single, pair = encodings

View File

@ -52,7 +52,7 @@ impl PostProcessor for BertProcessing {
&[self.sep.0.clone()], &[self.sep.0.clone()],
] ]
.concat(); .concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat(); let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat(); let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()]; let attention_mask = vec![1; ids.len()];
@ -80,7 +80,7 @@ impl PostProcessor for BertProcessing {
&[self.sep.0.clone()], &[self.sep.0.clone()],
] ]
.concat(); .concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat(); let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat(); let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens = let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
@ -109,7 +109,7 @@ impl PostProcessor for BertProcessing {
let pair_ids = [&encoding.get_ids()[..], &[self.sep.1]].concat(); let pair_ids = [&encoding.get_ids()[..], &[self.sep.1]].concat();
let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat(); let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat();
let pair_tokens = [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat(); let pair_tokens = [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat();
let pair_words = [&encoding.get_words()[..], &[None]].concat(); let pair_words = [&encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat(); let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens = let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
@ -134,7 +134,7 @@ impl PostProcessor for BertProcessing {
let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat(); let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat();
let pair_tokens = let pair_tokens =
[&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat(); [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat();
let pair_words = [&encoding.get_words()[..], &[None]].concat(); let pair_words = [&encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat(); let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens = let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();

View File

@ -85,7 +85,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()], &[self.sep.0.clone()],
] ]
.concat(); .concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat(); let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat(); let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()]; let attention_mask = vec![1; ids.len()];
@ -113,7 +113,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()], &[self.sep.0.clone()],
] ]
.concat(); .concat();
let words = [&[None], &encoding.get_words()[..], &[None]].concat(); let words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat(); let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens = let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
@ -147,7 +147,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()], &[self.sep.0.clone()],
] ]
.concat(); .concat();
let pair_words = [&[None], &encoding.get_words()[..], &[None]].concat(); let pair_words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat(); let pair_offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens = let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
@ -177,7 +177,7 @@ impl PostProcessor for RobertaProcessing {
&[self.sep.0.clone()], &[self.sep.0.clone()],
] ]
.concat(); .concat();
let pair_words = [&[None], &encoding.get_words()[..], &[None]].concat(); let pair_words = [&[None], &encoding.get_word_ids()[..], &[None]].concat();
let pair_offsets = let pair_offsets =
[&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat(); [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens = let pair_special_tokens =

View File

@ -571,7 +571,7 @@ impl TemplateProcessing {
ids.extend(encoding.get_ids()); ids.extend(encoding.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(encoding.len())); type_ids.extend(std::iter::repeat(type_id).take(encoding.len()));
tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned())); tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned()));
words.extend(encoding.get_words()); words.extend(encoding.get_word_ids());
offsets.extend(encoding.get_offsets()); offsets.extend(encoding.get_offsets());
special_tokens_mask.extend(encoding.get_special_tokens_mask()); special_tokens_mask.extend(encoding.get_special_tokens_mask());
attention_mask.extend(encoding.get_attention_mask()); attention_mask.extend(encoding.get_attention_mask());
@ -587,7 +587,7 @@ impl TemplateProcessing {
ids.extend(pair.get_ids()); ids.extend(pair.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(pair.len())); type_ids.extend(std::iter::repeat(type_id).take(pair.len()));
tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned())); tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned()));
words.extend(pair.get_words()); words.extend(pair.get_word_ids());
offsets.extend(pair.get_offsets()); offsets.extend(pair.get_offsets());
special_tokens_mask.extend(pair.get_special_tokens_mask()); special_tokens_mask.extend(pair.get_special_tokens_mask());
attention_mask.extend(pair.get_attention_mask()); attention_mask.extend(pair.get_attention_mask());

View File

@ -125,11 +125,15 @@ impl Encoding {
&self.tokens[..] &self.tokens[..]
} }
pub fn get_words(&self) -> &[Option<u32>] { pub fn get_word_ids(&self) -> &[Option<u32>] {
&self.words &self.words
} }
pub fn get_sequences(&self) -> Vec<Option<usize>> { pub fn get_word_ids_mut(&mut self) -> &mut [Option<u32>] {
&mut self.words
}
pub fn get_sequence_ids(&self) -> Vec<Option<usize>> {
let mut sequences = vec![None; self.len()]; let mut sequences = vec![None; self.len()];
for seq_id in 0..self.n_sequences() { for seq_id in 0..self.n_sequences() {
let range = self.sequence_range(seq_id); let range = self.sequence_range(seq_id);
@ -139,10 +143,6 @@ impl Encoding {
sequences sequences
} }
pub fn get_words_mut(&mut self) -> &mut [Option<u32>] {
&mut self.words
}
pub fn get_ids(&self) -> &[u32] { pub fn get_ids(&self) -> &[u32] {
&self.ids &self.ids
} }

View File

@ -78,7 +78,7 @@ fn byte_level_double_sequence() {
] ]
); );
assert_eq!( assert_eq!(
output.get_words(), output.get_word_ids(),
&[ &[
Some(0), Some(0),
Some(1), Some(1),
@ -126,7 +126,7 @@ fn byte_level_pre_tokenized_sequence() {
&["ĠMy", "Ġname", "Ġis", "ĠAnth", "on", "ino"] &["ĠMy", "Ġname", "Ġis", "ĠAnth", "on", "ino"]
); );
assert_eq!( assert_eq!(
output.get_words(), output.get_word_ids(),
&[Some(0), Some(1), Some(2), Some(3), Some(3), Some(3)] &[Some(0), Some(1), Some(2), Some(3), Some(3), Some(3)]
); );
assert_eq!( assert_eq!(
@ -145,7 +145,7 @@ fn byte_level_pre_tokenized_sequence_with_trimming() {
let output = tokenizer.encode(&input[..], false).unwrap(); let output = tokenizer.encode(&input[..], false).unwrap();
assert_eq!( assert_eq!(
output.get_words(), output.get_word_ids(),
&[Some(0), Some(1), Some(2), Some(3), Some(3), Some(3)] &[Some(0), Some(1), Some(2), Some(3), Some(3), Some(3)]
); );
assert_eq!( assert_eq!(
@ -179,7 +179,7 @@ fn split_on_added_tokens_bert() {
&["yesterday", "i", "saw", "a", "[MASK]", "far", "away"] &["yesterday", "i", "saw", "a", "[MASK]", "far", "away"]
); );
assert_eq!( assert_eq!(
output.get_words(), output.get_word_ids(),
&[ &[
Some(0), Some(0),
Some(1), Some(1),