mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 21:28:19 +00:00
Rust | Python | Node - Also add char_to_word
This commit is contained in:
7
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
7
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
@@ -40,6 +40,13 @@ export interface RawEncoding {
|
||||
*/
|
||||
charToToken(pos: number): number | undefined;
|
||||
|
||||
/**
|
||||
* Get the word that contains the given char
|
||||
* @param pos The position of a char in the input string
|
||||
* @since 0.7.0
|
||||
*/
|
||||
charToWord(pos: number): number | undefined;
|
||||
|
||||
/**
|
||||
* Returns the attention mask
|
||||
*/
|
||||
|
||||
@@ -39,6 +39,7 @@ describe("RawEncoding", () => {
|
||||
expect(typeof encoding.tokenToChars).toBe("function");
|
||||
expect(typeof encoding.tokenToWord).toBe("function");
|
||||
expect(typeof encoding.charToToken).toBe("function");
|
||||
expect(typeof encoding.charToWord).toBe("function");
|
||||
expect(typeof encoding.getAttentionMask).toBe("function");
|
||||
expect(typeof encoding.getIds).toBe("function");
|
||||
expect(typeof encoding.getLength).toBe("function");
|
||||
@@ -124,4 +125,16 @@ describe("RawEncoding", () => {
|
||||
expect(index).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("charToWord", () => {
|
||||
it("returns the correct index", () => {
|
||||
const index = encoding.charToWord(3);
|
||||
expect(index).toEqual(1);
|
||||
});
|
||||
|
||||
it("returns undefined when out of range char", () => {
|
||||
const index = encoding.charToWord(100);
|
||||
expect(index).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -289,6 +289,24 @@ declare_types! {
|
||||
}
|
||||
}
|
||||
|
||||
method charToWord(mut cx) {
|
||||
// charToWord(pos: number): number | undefined
|
||||
|
||||
let pos = cx.argument::<JsNumber>(0)?.value() as usize;
|
||||
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let index = this.borrow(&guard).encoding.execute(|encoding| {
|
||||
encoding.unwrap().char_to_word(pos)
|
||||
});
|
||||
|
||||
if let Some(index) = index {
|
||||
Ok(cx.number(index as f64).upcast())
|
||||
} else {
|
||||
Ok(cx.undefined().upcast())
|
||||
}
|
||||
}
|
||||
|
||||
method pad(mut cx) {
|
||||
// pad(length: number, options?: {
|
||||
// direction?: 'left' | 'right' = 'right',
|
||||
|
||||
@@ -131,6 +131,18 @@ class Encoding:
|
||||
The index of the token that contains this char
|
||||
"""
|
||||
pass
|
||||
def char_to_word(self, pos: int) -> Optional[int]:
|
||||
"""
|
||||
Get the word that contains the given char.
|
||||
|
||||
Args:
|
||||
pos: int:
|
||||
The position of a char in the input string
|
||||
|
||||
Returns:
|
||||
The index of the word that contains this char
|
||||
"""
|
||||
pass
|
||||
def pad(
|
||||
self,
|
||||
length: int,
|
||||
|
||||
@@ -120,7 +120,8 @@ impl Encoding {
|
||||
std::mem::replace(&mut self.overflowing, vec![])
|
||||
}
|
||||
|
||||
/// Convert the given word index, to the corresponding tokens [start, end)
|
||||
/// Get the encoded tokens corresponding to the word at the given index in the input sequence,
|
||||
/// with the form (start_token, end_token + 1)
|
||||
pub fn word_to_tokens(&self, word: u32) -> Option<(usize, usize)> {
|
||||
let (mut start, mut end) = (None, None);
|
||||
self.words
|
||||
@@ -144,7 +145,7 @@ impl Encoding {
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the offsets of the given word
|
||||
/// Get the offsets of the word at the given index in the input sequence.
|
||||
pub fn word_to_chars(&self, word: u32) -> Option<Offsets> {
|
||||
self.word_to_tokens(word)
|
||||
.map(|(start, end)| {
|
||||
@@ -157,23 +158,30 @@ impl Encoding {
|
||||
.flatten()
|
||||
}
|
||||
|
||||
/// Find the offsets of the given token
|
||||
/// Get the offsets of the token at the given index.
|
||||
pub fn token_to_chars(&self, token: usize) -> Option<Offsets> {
|
||||
self.offsets.get(token).copied()
|
||||
}
|
||||
|
||||
/// Find the index of the word that contains the token at the given index
|
||||
/// Get the word that contains the token at the given index.
|
||||
pub fn token_to_word(&self, token: usize) -> Option<u32> {
|
||||
self.words.get(token).copied().flatten()
|
||||
}
|
||||
|
||||
/// Return the index of the token at position of the given char.
|
||||
/// Get the token that contains the given char.
|
||||
pub fn char_to_token(&self, pos: usize) -> Option<usize> {
|
||||
self.offsets
|
||||
.iter()
|
||||
.position(|(start, end)| pos >= *start && pos < *end)
|
||||
}
|
||||
|
||||
/// Get the word that contains the given char.
|
||||
pub fn char_to_word(&self, pos: usize) -> Option<u32> {
|
||||
self.char_to_token(pos)
|
||||
.map(|token| self.token_to_word(token))
|
||||
.flatten()
|
||||
}
|
||||
|
||||
/// Truncate the current `Encoding`.
|
||||
///
|
||||
/// Panic if `stride >= max_len` or `max_len == 0`.
|
||||
@@ -550,5 +558,10 @@ mod tests {
|
||||
assert_eq!(encoding.char_to_token(8), Some(2));
|
||||
assert_eq!(encoding.char_to_token(16), None);
|
||||
assert_eq!(encoding.char_to_token(23), Some(6));
|
||||
|
||||
assert_eq!(encoding.char_to_word(3), Some(0));
|
||||
assert_eq!(encoding.char_to_word(8), Some(1));
|
||||
assert_eq!(encoding.char_to_word(16), None);
|
||||
assert_eq!(encoding.char_to_word(23), Some(3));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user