Fix decode error in SVTR model

This commit is contained in:
Jamjamjon
2025-02-08 21:24:45 +08:00
committed by GitHub
parent e2347353ba
commit 018b884992

View File

@ -2,6 +2,7 @@ use aksr::Builder;
use anyhow::Result; use anyhow::Result;
use image::DynamicImage; use image::DynamicImage;
use ndarray::Axis; use ndarray::Axis;
use rayon::prelude::*;
use crate::{elapsed, DynConf, Engine, Options, Processor, Ts, Xs, Ys, Y}; use crate::{elapsed, DynConf, Engine, Options, Processor, Ts, Xs, Ys, Y};
@ -35,6 +36,7 @@ impl SVTR {
if processor.vocab().is_empty() { if processor.vocab().is_empty() {
anyhow::bail!("No vocab file found") anyhow::bail!("No vocab file found")
} }
log::info!("Vacab size: {}", processor.vocab().len());
Ok(Self { Ok(Self {
engine, engine,
@ -69,40 +71,26 @@ impl SVTR {
} }
pub fn postprocess(&self, xs: Xs) -> Result<Ys> { pub fn postprocess(&self, xs: Xs) -> Result<Ys> {
let mut ys: Vec<Y> = Vec::new(); let ys: Vec<Y> = xs[0]
for batch in xs[0].axis_iter(Axis(0)) {
let preds = batch
.axis_iter(Axis(0)) .axis_iter(Axis(0))
.filter_map(|x| { .into_par_iter()
x.into_iter() .map(|preds| {
.enumerate() let mut preds: Vec<_> = preds
.max_by(|(_, x), (_, y)| x.total_cmp(y)) .axis_iter(Axis(0))
}) .filter_map(|x| x.into_iter().enumerate().max_by(|a, b| a.1.total_cmp(b.1)))
.collect::<Vec<_>>(); .collect();
let text = preds preds.dedup_by(|a, b| a.0 == b.0);
.iter()
.enumerate()
.fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| {
if *text_id == 0 || confidence < self.confs[0] {
return text_ids;
}
if idx == 0 || idx == self.processor.vocab().len() - 1 { let text: String = preds
return text_ids;
}
if *text_id != preds[idx - 1].0 {
text_ids.push(*text_id);
}
text_ids
})
.into_iter() .into_iter()
.map(|idx| self.processor.vocab()[idx].to_owned()) .filter(|(id, &conf)| *id != 0 && conf >= self.confs[0])
.collect::<String>(); .map(|(id, _)| self.processor.vocab()[id].clone())
.collect();
ys.push(Y::default().with_texts(&[text.into()])) Y::default().with_texts(&[text.into()])
} })
.collect();
Ok(ys.into()) Ok(ys.into())
} }