diff --git a/bindings/python/test.py b/bindings/python/test.py deleted file mode 100644 index 28f313a2..00000000 --- a/bindings/python/test.py +++ /dev/null @@ -1,12 +0,0 @@ -from tokenizers import ByteLevelBPETokenizer -from tokenizers import pre_tokenizers, models, Tokenizer, trainers - -tokenizer = Tokenizer(models.Unigram()) -tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit() -trainer = trainers.UnigramTrainer( - vocab_size=400000000, - show_progress=True, - special_tokens=["", "", "", "", "mask"] - ) -tokenizer.train(["data/big.txt"], trainer) - diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index d5315bc6..d57f5792 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -8,6 +8,11 @@ pub enum TruncationDirection { Left, Right, } +impl Default for TruncationDirection { + fn default() -> Self { + TruncationDirection::Right + } +} impl std::convert::AsRef for TruncationDirection { fn as_ref(&self) -> &str { @@ -20,6 +25,7 @@ impl std::convert::AsRef for TruncationDirection { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TruncationParams { + #[serde(default)] pub direction: TruncationDirection, pub max_length: usize, pub strategy: TruncationStrategy, @@ -30,9 +36,9 @@ impl Default for TruncationParams { fn default() -> Self { Self { max_length: 512, - strategy: TruncationStrategy::LongestFirst, + strategy: TruncationStrategy::default(), stride: 0, - direction: TruncationDirection::Right, + direction: TruncationDirection::default(), } } } @@ -68,6 +74,12 @@ pub enum TruncationStrategy { OnlySecond, } +impl Default for TruncationStrategy { + fn default() -> Self { + TruncationStrategy::LongestFirst + } +} + impl std::convert::AsRef for TruncationStrategy { fn as_ref(&self) -> &str { match self { @@ -325,4 +337,13 @@ mod tests { truncate_and_assert(get_medium(), get_medium(), ¶ms, 0, 0); truncate_and_assert(get_long(), get_long(), ¶ms, 0, 0); } + + #[test] + fn test_deserialize_defaults() { + let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#; + + let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap(); + + assert_eq!(params.direction, TruncationDirection::Right); + } }