diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index d4562b28..ab1bb674 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -29,6 +29,12 @@ fn digamma(mut x: f64) -> f64 { result } +#[derive(thiserror::Error, Debug)] +pub enum UnigramTrainerError { + #[error("The vocabulary is not large enough to contain all chars")] + VocabularyTooSmall, +} + fn to_log_prob(pieces: &mut [SentencePiece]) { let sum: f64 = pieces.iter().map(|(_, score)| score).sum(); let logsum = sum.ln(); @@ -504,6 +510,9 @@ impl UnigramTrainer { let expected_updates = expected_loops * self.n_sub_iterations as usize; self.update_progress(&progress, expected_updates, "EM training"); let required_chars = self.required_chars(&sentences); + if required_chars.len() as u32 > self.vocab_size { + return Err(Box::new(UnigramTrainerError::VocabularyTooSmall)); + } let mut new_model = Unigram::from(pieces.clone(), Some(0))?; loop { // Sub-EM iteration. diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 64814787..a8ad2a61 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -3,16 +3,12 @@ use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)] pub enum TruncationDirection { Left, + #[default] Right, } -impl Default for TruncationDirection { - fn default() -> Self { - TruncationDirection::Right - } -} impl std::convert::AsRef for TruncationDirection { fn as_ref(&self) -> &str {