Fixing infinite loop in UnigramTrainer. (#1182)

* Fixing infinite loop in UnigramTrainer.

* Newer clippy.
This commit is contained in:
Nicolas Patry
2023-03-15 14:59:01 +01:00
committed by GitHub
parent 9c0e700212
commit 5ecd329503
2 changed files with 11 additions and 6 deletions

View File

@ -29,6 +29,12 @@ fn digamma(mut x: f64) -> f64 {
result
}
#[derive(thiserror::Error, Debug)]
pub enum UnigramTrainerError {
#[error("The vocabulary is not large enough to contain all chars")]
VocabularyTooSmall,
}
fn to_log_prob(pieces: &mut [SentencePiece]) {
let sum: f64 = pieces.iter().map(|(_, score)| score).sum();
let logsum = sum.ln();
@ -504,6 +510,9 @@ impl UnigramTrainer {
let expected_updates = expected_loops * self.n_sub_iterations as usize;
self.update_progress(&progress, expected_updates, "EM training");
let required_chars = self.required_chars(&sentences);
if required_chars.len() as u32 > self.vocab_size {
return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
}
let mut new_model = Unigram::from(pieces.clone(), Some(0))?;
loop {
// Sub-EM iteration.

View File

@ -3,16 +3,12 @@ use serde::{Deserialize, Serialize};
use std::cmp;
use std::mem;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)]
pub enum TruncationDirection {
Left,
#[default]
Right,
}
impl Default for TruncationDirection {
fn default() -> Self {
TruncationDirection::Right
}
}
impl std::convert::AsRef<str> for TruncationDirection {
fn as_ref(&self) -> &str {