Add DB model for text detection

This commit is contained in:
jamjamjon
2024-03-31 02:14:53 +08:00
parent a5cee66dfd
commit ce9a416b71
11 changed files with 318 additions and 35 deletions

View File

@ -1,35 +1,37 @@
# usls
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics) `(Classification, Segmentation, Detection and Pose Detection)`, [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`.
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics) `(Classification, Segmentation, Detection and Pose Detection)`, [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`.
## Supported Models
| Model | Example | CUDA(f32) | CUDA(f16) | TensorRT(f32) | TensorRT(f16) |
| :-------------------: | :----------------------: | :----------------: | :----------------: | :------------------------: | :-----------------------: |
| YOLOv8-detection | [demo](examples/yolov8) | | | ✅ | ✅ |
| YOLOv8-pose | [demo](examples/yolov8) | | | ✅ | ✅ |
| YOLOv8-classification | [demo](examples/yolov8) | | | ✅ | ✅ |
| YOLOv8-segmentation | [demo](examples/yolov8) | | | ✅ | ✅ |
| YOLOv8-OBB | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** |
| YOLOv9 | [demo](examples/yolov9) | | | ✅ | ✅ |
| RT-DETR | [demo](examples/rtdetr) | | | ✅ | ✅ |
| FastSAM | [demo](examples/fastsam) | | | ✅ | ✅ |
| YOLO-World | [demo](examples/yolo-world) | | | ✅ | ✅ |
| DINOv2 | [demo](examples/dinov2) | | | ✅ | ✅ |
| CLIP | [demo](examples/clip) | | | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| BLIP | [demo](examples/blip) | | | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| OCR(DB, SVTR) | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** |
| Model | Example | CUDA<br />f32 | CUDA<br />f16 | TensorRT<br />f32 | TensorRT<br />f16 |
| :-----------------------------: | :----------------------: | :-----------: | :-----------: | :------------------------: | :-----------------------: |
| **YOLOv8-detection** | [demo](examples/yolov8) | | | ✅ | ✅ |
| **YOLOv8-pose** | [demo](examples/yolov8) | | | ✅ | ✅ |
| **YOLOv8-classification** | [demo](examples/yolov8) | ✅ | | ✅ | ✅ |
| **YOLOv8-segmentation** | [demo](examples/yolov8) | ✅ | | ✅ | ✅ |
| **YOLOv8-OBB** | TODO | TODO | TODO | TODO | TODO |
| **YOLOv9** | [demo](examples/yolov9) | | | ✅ | ✅ |
| **RT-DETR** | [demo](examples/rtdetr) | | | ✅ | ✅ |
| **FastSAM** | [demo](examples/fastsam) | | | ✅ | ✅ |
| **YOLO-World** | [demo](examples/yolo-world) | | | ✅ | ✅ |
| **DINOv2** | [demo](examples/dinov2) | | | ✅ | ✅ |
| **CLIP** | [demo](examples/clip) | | | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| **BLIP** | [demo](examples/blip) | | | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| [**DB(Text Detection)**](https://arxiv.org/abs/1911.08947) | [demo](examples/db) | ✅ | ❌ | ✅ | ✅ |
| **SVTR, TROCR** | TODO | TODO | TODO | TODO | TODO |
## Solution Models
Additionally, this repo also provides some solution models such as pedestrian `fall detection`, `head detection`, `trash detection`, and more.
| Model | Example |
| :---------------------: | :------------------------------: |
| face-landmark detection | [demo](examples/yolov8-face) |
| head detection | [demo](examples/yolov8-head) |
| fall detection | [demo](examples/yolov8-falldown) |
| trash detection | [demo](examples/yolov8-plastic-bag) |
| Model | Example |
| :-------------------------------------------------------: | :------------------------------: |
| **face-landmark detection**<br />**人脸 & 关键点检测** | [demo](examples/yolov8-face) |
| **head detection**<br /> **人头检测** | [demo](examples/yolov8-head) |
| **fall detection**<br /> **摔倒检测** | [demo](examples/yolov8-falldown) |
| **trash detection**<br /> **垃圾检测** | [demo](examples/yolov8-plastic-bag) |
| **text detection(PPOCR-det v3, v4)**<br />**PPOCR文本检测** | [demo](examples/db) |
## Demo
@ -63,8 +65,8 @@ cargo add --git https://github.com/jamjamjon/usls
cargo add usls
```
#### 3. Set `Options` and build model
```Rust
let options = Options::default()
.with_model("../models/yolov8m-seg-dyn-f16.onnx")
@ -100,7 +102,6 @@ let x = DataLoader::try_read("./assets/bus.jpg")?;
let _y = model.run(&[x])?;
```
## Script: converte ONNX model from `float32` to `float16`
```python

BIN
assets/math.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

39
examples/db/README.md Normal file
View File

@ -0,0 +1,39 @@
## Quick Start
```shell
cargo run -r --example db
```
## Or you can manully
### 1. Donwload ONNX Model
[ppocr-v3-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-db-dyn.onnx)
[ppocr-v4-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-db-dyn.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
.with_profile(false);
```
### 3. Run
```bash
cargo run -r --example db
```
### Speed test
| Model | Image size | TensorRT<br />f16 | TensorRT<br />f32 | CUDA<br />f32 |
| --------------- | ---------- | ----------------- | ----------------- | ------------- |
| ppocr-v3-db-dyn | 640x640 | 1.8585ms | 2.5739ms | 4.3314ms |
| ppocr-v4-db-dyn | 640x640 | 2.0507ms | 2.8264ms | 6.6064ms |
***Test on RTX3060***
## Results
![](./demo.jpg)

BIN
examples/db/demo.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

25
examples/db/main.rs Normal file
View File

@ -0,0 +1,25 @@
use usls::{models::DB, DataLoader, Options};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model
let options = Options::default()
.with_model("../models/ppocr-v4-db-dyn.onnx")
.with_i00((1, 1, 4).into())
.with_i02((608, 640, 960).into())
.with_i03((608, 640, 960).into())
.with_confs(&[0.7])
.with_saveout("DB-Text-Detection")
.with_dry_run(5)
// .with_trt(0)
// .with_fp16(true)
.with_profile(true);
let mut model = DB::new(&options)?;
// load image
let x = DataLoader::try_read("./assets/math.jpg")?;
// run
let _y = model.run(&[x])?;
Ok(())
}

View File

@ -12,6 +12,7 @@ pub mod models;
pub mod ops;
mod options;
mod point;
mod polygon;
mod rect;
mod results;
mod rotated_rect;
@ -30,6 +31,7 @@ pub use metric::Metric;
pub use min_opt_max::MinOptMax;
pub use options::Options;
pub use point::Point;
pub use polygon::Polygon;
pub use rect::Rect;
pub use results::Results;
pub use rotated_rect::RotatedRect;

155
src/models/db.rs Normal file
View File

@ -0,0 +1,155 @@
use crate::{
ops, Annotator, Bbox, DynConf, MinOptMax, Options, OrtEngine, Point, Polygon, Results,
};
use anyhow::Result;
use image::{DynamicImage, ImageBuffer};
use ndarray::{Array, Axis, IxDyn};
#[derive(Debug)]
pub struct DB {
engine: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
annotator: Annotator,
confs: DynConf,
saveout: Option<String>,
names: Option<Vec<String>>,
}
impl DB {
pub fn new(options: &Options) -> Result<Self> {
let engine = OrtEngine::new(options)?;
let (batch, height, width) = (
engine.inputs_minoptmax()[0][0].to_owned(),
engine.inputs_minoptmax()[0][2].to_owned(),
engine.inputs_minoptmax()[0][3].to_owned(),
);
let annotator = Annotator::default();
let names = Some(vec!["Text".to_string()]);
let confs = DynConf::new(&options.confs, 1);
engine.dry_run()?;
Ok(Self {
engine,
names,
confs,
height,
width,
batch,
saveout: options.saveout.to_owned(),
annotator,
})
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Results>> {
let xs_ = ops::letterbox(xs, self.height.opt as u32, self.width.opt as u32)?;
let ys = self.engine.run(&[xs_])?;
let ys = self.postprocess(ys, xs)?;
match &self.saveout {
None => {}
Some(saveout) => {
for (img0, y) in xs.iter().zip(ys.iter()) {
let mut img = img0.to_rgb8();
self.annotator.plot(&mut img, y);
self.annotator.save(&img, saveout);
}
}
}
Ok(ys)
}
pub fn postprocess(
&self,
xs: Vec<Array<f32, IxDyn>>,
xs0: &[DynamicImage],
) -> Result<Vec<Results>> {
let mut ys = Vec::new();
for (idx, mask) in xs[0].axis_iter(Axis(0)).enumerate() {
let mut ys_bbox = Vec::new();
// input image
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
// h,w,1
let h = mask.dim()[1];
let w = mask.dim()[2];
let mask = mask.into_shape((h, w, 1))?.into_owned();
// build image from ndarray
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
ImageBuffer::from_raw(w as u32, h as u32, mask.into_raw_vec())
.expect("Faild to create image from ndarray");
let mut mask_im = image::DynamicImage::from(mask_im);
// rescale
let (_, w_mask, h_mask) = ops::scale_wh(image_width, image_height, w as f32, h as f32);
let mask_original = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
let mask_original = mask_original.resize_exact(
image_width as u32,
image_height as u32,
image::imageops::FilterType::Triangle,
);
// contours
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours(&mask_original.into_luma8());
for contour in contours.iter() {
// polygon
let points: Vec<Point> = contour
.points
.iter()
.map(|p| Point::new(p.x as f32, p.y as f32))
.collect();
let polygon = Polygon::new(&points);
let mut rect = polygon.find_min_rect();
// min size filter
if rect.height() < 3.0 || rect.width() < 3.0 {
continue;
}
// confs filter
let confidence = polygon.area() / rect.area();
if confidence < self.confs[0] {
continue;
}
// TODO: expand polygon
let unclip_ratio = 1.5;
let delta = rect.area() * unclip_ratio / rect.perimeter();
// save
let y_bbox = Bbox::new(
rect.expand(delta, delta, image_width, image_height),
0,
confidence,
self.names.as_ref().map(|names| names[0].clone()),
);
ys_bbox.push(y_bbox);
}
let y = Results {
probs: None,
bboxes: Some(ys_bbox),
keypoints: None,
masks: None,
};
ys.push(y);
}
Ok(ys)
}
pub fn batch(&self) -> isize {
self.batch.opt
}
pub fn width(&self) -> isize {
self.width.opt
}
pub fn height(&self) -> isize {
self.height.opt
}
}

View File

@ -1,11 +1,13 @@
mod blip;
mod clip;
mod db;
mod dinov2;
mod rtdetr;
mod yolo;
pub use blip::Blip;
pub use clip::Clip;
pub use db::DB;
pub use dinov2::Dinov2;
pub use rtdetr::RTDETR;
pub use yolo::YOLO;

View File

@ -128,13 +128,6 @@ impl YOLO {
})
}
// pub fn run_with_dl(&mut self, dl: &Dataloader) -> Result<Vec<Results>> {
// for (images, paths) in dataloader {
// self.run(&images)
// }
// Ok(())
// }
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Results>> {
let xs_ = ops::letterbox(xs, self.height() as u32, self.width() as u32)?;
let ys = self.engine.run(&[xs_])?;
@ -296,10 +289,9 @@ impl YOLO {
// build image from ndarray
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) {
Some(image) => image,
None => panic!("can not create image from ndarray"),
};
ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec())
.expect("Faild to create image from ndarray");
let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn
// rescale masks

54
src/polygon.rs Normal file
View File

@ -0,0 +1,54 @@
use crate::{Point, Rect, RotatedRect};
#[derive(Default, Debug, PartialOrd, PartialEq, Clone)]
pub struct Polygon {
points: Vec<Point>,
}
impl Polygon {
pub fn new(points: &[Point]) -> Self {
// TODO: refactor
Self {
points: points.to_vec(),
}
}
pub fn area(&self) -> f32 {
// make sure points are already sorted
let mut area = 0.0;
let n = self.points.len();
for i in 0..n {
let j = (i + 1) % n;
area += self.points[i].x * self.points[j].y;
area -= self.points[j].x * self.points[i].y;
}
area.abs() / 2.0
}
pub fn find_min_rect(&self) -> Rect {
let (mut min_x, mut min_y, mut max_x, mut max_y) = (f32::MAX, f32::MAX, f32::MIN, f32::MIN);
for point in self.points.iter() {
if point.x <= min_x {
min_x = point.x
}
if point.x > max_x {
max_x = point.x
}
if point.y <= min_y {
min_y = point.y
}
if point.y > max_y {
max_y = point.y
}
}
((min_x, min_y), (max_x, max_y)).into()
}
pub fn find_min_rotated_rect() -> RotatedRect {
todo!()
}
pub fn expand(&mut self) -> Self {
todo!()
}
}

View File

@ -120,6 +120,10 @@ impl Rect {
self.height() * self.width()
}
pub fn perimeter(&self) -> f32 {
(self.height() + self.width()) * 2.0
}
pub fn is_empty(&self) -> bool {
self.area() == 0.0
}
@ -150,6 +154,15 @@ impl Rect {
&& self.ymin() <= other.ymin()
&& self.ymax() >= other.ymax()
}
pub fn expand(&mut self, x: f32, y: f32, max_x: f32, max_y: f32) -> Self {
Self::from_xyxy(
(self.xmin() - x).max(0.0f32).min(max_x),
(self.ymin() - y).max(0.0f32).min(max_y),
(self.xmax() + x).max(0.0f32).min(max_x),
(self.ymax() + y).max(0.0f32).min(max_y),
)
}
}
#[cfg(test)]