From 5682627223c249c17c13c6ba34e204a8cb30b00c Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Sat, 28 Dec 2019 15:21:50 -0500 Subject: [PATCH] BPE handles offsets --- tokenizers/src/models/bpe/model.rs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index c938227e..42fd70bb 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,5 +1,5 @@ use super::{Cache, Error, Pair, Word}; -use crate::tokenizer::{Model, Result, Token}; +use crate::tokenizer::{Model, Offsets, Result, Token}; use serde_json::Value; use std::{ collections::HashMap, @@ -103,15 +103,20 @@ impl Model for BPE { self.vocab.len() } - fn tokenize(&self, sentence: Vec) -> Result> { + fn tokenize(&self, sentence: Vec<(String, Offsets)>) -> Result> { if sentence.is_empty() { return Ok(vec![]); } let mut encoded: Vec = Vec::with_capacity(sentence.len()); - let mut cached_words = self.cache.get_values(&sentence); + let mut cached_words = self.cache.get_values( + &sentence + .iter() + .map(|(s, _)| s.to_owned()) + .collect::>(), + ); - for (i, w) in sentence.iter().enumerate() { + for (i, (w, initial_offsets)) in sentence.iter().enumerate() { if cached_words[i].is_none() { let mut word = Word::new(); for c in w.chars() { @@ -155,9 +160,6 @@ impl Model for BPE { cached_words[i] = Some(word); } - // Offsets are word-based, we need to translate them to be sentence-based - let last_offset = encoded.last().map(|token| token.offsets.1).unwrap_or(0); - let word = cached_words[i].as_ref().unwrap(); let tokens = word .get_chars() @@ -167,7 +169,7 @@ impl Model for BPE { Token::new( *id, self.vocab_r[id].clone(), - (last_offset + offsets.0, last_offset + offsets.1), + (initial_offsets.0 + offsets.0, initial_offsets.0 + offsets.1), ) }) .collect::>(); @@ -180,7 +182,7 @@ impl Model for BPE { .into_iter() .zip(cached_words) .filter(|(_, v)| v.is_some()) - .map(|(k, v)| (k, v.unwrap())) + .map(|(k, v)| (k.0, v.unwrap())) .unzip::<_, _, Vec, Vec>(); self.cache.set_values(keys, values);