mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Fix decode error in SVTR model
This commit is contained in:
@ -2,6 +2,7 @@ use aksr::Builder;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use image::DynamicImage;
|
use image::DynamicImage;
|
||||||
use ndarray::Axis;
|
use ndarray::Axis;
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
use crate::{elapsed, DynConf, Engine, Options, Processor, Ts, Xs, Ys, Y};
|
use crate::{elapsed, DynConf, Engine, Options, Processor, Ts, Xs, Ys, Y};
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ impl SVTR {
|
|||||||
if processor.vocab().is_empty() {
|
if processor.vocab().is_empty() {
|
||||||
anyhow::bail!("No vocab file found")
|
anyhow::bail!("No vocab file found")
|
||||||
}
|
}
|
||||||
|
log::info!("Vacab size: {}", processor.vocab().len());
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
engine,
|
engine,
|
||||||
@ -69,40 +71,26 @@ impl SVTR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn postprocess(&self, xs: Xs) -> Result<Ys> {
|
pub fn postprocess(&self, xs: Xs) -> Result<Ys> {
|
||||||
let mut ys: Vec<Y> = Vec::new();
|
let ys: Vec<Y> = xs[0]
|
||||||
for batch in xs[0].axis_iter(Axis(0)) {
|
.axis_iter(Axis(0))
|
||||||
let preds = batch
|
.into_par_iter()
|
||||||
.axis_iter(Axis(0))
|
.map(|preds| {
|
||||||
.filter_map(|x| {
|
let mut preds: Vec<_> = preds
|
||||||
x.into_iter()
|
.axis_iter(Axis(0))
|
||||||
.enumerate()
|
.filter_map(|x| x.into_iter().enumerate().max_by(|a, b| a.1.total_cmp(b.1)))
|
||||||
.max_by(|(_, x), (_, y)| x.total_cmp(y))
|
.collect();
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let text = preds
|
preds.dedup_by(|a, b| a.0 == b.0);
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| {
|
|
||||||
if *text_id == 0 || confidence < self.confs[0] {
|
|
||||||
return text_ids;
|
|
||||||
}
|
|
||||||
|
|
||||||
if idx == 0 || idx == self.processor.vocab().len() - 1 {
|
let text: String = preds
|
||||||
return text_ids;
|
.into_iter()
|
||||||
}
|
.filter(|(id, &conf)| *id != 0 && conf >= self.confs[0])
|
||||||
|
.map(|(id, _)| self.processor.vocab()[id].clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
if *text_id != preds[idx - 1].0 {
|
Y::default().with_texts(&[text.into()])
|
||||||
text_ids.push(*text_id);
|
})
|
||||||
}
|
.collect();
|
||||||
text_ids
|
|
||||||
})
|
|
||||||
.into_iter()
|
|
||||||
.map(|idx| self.processor.vocab()[idx].to_owned())
|
|
||||||
.collect::<String>();
|
|
||||||
|
|
||||||
ys.push(Y::default().with_texts(&[text.into()]))
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ys.into())
|
Ok(ys.into())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user