From 71b7830d1b4b633e05cfc2b5271f08a215db2a04 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 16 Apr 2020 11:00:04 -0400 Subject: [PATCH] Rust | Python | Node - Also add char_to_word --- bindings/node/lib/bindings/raw-encoding.d.ts | 7 ++++++ .../node/lib/bindings/raw-encoding.test.ts | 13 +++++++++++ bindings/node/native/src/encoding.rs | 18 +++++++++++++++ bindings/python/tokenizers/__init__.pyi | 12 ++++++++++ tokenizers/src/tokenizer/encoding.rs | 23 +++++++++++++++---- 5 files changed, 68 insertions(+), 5 deletions(-) diff --git a/bindings/node/lib/bindings/raw-encoding.d.ts b/bindings/node/lib/bindings/raw-encoding.d.ts index db79ac6a..d4ad6ecc 100644 --- a/bindings/node/lib/bindings/raw-encoding.d.ts +++ b/bindings/node/lib/bindings/raw-encoding.d.ts @@ -40,6 +40,13 @@ export interface RawEncoding { */ charToToken(pos: number): number | undefined; + /** + * Get the word that contains the given char + * @param pos The position of a char in the input string + * @since 0.7.0 + */ + charToWord(pos: number): number | undefined; + /** * Returns the attention mask */ diff --git a/bindings/node/lib/bindings/raw-encoding.test.ts b/bindings/node/lib/bindings/raw-encoding.test.ts index 2a85044e..d38a9950 100644 --- a/bindings/node/lib/bindings/raw-encoding.test.ts +++ b/bindings/node/lib/bindings/raw-encoding.test.ts @@ -39,6 +39,7 @@ describe("RawEncoding", () => { expect(typeof encoding.tokenToChars).toBe("function"); expect(typeof encoding.tokenToWord).toBe("function"); expect(typeof encoding.charToToken).toBe("function"); + expect(typeof encoding.charToWord).toBe("function"); expect(typeof encoding.getAttentionMask).toBe("function"); expect(typeof encoding.getIds).toBe("function"); expect(typeof encoding.getLength).toBe("function"); @@ -124,4 +125,16 @@ describe("RawEncoding", () => { expect(index).toBeUndefined(); }); }); + + describe("charToWord", () => { + it("returns the correct index", () => { + const index = encoding.charToWord(3); + expect(index).toEqual(1); + }); + + it("returns undefined when out of range char", () => { + const index = encoding.charToWord(100); + expect(index).toBeUndefined(); + }); + }); }); diff --git a/bindings/node/native/src/encoding.rs b/bindings/node/native/src/encoding.rs index f9873207..5f06028c 100644 --- a/bindings/node/native/src/encoding.rs +++ b/bindings/node/native/src/encoding.rs @@ -289,6 +289,24 @@ declare_types! { } } + method charToWord(mut cx) { + // charToWord(pos: number): number | undefined + + let pos = cx.argument::(0)?.value() as usize; + + let this = cx.this(); + let guard = cx.lock(); + let index = this.borrow(&guard).encoding.execute(|encoding| { + encoding.unwrap().char_to_word(pos) + }); + + if let Some(index) = index { + Ok(cx.number(index as f64).upcast()) + } else { + Ok(cx.undefined().upcast()) + } + } + method pad(mut cx) { // pad(length: number, options?: { // direction?: 'left' | 'right' = 'right', diff --git a/bindings/python/tokenizers/__init__.pyi b/bindings/python/tokenizers/__init__.pyi index 24c1e241..d54fe5ba 100644 --- a/bindings/python/tokenizers/__init__.pyi +++ b/bindings/python/tokenizers/__init__.pyi @@ -131,6 +131,18 @@ class Encoding: The index of the token that contains this char """ pass + def char_to_word(self, pos: int) -> Optional[int]: + """ + Get the word that contains the given char. + + Args: + pos: int: + The position of a char in the input string + + Returns: + The index of the word that contains this char + """ + pass def pad( self, length: int, diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 8b5caf7b..6cbc557d 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -120,7 +120,8 @@ impl Encoding { std::mem::replace(&mut self.overflowing, vec![]) } - /// Convert the given word index, to the corresponding tokens [start, end) + /// Get the encoded tokens corresponding to the word at the given index in the input sequence, + /// with the form (start_token, end_token + 1) pub fn word_to_tokens(&self, word: u32) -> Option<(usize, usize)> { let (mut start, mut end) = (None, None); self.words @@ -144,7 +145,7 @@ impl Encoding { } } - /// Find the offsets of the given word + /// Get the offsets of the word at the given index in the input sequence. pub fn word_to_chars(&self, word: u32) -> Option { self.word_to_tokens(word) .map(|(start, end)| { @@ -157,23 +158,30 @@ impl Encoding { .flatten() } - /// Find the offsets of the given token + /// Get the offsets of the token at the given index. pub fn token_to_chars(&self, token: usize) -> Option { self.offsets.get(token).copied() } - /// Find the index of the word that contains the token at the given index + /// Get the word that contains the token at the given index. pub fn token_to_word(&self, token: usize) -> Option { self.words.get(token).copied().flatten() } - /// Return the index of the token at position of the given char. + /// Get the token that contains the given char. pub fn char_to_token(&self, pos: usize) -> Option { self.offsets .iter() .position(|(start, end)| pos >= *start && pos < *end) } + /// Get the word that contains the given char. + pub fn char_to_word(&self, pos: usize) -> Option { + self.char_to_token(pos) + .map(|token| self.token_to_word(token)) + .flatten() + } + /// Truncate the current `Encoding`. /// /// Panic if `stride >= max_len` or `max_len == 0`. @@ -550,5 +558,10 @@ mod tests { assert_eq!(encoding.char_to_token(8), Some(2)); assert_eq!(encoding.char_to_token(16), None); assert_eq!(encoding.char_to_token(23), Some(6)); + + assert_eq!(encoding.char_to_word(3), Some(0)); + assert_eq!(encoding.char_to_word(8), Some(1)); + assert_eq!(encoding.char_to_word(16), None); + assert_eq!(encoding.char_to_word(23), Some(3)); } }