mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fixing infinite loop in UnigramTrainer. (#1182)
* Fixing infinite loop in UnigramTrainer. * Newer clippy.
This commit is contained in:
@ -29,6 +29,12 @@ fn digamma(mut x: f64) -> f64 {
|
||||
result
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum UnigramTrainerError {
|
||||
#[error("The vocabulary is not large enough to contain all chars")]
|
||||
VocabularyTooSmall,
|
||||
}
|
||||
|
||||
fn to_log_prob(pieces: &mut [SentencePiece]) {
|
||||
let sum: f64 = pieces.iter().map(|(_, score)| score).sum();
|
||||
let logsum = sum.ln();
|
||||
@ -504,6 +510,9 @@ impl UnigramTrainer {
|
||||
let expected_updates = expected_loops * self.n_sub_iterations as usize;
|
||||
self.update_progress(&progress, expected_updates, "EM training");
|
||||
let required_chars = self.required_chars(&sentences);
|
||||
if required_chars.len() as u32 > self.vocab_size {
|
||||
return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
|
||||
}
|
||||
let mut new_model = Unigram::from(pieces.clone(), Some(0))?;
|
||||
loop {
|
||||
// Sub-EM iteration.
|
||||
|
@ -3,16 +3,12 @@ use serde::{Deserialize, Serialize};
|
||||
use std::cmp;
|
||||
use std::mem;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)]
|
||||
pub enum TruncationDirection {
|
||||
Left,
|
||||
#[default]
|
||||
Right,
|
||||
}
|
||||
impl Default for TruncationDirection {
|
||||
fn default() -> Self {
|
||||
TruncationDirection::Right
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for TruncationDirection {
|
||||
fn as_ref(&self) -> &str {
|
||||
|
Reference in New Issue
Block a user