Simplify the API for Encoding.token_to_XXX

This commit is contained in:
Anthony MOI
2020-11-05 12:11:18 -05:00
committed by Anthony MOI
parent 51dbf0b6df
commit 385d25720a
7 changed files with 70 additions and 116 deletions

View File

@ -29,23 +29,27 @@ export interface RawEncoding {
/**
* Get the offsets of the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned.
* If this encoding represents more than one sequence, then it returns a tuple with the sequence
* id in the first part
*
* The returned offsets are related to the input sequence that contains the
* token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence
* @since 0.7.0
*/
tokenToChars(token: number): [number, number] | [number, [number, number]] | undefined;
tokenToChars(token: number): [number, number] | undefined;
/**
* Get the word that contains the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned.
* If this encoding represents more than one sequence, then it returns a tuple with the sequence
* id in the first part
*
* The returned index is related to the input sequence that contains the
* token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence
* @since 0.7.0
*/
tokenToWord(token: number): number | [number, number] | undefined;
tokenToWord(token: number): number | undefined;
/**
* Find the index of the token at the position of the given char

View File

@ -160,8 +160,8 @@ describe("RawEncoding", () => {
});
it("returns the correct offsets with pair sequences", () => {
expect(encodingDual.tokenToChars(3)).toEqual([0, [11, 13]]);
expect(encodingDual.tokenToChars(7)).toEqual([1, [8, 13]]);
expect(encodingDual.tokenToChars(3)).toEqual([11, 13]);
expect(encodingDual.tokenToChars(7)).toEqual([8, 13]);
});
it("returns undefined when out of range token", () => {
@ -177,8 +177,8 @@ describe("RawEncoding", () => {
});
it("returns the correct index with pair sequences", () => {
expect(encodingDual.tokenToWord(3)).toEqual([0, 3]);
expect(encodingDual.tokenToWord(7)).toEqual([1, 2]);
expect(encodingDual.tokenToWord(3)).toEqual(3);
expect(encodingDual.tokenToWord(7)).toEqual(2);
});
it("returns undefined when out of range token", () => {

View File

@ -182,25 +182,29 @@ export class Encoding {
/**
* Get the offsets of the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned.
* If this encoding represents more than one sequence, then it returns a tuple with the sequence
* id in the first part
*
* The returned offsets are related to the input sequence that contains the
* token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence
* @since 0.7.0
*/
tokenToChars(token: number): [number, number] | [number, [number, number]] | undefined {
tokenToChars(token: number): [number, number] | undefined {
return this._rawEncoding.tokenToChars(token);
}
/**
* Get the word that contains the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned.
* If this encoding represents more than one sequence, then it returns a tuple with the sequence
* id in the first part
*
* The returned index is related to the input sequence that contains the
* token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence
* @since 0.7.0
*/
tokenToWord(token: number): number | [number, number] | undefined {
tokenToWord(token: number): number | undefined {
return this._rawEncoding.tokenToWord(token);
}

View File

@ -233,21 +233,12 @@ declare_types! {
let this = cx.this();
let guard = cx.lock();
let (res, n_seq) = {
let borrowed = this.borrow(&guard);
let encoding = borrowed.encoding.as_ref().expect("Uninitialized Encoding");
let res = this.borrow(&guard)
.encoding.as_ref().expect("Uninitialized Encoding")
.token_to_chars(token);
let res = encoding.token_to_chars(token);
let n_seq = encoding.n_sequences();
(res, n_seq)
};
if let Some((seq_id, offsets)) = res {
if n_seq > 1 {
Ok(neon_serde::to_value(&mut cx, &(seq_id, offsets))?)
} else {
if let Some((_, offsets)) = res {
Ok(neon_serde::to_value(&mut cx, &offsets)?)
}
} else {
Ok(cx.undefined().upcast())
}
@ -261,21 +252,12 @@ declare_types! {
let this = cx.this();
let guard = cx.lock();
let (res, n_seq) = {
let borrowed = this.borrow(&guard);
let encoding = borrowed.encoding.as_ref().expect("Uninitialized Encoding");
let res = this.borrow(&guard)
.encoding.as_ref().expect("Uninitialized Encoding")
.token_to_word(token);
let res = encoding.token_to_word(token);
let n_seq = encoding.n_sequences();
(res, n_seq)
};
if let Some((seq_id, index)) = res {
if n_seq > 1 {
Ok(neon_serde::to_value(&mut cx, &(seq_id, index))?)
} else {
if let Some((_, index)) = res {
Ok(cx.number(index as f64).upcast())
}
} else {
Ok(cx.undefined().upcast())
}

View File

@ -424,46 +424,34 @@ class Encoding:
:obj:`int`: The sequence id of the given token
"""
pass
def token_to_chars(self, token_index: int) -> Optional[Union[Offsets, Tuple[int, Offsets]]]:
def token_to_chars(self, token_index: int) -> Optional[Offsets]:
"""Get the offsets of the token at the given index.
If the :class:`~tokenizers.Encoding` represents multiple sequences (namely
a pair of sequences), then this method returns a Tuple with both the relevant
sequence index, and the offsets.
The returned offsets are related to the input sequence that contains the
token. In order to determine in which input sequence it belongs, you
must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
Args:
token_index (:obj:`int`):
The index of a token in the encoded sequence.
Returns:
:obj:`Tuple[int, int]` or :obj:`Tuple[int, Tuple[int, int]]`:
- For a single sequence: the token offsets:
:obj:`Tuple[int, int]` of the form :obj:`(first, last + 1)`
- For pairs of sequence: A tuple with the sequence index, and the token offsets:
:obj:`Tuple[int, Tuple[int, int]]` with offsets of the form :obj:`(first, last + 1)`
:obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)`
"""
pass
def token_to_word(self, token_index: int) -> Optional[Union[int, Tuple[int, int]]]:
"""Get the word that contains the token at the given index
def token_to_word(self, token_index: int) -> Optional[int]:
"""Get the index of the word that contains the token in one of the input sequences.
If the :class:`~tokenizers.Encoding` represents multiple sequences (namely
a pair of sequences), then this method returns a Tuple with both the relevant
sequence index, and the word index.
The returned word index is related to the input sequence that contains
the token. In order to determine in which input sequence it belongs, you
must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
Args:
token_index (:obj:`int`):
The index of a token in the encoded sequence.
Returns:
:obj:`int` or :obj:`Tuple[int, int]`:
- For a single sequence: The index of the word in the input sequence: :obj:`int`
- For pairs of sequence: A tuple with the sequence index, and the index of the word
in the said sequence: :obj:`Tuple[int, int]`
:obj:`int`: The index of the word in the relevant input sequence.
"""
pass
def char_to_token(self, pos: int, sequence_index: int = 0) -> Optional[int]:

View File

@ -270,62 +270,38 @@ impl PyEncoding {
/// Get the offsets of the token at the given index.
///
/// If the :class:`~tokenizers.Encoding` represents multiple sequences (namely
/// a pair of sequences), then this method returns a Tuple with both the relevant
/// sequence index, and the offsets.
/// The returned offsets are related to the input sequence that contains the
/// token. In order to determine in which input sequence it belongs, you
/// must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
///
/// Args:
/// token_index (:obj:`int`):
/// The index of a token in the encoded sequence.
///
/// Returns:
/// :obj:`Tuple[int, int]` or :obj:`Tuple[int, Tuple[int, int]]`:
///
/// - For a single sequence: the token offsets:
/// :obj:`Tuple[int, int]` of the form :obj:`(first, last + 1)`
///
/// - For pairs of sequence: A tuple with the sequence index, and the token offsets:
/// :obj:`Tuple[int, Tuple[int, int]]` with offsets of the form :obj:`(first, last + 1)`
///
/// :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)`
#[text_signature = "($self, token_index)"]
fn token_to_chars(&self, token_index: usize) -> Option<PyObject> {
let (seq_idx, offsets) = self.encoding.token_to_chars(token_index)?;
Python::with_gil(|py| {
if self.encoding.n_sequences() > 1 {
Some((seq_idx, offsets).to_object(py))
} else {
Some(offsets.to_object(py))
}
})
fn token_to_chars(&self, token_index: usize) -> Option<Offsets> {
let (_, offsets) = self.encoding.token_to_chars(token_index)?;
Some(offsets)
}
/// Get the word that contains the token at the given index
/// Get the index of the word that contains the token in one of the input sequences.
///
/// If the :class:`~tokenizers.Encoding` represents multiple sequences (namely
/// a pair of sequences), then this method returns a Tuple with both the relevant
/// sequence index, and the word index.
/// The returned word index is related to the input sequence that contains
/// the token. In order to determine in which input sequence it belongs, you
/// must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
///
/// Args:
/// token_index (:obj:`int`):
/// The index of a token in the encoded sequence.
///
/// Returns:
/// :obj:`int` or :obj:`Tuple[int, int]`:
///
/// - For a single sequence: The index of the word in the input sequence: :obj:`int`
/// - For pairs of sequence: A tuple with the sequence index, and the index of the word
/// in the said sequence: :obj:`Tuple[int, int]`
///
/// :obj:`int`: The index of the word in the relevant input sequence.
#[text_signature = "($self, token_index)"]
fn token_to_word(&self, token_index: usize) -> Option<PyObject> {
let (seq_idx, word_idx) = self.encoding.token_to_word(token_index)?;
Python::with_gil(|py| {
if self.encoding.n_sequences() > 1 {
Some((seq_idx, word_idx).to_object(py))
} else {
Some(word_idx.to_object(py))
}
})
fn token_to_word(&self, token_index: usize) -> Option<u32> {
let (_, word_idx) = self.encoding.token_to_word(token_index)?;
Some(word_idx)
}
/// Get the token that contains the char at the given position in the input sequence.

View File

@ -65,9 +65,9 @@ class TestEncoding:
assert single.token_to_chars(0) == None
assert single.token_to_chars(2) == (2, 6)
assert pair.token_to_chars(2) == (0, (2, 6))
assert pair.token_to_chars(2) == (2, 6)
assert pair.token_to_chars(5) == None
assert pair.token_to_chars(6) == (1, (0, 2))
assert pair.token_to_chars(6) == (0, 2)
def test_token_to_word(self, encodings):
single, pair = encodings
@ -75,11 +75,11 @@ class TestEncoding:
assert single.token_to_word(0) == None
assert single.token_to_word(1) == 0
assert single.token_to_word(4) == 2
assert pair.token_to_word(1) == (0, 0)
assert pair.token_to_word(4) == (0, 2)
assert pair.token_to_word(1) == 0
assert pair.token_to_word(4) == 2
assert pair.token_to_word(5) == None
assert pair.token_to_word(6) == (1, 0)
assert pair.token_to_word(7) == (1, 1)
assert pair.token_to_word(6) == 0
assert pair.token_to_word(7) == 1
def test_char_to_token(self, encodings):
single, pair = encodings