diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index 78422049..f8b61db6 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - [#652]: Fix offsets for `Precompiled` corner case +- [#656]: Fix BPE `continuing_subword_prefix` ## [0.10.1] @@ -307,6 +308,7 @@ delimiter (Works like `.split(delimiter)`) - Fix a bug that was causing crashes in Python 3.5 +[#656]: https://github.com/huggingface/tokenizers/pull/656 [#652]: https://github.com/huggingface/tokenizers/pull/652 [#621]: https://github.com/huggingface/tokenizers/pull/621 [#620]: https://github.com/huggingface/tokenizers/pull/620 diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 2ffba711..963c0ec9 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -335,23 +335,29 @@ impl BPE { let mut word = Word::with_capacity(w.len()); let mut unk: Option<(u32, usize)> = None; while let Some(i) = indices.next() { - let (s, byte_len) = if let Some(&end) = indices.peek() { - match (i, self.continuing_subword_prefix.as_ref()) { - (0, Some(prefix)) => ( - Cow::Owned(format!("{}{}", prefix, &w[i..end])), - (i..end).len(), - ), - _ => (Cow::Borrowed(&w[i..end]), (i..end).len()), - } + let end = indices.peek(); + let is_first = i == 0; + let is_last = end.is_none(); + + let mut s = if let Some(e) = end { + Cow::Borrowed(&w[i..*e]) } else { - ( - self.end_of_word_suffix - .as_ref() - .map(|suffix| format!("{}{}", &w[i..], suffix).into()) - .unwrap_or_else(|| Cow::Borrowed(&w[i..])), - w[i..].len(), - ) + Cow::Borrowed(&w[i..]) }; + let byte_len = s.len(); + + // Add the `continuing_subword_prefix` if relevant + if !is_first { + if let Some(ref prefix) = self.continuing_subword_prefix { + s = format!("{}{}", prefix, s).into() + } + } + // Add the `end_of_word_suffix` if relevant + if is_last { + if let Some(ref suffix) = self.end_of_word_suffix { + s = format!("{}{}", s, suffix).into() + } + } if let Some(id) = self.vocab.get(s.as_ref()) { if let Some((unk_id, unk_len)) = unk { @@ -684,17 +690,41 @@ mod tests { ("##b".to_string(), 1), ("##c".to_string(), 2), ("ab".to_string(), 3), + ("abc".to_string(), 4), ] .into_iter() .collect(); - let merges = vec![("a".to_string(), "##b".to_string())]; + let merges = vec![ + ("a".to_string(), "##b".to_string()), + ("ab".to_string(), "##c".to_string()), + ]; - BPE::builder() + let bpe = BPE::builder() .vocab_and_merges(vocab, merges) + .unk_token("[UNK]".to_string()) .continuing_subword_prefix("##".to_string()) .build() .unwrap(); + + let res = bpe.tokenize("ab"); + assert_eq!( + res.unwrap(), + vec![Token { + id: 3, + value: "ab".to_string(), + offsets: (0, 2) + }] + ); + let res = bpe.tokenize("abc"); + assert_eq!( + res.unwrap(), + vec![Token { + id: 4, + value: "abc".to_string(), + offsets: (0, 3) + }] + ); } #[test]