diff --git a/src/models/svtr/impl.rs b/src/models/svtr/impl.rs index f14f5f8..2e262fe 100644 --- a/src/models/svtr/impl.rs +++ b/src/models/svtr/impl.rs @@ -2,6 +2,7 @@ use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::Axis; +use rayon::prelude::*; use crate::{elapsed, DynConf, Engine, Options, Processor, Ts, Xs, Ys, Y}; @@ -35,6 +36,7 @@ impl SVTR { if processor.vocab().is_empty() { anyhow::bail!("No vocab file found") } + log::info!("Vacab size: {}", processor.vocab().len()); Ok(Self { engine, @@ -69,40 +71,26 @@ impl SVTR { } pub fn postprocess(&self, xs: Xs) -> Result { - let mut ys: Vec = Vec::new(); - for batch in xs[0].axis_iter(Axis(0)) { - let preds = batch - .axis_iter(Axis(0)) - .filter_map(|x| { - x.into_iter() - .enumerate() - .max_by(|(_, x), (_, y)| x.total_cmp(y)) - }) - .collect::>(); + let ys: Vec = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .map(|preds| { + let mut preds: Vec<_> = preds + .axis_iter(Axis(0)) + .filter_map(|x| x.into_iter().enumerate().max_by(|a, b| a.1.total_cmp(b.1))) + .collect(); - let text = preds - .iter() - .enumerate() - .fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| { - if *text_id == 0 || confidence < self.confs[0] { - return text_ids; - } + preds.dedup_by(|a, b| a.0 == b.0); - if idx == 0 || idx == self.processor.vocab().len() - 1 { - return text_ids; - } + let text: String = preds + .into_iter() + .filter(|(id, &conf)| *id != 0 && conf >= self.confs[0]) + .map(|(id, _)| self.processor.vocab()[id].clone()) + .collect(); - if *text_id != preds[idx - 1].0 { - text_ids.push(*text_id); - } - text_ids - }) - .into_iter() - .map(|idx| self.processor.vocab()[idx].to_owned()) - .collect::(); - - ys.push(Y::default().with_texts(&[text.into()])) - } + Y::default().with_texts(&[text.into()]) + }) + .collect(); Ok(ys.into()) }