Rust | Python | Node - Also add char_to_word

This commit is contained in:
Anthony MOI
2020-04-16 11:00:04 -04:00
parent 4aecd82d07
commit 71b7830d1b
5 changed files with 68 additions and 5 deletions

View File

@@ -40,6 +40,13 @@ export interface RawEncoding {
*/
charToToken(pos: number): number | undefined;
/**
* Get the word that contains the given char
* @param pos The position of a char in the input string
* @since 0.7.0
*/
charToWord(pos: number): number | undefined;
/**
* Returns the attention mask
*/

View File

@@ -39,6 +39,7 @@ describe("RawEncoding", () => {
expect(typeof encoding.tokenToChars).toBe("function");
expect(typeof encoding.tokenToWord).toBe("function");
expect(typeof encoding.charToToken).toBe("function");
expect(typeof encoding.charToWord).toBe("function");
expect(typeof encoding.getAttentionMask).toBe("function");
expect(typeof encoding.getIds).toBe("function");
expect(typeof encoding.getLength).toBe("function");
@@ -124,4 +125,16 @@ describe("RawEncoding", () => {
expect(index).toBeUndefined();
});
});
describe("charToWord", () => {
it("returns the correct index", () => {
const index = encoding.charToWord(3);
expect(index).toEqual(1);
});
it("returns undefined when out of range char", () => {
const index = encoding.charToWord(100);
expect(index).toBeUndefined();
});
});
});

View File

@@ -289,6 +289,24 @@ declare_types! {
}
}
method charToWord(mut cx) {
// charToWord(pos: number): number | undefined
let pos = cx.argument::<JsNumber>(0)?.value() as usize;
let this = cx.this();
let guard = cx.lock();
let index = this.borrow(&guard).encoding.execute(|encoding| {
encoding.unwrap().char_to_word(pos)
});
if let Some(index) = index {
Ok(cx.number(index as f64).upcast())
} else {
Ok(cx.undefined().upcast())
}
}
method pad(mut cx) {
// pad(length: number, options?: {
// direction?: 'left' | 'right' = 'right',

View File

@@ -131,6 +131,18 @@ class Encoding:
The index of the token that contains this char
"""
pass
def char_to_word(self, pos: int) -> Optional[int]:
"""
Get the word that contains the given char.
Args:
pos: int:
The position of a char in the input string
Returns:
The index of the word that contains this char
"""
pass
def pad(
self,
length: int,

View File

@@ -120,7 +120,8 @@ impl Encoding {
std::mem::replace(&mut self.overflowing, vec![])
}
/// Convert the given word index, to the corresponding tokens [start, end)
/// Get the encoded tokens corresponding to the word at the given index in the input sequence,
/// with the form (start_token, end_token + 1)
pub fn word_to_tokens(&self, word: u32) -> Option<(usize, usize)> {
let (mut start, mut end) = (None, None);
self.words
@@ -144,7 +145,7 @@ impl Encoding {
}
}
/// Find the offsets of the given word
/// Get the offsets of the word at the given index in the input sequence.
pub fn word_to_chars(&self, word: u32) -> Option<Offsets> {
self.word_to_tokens(word)
.map(|(start, end)| {
@@ -157,23 +158,30 @@ impl Encoding {
.flatten()
}
/// Find the offsets of the given token
/// Get the offsets of the token at the given index.
pub fn token_to_chars(&self, token: usize) -> Option<Offsets> {
self.offsets.get(token).copied()
}
/// Find the index of the word that contains the token at the given index
/// Get the word that contains the token at the given index.
pub fn token_to_word(&self, token: usize) -> Option<u32> {
self.words.get(token).copied().flatten()
}
/// Return the index of the token at position of the given char.
/// Get the token that contains the given char.
pub fn char_to_token(&self, pos: usize) -> Option<usize> {
self.offsets
.iter()
.position(|(start, end)| pos >= *start && pos < *end)
}
/// Get the word that contains the given char.
pub fn char_to_word(&self, pos: usize) -> Option<u32> {
self.char_to_token(pos)
.map(|token| self.token_to_word(token))
.flatten()
}
/// Truncate the current `Encoding`.
///
/// Panic if `stride >= max_len` or `max_len == 0`.
@@ -550,5 +558,10 @@ mod tests {
assert_eq!(encoding.char_to_token(8), Some(2));
assert_eq!(encoding.char_to_token(16), None);
assert_eq!(encoding.char_to_token(23), Some(6));
assert_eq!(encoding.char_to_word(3), Some(0));
assert_eq!(encoding.char_to_word(8), Some(1));
assert_eq!(encoding.char_to_word(16), None);
assert_eq!(encoding.char_to_word(23), Some(3));
}
}