Simplify the API for Encoding.token_to_XXX

This commit is contained in:
Anthony MOI
2020-11-05 12:11:18 -05:00
committed by Anthony MOI
parent 51dbf0b6df
commit 385d25720a
7 changed files with 70 additions and 116 deletions

View File

@@ -29,23 +29,27 @@ export interface RawEncoding {
/** /**
* Get the offsets of the token at the given index * Get the offsets of the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned. *
* If this encoding represents more than one sequence, then it returns a tuple with the sequence * The returned offsets are related to the input sequence that contains the
* id in the first part * token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence * @param token The index of the token in the encoded sequence
* @since 0.7.0 * @since 0.7.0
*/ */
tokenToChars(token: number): [number, number] | [number, [number, number]] | undefined; tokenToChars(token: number): [number, number] | undefined;
/** /**
* Get the word that contains the token at the given index * Get the word that contains the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned. *
* If this encoding represents more than one sequence, then it returns a tuple with the sequence * The returned index is related to the input sequence that contains the
* id in the first part * token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence * @param token The index of the token in the encoded sequence
* @since 0.7.0 * @since 0.7.0
*/ */
tokenToWord(token: number): number | [number, number] | undefined; tokenToWord(token: number): number | undefined;
/** /**
* Find the index of the token at the position of the given char * Find the index of the token at the position of the given char

View File

@@ -160,8 +160,8 @@ describe("RawEncoding", () => {
}); });
it("returns the correct offsets with pair sequences", () => { it("returns the correct offsets with pair sequences", () => {
expect(encodingDual.tokenToChars(3)).toEqual([0, [11, 13]]); expect(encodingDual.tokenToChars(3)).toEqual([11, 13]);
expect(encodingDual.tokenToChars(7)).toEqual([1, [8, 13]]); expect(encodingDual.tokenToChars(7)).toEqual([8, 13]);
}); });
it("returns undefined when out of range token", () => { it("returns undefined when out of range token", () => {
@@ -177,8 +177,8 @@ describe("RawEncoding", () => {
}); });
it("returns the correct index with pair sequences", () => { it("returns the correct index with pair sequences", () => {
expect(encodingDual.tokenToWord(3)).toEqual([0, 3]); expect(encodingDual.tokenToWord(3)).toEqual(3);
expect(encodingDual.tokenToWord(7)).toEqual([1, 2]); expect(encodingDual.tokenToWord(7)).toEqual(2);
}); });
it("returns undefined when out of range token", () => { it("returns undefined when out of range token", () => {

View File

@@ -182,25 +182,29 @@ export class Encoding {
/** /**
* Get the offsets of the token at the given index * Get the offsets of the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned. *
* If this encoding represents more than one sequence, then it returns a tuple with the sequence * The returned offsets are related to the input sequence that contains the
* id in the first part * token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence * @param token The index of the token in the encoded sequence
* @since 0.7.0 * @since 0.7.0
*/ */
tokenToChars(token: number): [number, number] | [number, [number, number]] | undefined { tokenToChars(token: number): [number, number] | undefined {
return this._rawEncoding.tokenToChars(token); return this._rawEncoding.tokenToChars(token);
} }
/** /**
* Get the word that contains the token at the given index * Get the word that contains the token at the given index
* If this encoding represents only one sequence, then only the offsets are returned. *
* If this encoding represents more than one sequence, then it returns a tuple with the sequence * The returned index is related to the input sequence that contains the
* id in the first part * token. In order to determine in which input sequence it belongs, you
* must call `tokenToSequence`.
*
* @param token The index of the token in the encoded sequence * @param token The index of the token in the encoded sequence
* @since 0.7.0 * @since 0.7.0
*/ */
tokenToWord(token: number): number | [number, number] | undefined { tokenToWord(token: number): number | undefined {
return this._rawEncoding.tokenToWord(token); return this._rawEncoding.tokenToWord(token);
} }

View File

@@ -233,21 +233,12 @@ declare_types! {
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
let (res, n_seq) = { let res = this.borrow(&guard)
let borrowed = this.borrow(&guard); .encoding.as_ref().expect("Uninitialized Encoding")
let encoding = borrowed.encoding.as_ref().expect("Uninitialized Encoding"); .token_to_chars(token);
let res = encoding.token_to_chars(token); if let Some((_, offsets)) = res {
let n_seq = encoding.n_sequences();
(res, n_seq)
};
if let Some((seq_id, offsets)) = res {
if n_seq > 1 {
Ok(neon_serde::to_value(&mut cx, &(seq_id, offsets))?)
} else {
Ok(neon_serde::to_value(&mut cx, &offsets)?) Ok(neon_serde::to_value(&mut cx, &offsets)?)
}
} else { } else {
Ok(cx.undefined().upcast()) Ok(cx.undefined().upcast())
} }
@@ -261,21 +252,12 @@ declare_types! {
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
let (res, n_seq) = { let res = this.borrow(&guard)
let borrowed = this.borrow(&guard); .encoding.as_ref().expect("Uninitialized Encoding")
let encoding = borrowed.encoding.as_ref().expect("Uninitialized Encoding"); .token_to_word(token);
let res = encoding.token_to_word(token); if let Some((_, index)) = res {
let n_seq = encoding.n_sequences();
(res, n_seq)
};
if let Some((seq_id, index)) = res {
if n_seq > 1 {
Ok(neon_serde::to_value(&mut cx, &(seq_id, index))?)
} else {
Ok(cx.number(index as f64).upcast()) Ok(cx.number(index as f64).upcast())
}
} else { } else {
Ok(cx.undefined().upcast()) Ok(cx.undefined().upcast())
} }

View File

@@ -424,46 +424,34 @@ class Encoding:
:obj:`int`: The sequence id of the given token :obj:`int`: The sequence id of the given token
""" """
pass pass
def token_to_chars(self, token_index: int) -> Optional[Union[Offsets, Tuple[int, Offsets]]]: def token_to_chars(self, token_index: int) -> Optional[Offsets]:
"""Get the offsets of the token at the given index. """Get the offsets of the token at the given index.
If the :class:`~tokenizers.Encoding` represents multiple sequences (namely The returned offsets are related to the input sequence that contains the
a pair of sequences), then this method returns a Tuple with both the relevant token. In order to determine in which input sequence it belongs, you
sequence index, and the offsets. must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
Args: Args:
token_index (:obj:`int`): token_index (:obj:`int`):
The index of a token in the encoded sequence. The index of a token in the encoded sequence.
Returns: Returns:
:obj:`Tuple[int, int]` or :obj:`Tuple[int, Tuple[int, int]]`: :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)`
- For a single sequence: the token offsets:
:obj:`Tuple[int, int]` of the form :obj:`(first, last + 1)`
- For pairs of sequence: A tuple with the sequence index, and the token offsets:
:obj:`Tuple[int, Tuple[int, int]]` with offsets of the form :obj:`(first, last + 1)`
""" """
pass pass
def token_to_word(self, token_index: int) -> Optional[Union[int, Tuple[int, int]]]: def token_to_word(self, token_index: int) -> Optional[int]:
"""Get the word that contains the token at the given index """Get the index of the word that contains the token in one of the input sequences.
If the :class:`~tokenizers.Encoding` represents multiple sequences (namely The returned word index is related to the input sequence that contains
a pair of sequences), then this method returns a Tuple with both the relevant the token. In order to determine in which input sequence it belongs, you
sequence index, and the word index. must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
Args: Args:
token_index (:obj:`int`): token_index (:obj:`int`):
The index of a token in the encoded sequence. The index of a token in the encoded sequence.
Returns: Returns:
:obj:`int` or :obj:`Tuple[int, int]`: :obj:`int`: The index of the word in the relevant input sequence.
- For a single sequence: The index of the word in the input sequence: :obj:`int`
- For pairs of sequence: A tuple with the sequence index, and the index of the word
in the said sequence: :obj:`Tuple[int, int]`
""" """
pass pass
def char_to_token(self, pos: int, sequence_index: int = 0) -> Optional[int]: def char_to_token(self, pos: int, sequence_index: int = 0) -> Optional[int]:

View File

@@ -270,62 +270,38 @@ impl PyEncoding {
/// Get the offsets of the token at the given index. /// Get the offsets of the token at the given index.
/// ///
/// If the :class:`~tokenizers.Encoding` represents multiple sequences (namely /// The returned offsets are related to the input sequence that contains the
/// a pair of sequences), then this method returns a Tuple with both the relevant /// token. In order to determine in which input sequence it belongs, you
/// sequence index, and the offsets. /// must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
/// ///
/// Args: /// Args:
/// token_index (:obj:`int`): /// token_index (:obj:`int`):
/// The index of a token in the encoded sequence. /// The index of a token in the encoded sequence.
/// ///
/// Returns: /// Returns:
/// :obj:`Tuple[int, int]` or :obj:`Tuple[int, Tuple[int, int]]`: /// :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)`
///
/// - For a single sequence: the token offsets:
/// :obj:`Tuple[int, int]` of the form :obj:`(first, last + 1)`
///
/// - For pairs of sequence: A tuple with the sequence index, and the token offsets:
/// :obj:`Tuple[int, Tuple[int, int]]` with offsets of the form :obj:`(first, last + 1)`
///
#[text_signature = "($self, token_index)"] #[text_signature = "($self, token_index)"]
fn token_to_chars(&self, token_index: usize) -> Option<PyObject> { fn token_to_chars(&self, token_index: usize) -> Option<Offsets> {
let (seq_idx, offsets) = self.encoding.token_to_chars(token_index)?; let (_, offsets) = self.encoding.token_to_chars(token_index)?;
Python::with_gil(|py| { Some(offsets)
if self.encoding.n_sequences() > 1 {
Some((seq_idx, offsets).to_object(py))
} else {
Some(offsets.to_object(py))
}
})
} }
/// Get the word that contains the token at the given index /// Get the index of the word that contains the token in one of the input sequences.
/// ///
/// If the :class:`~tokenizers.Encoding` represents multiple sequences (namely /// The returned word index is related to the input sequence that contains
/// a pair of sequences), then this method returns a Tuple with both the relevant /// the token. In order to determine in which input sequence it belongs, you
/// sequence index, and the word index. /// must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
/// ///
/// Args: /// Args:
/// token_index (:obj:`int`): /// token_index (:obj:`int`):
/// The index of a token in the encoded sequence. /// The index of a token in the encoded sequence.
/// ///
/// Returns: /// Returns:
/// :obj:`int` or :obj:`Tuple[int, int]`: /// :obj:`int`: The index of the word in the relevant input sequence.
///
/// - For a single sequence: The index of the word in the input sequence: :obj:`int`
/// - For pairs of sequence: A tuple with the sequence index, and the index of the word
/// in the said sequence: :obj:`Tuple[int, int]`
///
#[text_signature = "($self, token_index)"] #[text_signature = "($self, token_index)"]
fn token_to_word(&self, token_index: usize) -> Option<PyObject> { fn token_to_word(&self, token_index: usize) -> Option<u32> {
let (seq_idx, word_idx) = self.encoding.token_to_word(token_index)?; let (_, word_idx) = self.encoding.token_to_word(token_index)?;
Python::with_gil(|py| { Some(word_idx)
if self.encoding.n_sequences() > 1 {
Some((seq_idx, word_idx).to_object(py))
} else {
Some(word_idx.to_object(py))
}
})
} }
/// Get the token that contains the char at the given position in the input sequence. /// Get the token that contains the char at the given position in the input sequence.

View File

@@ -65,9 +65,9 @@ class TestEncoding:
assert single.token_to_chars(0) == None assert single.token_to_chars(0) == None
assert single.token_to_chars(2) == (2, 6) assert single.token_to_chars(2) == (2, 6)
assert pair.token_to_chars(2) == (0, (2, 6)) assert pair.token_to_chars(2) == (2, 6)
assert pair.token_to_chars(5) == None assert pair.token_to_chars(5) == None
assert pair.token_to_chars(6) == (1, (0, 2)) assert pair.token_to_chars(6) == (0, 2)
def test_token_to_word(self, encodings): def test_token_to_word(self, encodings):
single, pair = encodings single, pair = encodings
@@ -75,11 +75,11 @@ class TestEncoding:
assert single.token_to_word(0) == None assert single.token_to_word(0) == None
assert single.token_to_word(1) == 0 assert single.token_to_word(1) == 0
assert single.token_to_word(4) == 2 assert single.token_to_word(4) == 2
assert pair.token_to_word(1) == (0, 0) assert pair.token_to_word(1) == 0
assert pair.token_to_word(4) == (0, 2) assert pair.token_to_word(4) == 2
assert pair.token_to_word(5) == None assert pair.token_to_word(5) == None
assert pair.token_to_word(6) == (1, 0) assert pair.token_to_word(6) == 0
assert pair.token_to_word(7) == (1, 1) assert pair.token_to_word(7) == 1
def test_char_to_token(self, encodings): def test_char_to_token(self, encodings):
single, pair = encodings single, pair = encodings