diff --git a/bindings/python/test.py b/bindings/python/test.py
deleted file mode 100644
index 28f313a2..00000000
--- a/bindings/python/test.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from tokenizers import ByteLevelBPETokenizer
-from tokenizers import pre_tokenizers, models, Tokenizer, trainers
-
-tokenizer = Tokenizer(models.Unigram())
-tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
-trainer = trainers.UnigramTrainer(
- vocab_size=400000000,
- show_progress=True,
- special_tokens=["", "", "", "", "mask"]
- )
-tokenizer.train(["data/big.txt"], trainer)
-
diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs
index d5315bc6..d57f5792 100644
--- a/tokenizers/src/utils/truncation.rs
+++ b/tokenizers/src/utils/truncation.rs
@@ -8,6 +8,11 @@ pub enum TruncationDirection {
Left,
Right,
}
+impl Default for TruncationDirection {
+ fn default() -> Self {
+ TruncationDirection::Right
+ }
+}
impl std::convert::AsRef for TruncationDirection {
fn as_ref(&self) -> &str {
@@ -20,6 +25,7 @@ impl std::convert::AsRef for TruncationDirection {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TruncationParams {
+ #[serde(default)]
pub direction: TruncationDirection,
pub max_length: usize,
pub strategy: TruncationStrategy,
@@ -30,9 +36,9 @@ impl Default for TruncationParams {
fn default() -> Self {
Self {
max_length: 512,
- strategy: TruncationStrategy::LongestFirst,
+ strategy: TruncationStrategy::default(),
stride: 0,
- direction: TruncationDirection::Right,
+ direction: TruncationDirection::default(),
}
}
}
@@ -68,6 +74,12 @@ pub enum TruncationStrategy {
OnlySecond,
}
+impl Default for TruncationStrategy {
+ fn default() -> Self {
+ TruncationStrategy::LongestFirst
+ }
+}
+
impl std::convert::AsRef for TruncationStrategy {
fn as_ref(&self) -> &str {
match self {
@@ -325,4 +337,13 @@ mod tests {
truncate_and_assert(get_medium(), get_medium(), ¶ms, 0, 0);
truncate_and_assert(get_long(), get_long(), ¶ms, 0, 0);
}
+
+ #[test]
+ fn test_deserialize_defaults() {
+ let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;
+
+ let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();
+
+ assert_eq!(params.direction, TruncationDirection::Right);
+ }
}