From a16daa78f1469001b5f17ef4fbf21a4ae7234a63 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 19 Dec 2019 14:45:38 -0800 Subject: [PATCH] add test for word merge --- tokenizers/src/models/bpe/word.rs | 50 ++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 541ddac5..b61b5b49 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,6 +1,5 @@ use super::Pair; -// TODO: Add tests #[derive(Clone, Default)] pub struct Word { chars: Vec, @@ -75,3 +74,52 @@ impl Word { offsets } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge() { + // Let's say we have the word 'hello' and a word-to-id vocab that looks + // like this: {'h': 0, 'e': 1, 'l': 2, 'o': 3}. + let mut word = Word::new(); + word.add(0); // 'h' + word.add(1); // 'e' + word.add(2); // 'l' + word.add(2); // 'l' + word.add(3); // 'o' + + // We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's + // say that 'll' has the ID of 4 in the updated word-to-id vocab. + let changes = word.merge(2, 2, 4); + + // So the word should now look like this: + assert_eq!( + word.get_chars(), + &[ + 0u32, // 'h' + 1u32, // 'e' + 4u32, // 'll' + 3u32, // 'o' + ] + ); + + // The return value `changes` will be used to update the pair counts during + // training. This merge affects the counts for the pairs + // ('e', 'l') ~= (1, 2), + // ('e', 'll') ~= (1, 4), + // ('ll', 'o') ~= (4, 3), and + // ('l', 'o') ~= (2, 3). + // So the changes should reflect that: + assert_eq!( + changes, + &[ + ((1u32, 2u32), -1i32), // count for ('e', 'l') should be decreased by 1. + ((1u32, 4u32), 1i32), // count for ('e', 'll') should be increased by 1. + ((2u32, 3u32), -1i32), // count for ('l', 'o') should be decreased by 1. + ((4u32, 3u32), 1i32), // count for ('ll', 'o') should be increased by 1. + ] + ); + } +}