fix: svtr deduplication bug (#7)

This commit is contained in:
oatiz
2024-04-16 18:59:42 +08:00
committed by GitHub
parent c05836f02e
commit a86000d107
4 changed files with 43 additions and 33 deletions

View File

@ -25,6 +25,7 @@ cargo run -r --example svtr
```shell
[Texts] from the background, but also separate text instances which
[Texts] are closely jointed. Some examples are ilustrated in Fig.7.
[Texts] are closely jointed. Some examples are illustrated in Fig.7.
[Texts] 你有这么高速运转的机械进入中国,记住我给出的原理
[Texts] 110022345
```

View File

@ -15,10 +15,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
DataLoader::try_read("./examples/svtr/text1.png")?,
DataLoader::try_read("./examples/svtr/text2.png")?,
DataLoader::try_read("./examples/svtr/text3.png")?,
DataLoader::try_read("./examples/svtr/text4.png")?,
];
// run
model.run(&xs)?;
for text in model.run(&xs)?.into_iter() {
println!("[Texts] {text}")
}
Ok(())
}

BIN
examples/svtr/text4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

View File

@ -41,46 +41,52 @@ impl SVTR {
})
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<()> {
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<String>> {
let xs_ =
ops::resize_with_fixed_height(xs, self.height.opt as u32, self.width.opt as u32, 0.0)?;
let xs_ = ops::normalize(xs_, 0.0, 255.0);
let ys: Vec<Array<f32, IxDyn>> = self.engine.run(&[xs_])?;
let ys = ys[0].to_owned();
self.postprocess(&ys)?;
Ok(())
self.postprocess(&ys)
}
pub fn postprocess(&self, xs: &Array<f32, IxDyn>) -> Result<()> {
for batch in xs.axis_iter(Axis(0)) {
let mut texts: Vec<String> = Vec::new();
for (i, seq) in batch.axis_iter(Axis(0)).enumerate() {
let (id, &confidence) = seq
.into_iter()
.enumerate()
.reduce(|max, x| if x.1 > max.1 { x } else { max })
.unwrap();
if id == 0 || confidence < self.confs[0] {
continue;
}
if i == 0 && id == self.vocab.len() - 1 {
continue;
}
texts.push(self.vocab[id].to_owned());
}
texts.dedup();
pub fn postprocess(&self, output: &Array<f32, IxDyn>) -> Result<Vec<String>> {
let mut texts: Vec<String> = Vec::new();
for batch in output.axis_iter(Axis(0)) {
let preds = batch
.axis_iter(Axis(0))
.filter_map(|x| {
x.into_iter()
.enumerate()
.max_by(|(_, x), (_, y)| x.total_cmp(y))
})
.collect::<Vec<_>>();
print!("[Texts] ");
if texts.is_empty() {
println!("Nothing detected!");
} else {
for text in texts.into_iter() {
print!("{text}");
}
println!();
}
let text = preds
.iter()
.enumerate()
.fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| {
if *text_id == 0 || confidence < self.confs[0] {
return text_ids;
}
if idx == 0 || idx == self.vocab.len() - 1 {
return text_ids;
}
if *text_id != preds[idx - 1].0 {
text_ids.push(*text_id);
}
text_ids
})
.into_iter()
.map(|idx| self.vocab[idx].to_owned())
.collect::<String>();
texts.push(text);
}
Ok(())
Ok(texts)
}
}