diff --git a/README.md b/README.md index 4754f86..5c23bfe 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,30 @@ # usls -A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics) `(Classification, Segmentation, Detection and Pose Detection)`, [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`. +A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`. ## Supported Models -| Model | Example | CUDA
f32 | CUDA
f16 | TensorRT
f32 | TensorRT
f16 | -| :---------------------------------------------------------------: | :----------------------: | :-----------: | :-----------: | :------------------------: | :-----------------------: | -| **YOLOv8-detection** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| **YOLOv8-pose** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| **YOLOv8-classification** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| **YOLOv8-segmentation** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| **YOLOv8-OBB** | TODO | TODO | TODO | TODO | TODO | -| **YOLOv9** | [demo](examples/yolov9) | ✅ | ✅ | ✅ | ✅ | -| **RT-DETR** | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | ✅ | -| **FastSAM** | [demo](examples/fastsam) | ✅ | ✅ | ✅ | ✅ | -| **YOLO-World** | [demo](examples/yolo-world) | ✅ | ✅ | ✅ | ✅ | -| **DINOv2** | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | -| **CLIP** | [demo](examples/clip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | -| **BLIP** | [demo](examples/blip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | -| [**DB(Text Detection)**](https://arxiv.org/abs/1911.08947) | [demo](examples/db) | ✅ | ❌ | ✅ | ✅ | -| [**SVTR(Text Recognition)**](https://arxiv.org/abs/2205.00159) | [demo](examples/svtr) | ✅ | ❌ | ✅ | ✅ | +| Model | Task / Type | Example | CUDA
f32 | CUDA
f16 | TensorRT
f32 | TensorRT
f16 | +| :---------------------------------------------------------------: | :----------------------: |:----------------------: | :-----------: | :-----------: | :------------------------: | :-----------------------: | +| **[YOLOv8-detection](https://github.com/ultralytics/ultralytics)** | Object Detection | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **[YOLOv8-pose](https://github.com/ultralytics/ultralytics)** | Keypoint Detection | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **[YOLOv8-classification](https://github.com/ultralytics/ultralytics)** | Classification | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **[YOLOv8-segmentation](https://github.com/ultralytics/ultralytics)** | Instance Segmentation | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **[YOLOv9](https://github.com/WongKinYiu/yolov9)** | Object Detection | [demo](examples/yolov9) | ✅ | ✅ | ✅ | ✅ | +| **[RT-DETR](https://arxiv.org/abs/2304.08069)** | Object Detection | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | ✅ | +| **[FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)** | Instance Segmentation | [demo](examples/fastsam) | ✅ | ✅ | ✅ | ✅ | +| **[YOLO-World](https://github.com/AILab-CVC/YOLO-World)** | Object Detection | [demo](examples/yolo-world) | ✅ | ✅ | ✅ | ✅ | +| **[DINOv2](https://github.com/facebookresearch/dinov2)** | Vision-Self-Supervised | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | +| **[CLIP](https://github.com/openai/CLIP)** | Vision-Language | [demo](examples/clip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | +| **[BLIP](https://github.com/salesforce/BLIP)** | Vision-Language | [demo](examples/blip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | +| [**DB**](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | ✅ | ❌ | ✅ | ✅ | +| [**SVTR**](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ❌ | ✅ | ✅ | +| [**RTMO**](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ | + ## Solution Models -Additionally, this repo also provides some solution models such as pedestrian `fall detection`, `head detection`, `trash detection`, and more. +Additionally, this repo also provides some solution models. | Model | Example | | :--------------------------------------------------------------------------------: | :------------------------------: | diff --git a/examples/assets/bus.jpg b/examples/assets/bus.jpg deleted file mode 100644 index 40eaaf5..0000000 Binary files a/examples/assets/bus.jpg and /dev/null differ diff --git a/examples/assets/falldown.jpg b/examples/assets/falldown.jpg deleted file mode 100644 index 1492401..0000000 Binary files a/examples/assets/falldown.jpg and /dev/null differ diff --git a/examples/assets/kids.jpg b/examples/assets/kids.jpg deleted file mode 100644 index 7eda4f3..0000000 Binary files a/examples/assets/kids.jpg and /dev/null differ diff --git a/examples/assets/math.jpg b/examples/assets/math.jpg deleted file mode 100644 index 0b5b656..0000000 Binary files a/examples/assets/math.jpg and /dev/null differ diff --git a/examples/assets/trash.jpg b/examples/assets/trash.jpg deleted file mode 100644 index 2ead8d5..0000000 Binary files a/examples/assets/trash.jpg and /dev/null differ diff --git a/examples/blip/README.md b/examples/blip/README.md index 823fdb5..dda286b 100644 --- a/examples/blip/README.md +++ b/examples/blip/README.md @@ -47,7 +47,6 @@ cargo run -r --example blip ## TODO -* [ ] text decode with Top-p sample * [ ] VQA * [ ] Retrival * [ ] TensorRT support for textual model diff --git a/examples/blip/main.rs b/examples/blip/main.rs index a21c89a..5c11034 100644 --- a/examples/blip/main.rs +++ b/examples/blip/main.rs @@ -10,6 +10,7 @@ fn main() -> Result<(), Box> { // textual let options_textual = Options::default() .with_model("../models/blip-textual-base.onnx") + .with_tokenizer("tokenizer-blip.json") .with_i00((1, 1, 4).into()) // input_id: batch .with_i01((1, 1, 4).into()) // input_id: seq_len .with_i10((1, 1, 4).into()) // attention_mask: batch diff --git a/examples/clip/main.rs b/examples/clip/main.rs index 2b0aaf8..04d5a67 100644 --- a/examples/clip/main.rs +++ b/examples/clip/main.rs @@ -10,6 +10,7 @@ fn main() -> Result<(), Box> { // textual let options_textual = Options::default() .with_model("../models/clip-b32-textual-dyn.onnx") + .with_tokenizer("tokenizer-clip.json") .with_i00((1, 1, 4).into()) .with_profile(false); diff --git a/examples/db/README.md b/examples/db/README.md index e393e08..50a065f 100644 --- a/examples/db/README.md +++ b/examples/db/README.md @@ -8,7 +8,7 @@ cargo run -r --example db ### 1. Donwload ONNX Model -[ppocr-v3-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-db-dyn.onnx) +[ppocr-v3-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-db-dyn.onnx) [ppocr-v4-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-db-dyn.onnx) ### 2. Specify the ONNX model path in `main.rs` diff --git a/examples/fastsam/README.md b/examples/fastsam/README.md index d2ecc03..00b5a4b 100644 --- a/examples/fastsam/README.md +++ b/examples/fastsam/README.md @@ -25,7 +25,6 @@ cargo run -r --example fastsam ```Rust let options = Options::default() .with_model("../models/FastSAM-s-dyn-f16.onnx") // <= modify this - .with_saveout("FastSAM") .with_profile(false); let mut model = YOLO::new(&options)?; ``` diff --git a/examples/rtdetr/README.md b/examples/rtdetr/README.md index 9d5921a..4bfe671 100644 --- a/examples/rtdetr/README.md +++ b/examples/rtdetr/README.md @@ -23,7 +23,6 @@ cargo run -r --example rtdetr ```Rust let options = Options::default() .with_model("ONNX_MODEL") // <= modify this - .with_saveout("RT-DETR"); ``` ### 3. Then, run diff --git a/examples/rtmo/README.md b/examples/rtmo/README.md new file mode 100644 index 0000000..14c3a57 --- /dev/null +++ b/examples/rtmo/README.md @@ -0,0 +1,35 @@ +## Quick Start + +```shell +cargo run -r --example rtmo +``` + +## Or you can manully + +### 1. Donwload ONNX Model + +[rtmo-s-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-s-dyn.onnx) +[rtmo-m-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-m-dyn.onnx) +[rtmo-l-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-l-dyn.onnx) +[rtmo-s-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-s-dyn-f16.onnx) +[rtmo-m-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-m-dyn-f16.onnx) +[rtmo-l-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-l-dyn-f16.onnx) + + + +### 2. Specify the ONNX model path in `main.rs` + +```Rust +let options = Options::default() + .with_model("ONNX_MODEL") // <= modify this +``` + +### 3. Then, run + +```bash +cargo run -r --example rtmo +``` + +## Results + +![](./demo.jpg) diff --git a/examples/rtmo/demo.jpg b/examples/rtmo/demo.jpg new file mode 100644 index 0000000..4dca7ef Binary files /dev/null and b/examples/rtmo/demo.jpg differ diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs new file mode 100644 index 0000000..b16dd60 --- /dev/null +++ b/examples/rtmo/main.rs @@ -0,0 +1,26 @@ +use usls::{models::RTMO, Annotator, DataLoader, Options, COCO_SKELETON_17}; + +fn main() -> Result<(), Box> { + // build model + let options = Options::default() + .with_model("../rtmo-l-dyn-f16.onnx") + .with_i00((1, 2, 8).into()) + .with_nk(17) + .with_confs(&[0.3]) + .with_kconfs(&[0.5]); + let mut model = RTMO::new(&options)?; + + // load image + let x = vec![DataLoader::try_read("./assets/bus.jpg")?]; + + // run + let y = model.run(&x)?; + + // // annotate + let annotator = Annotator::default() + .with_saveout("RTMO") + .with_skeletons(&COCO_SKELETON_17); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/svtr/README.md b/examples/svtr/README.md index 8304a5d..ef8f26f 100644 --- a/examples/svtr/README.md +++ b/examples/svtr/README.md @@ -8,9 +8,9 @@ cargo run -r --example svtr ### 1. Donwload ONNX Model -[ppocr-v4-server-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-server-svtr-ch-dyn.onnx) -[ppocr-v4-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-svtr-ch-dyn.onnx) -[ppocr-v3-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-svtr-ch-dyn.onnx) +[ppocr-v4-server-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-server-svtr-ch-dyn.onnx) +[ppocr-v4-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-svtr-ch-dyn.onnx) +[ppocr-v3-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-svtr-ch-dyn.onnx) ### 2. Specify the ONNX model path in `main.rs` diff --git a/examples/yolo-world/README.md b/examples/yolo-world/README.md index f3081a2..8f006c9 100644 --- a/examples/yolo-world/README.md +++ b/examples/yolo-world/README.md @@ -10,7 +10,7 @@ cargo run -r --example yolo-world - Download - [yolov8s-world-v2-shoes](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8s-world-v2-shoes.onnx) + [yolov8s-world-v2-shoes](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8s-world-v2-shoes.onnx) - Or generate your own `yolo-world` model and then Export - Installation diff --git a/examples/yolov8-face/README.md b/examples/yolov8-face/README.md index 8b741df..ff98d3d 100644 --- a/examples/yolov8-face/README.md +++ b/examples/yolov8-face/README.md @@ -8,7 +8,7 @@ cargo run -r --example yolov8-face ### 1. Donwload ONNX Model -[yolov8-face-dyn-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-face-dyn-f16.onnx) +[yolov8-face-dyn-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-face-dyn-f16.onnx) ### 2. Specify the ONNX model path in `main.rs` diff --git a/examples/yolov8-falldown/README.md b/examples/yolov8-falldown/README.md index 1cc6699..a7f3a26 100644 --- a/examples/yolov8-falldown/README.md +++ b/examples/yolov8-falldown/README.md @@ -8,7 +8,7 @@ cargo run -r --example yolov8-falldown ### 1.Donwload ONNX Model -[yolov8-falldown-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-falldown-f16.onnx) +[yolov8-falldown-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-falldown-f16.onnx) ### 2. Specify the ONNX model path in `main.rs` diff --git a/examples/yolov8-head/README.md b/examples/yolov8-head/README.md index 2ef3bd7..57e30f7 100644 --- a/examples/yolov8-head/README.md +++ b/examples/yolov8-head/README.md @@ -8,7 +8,7 @@ cargo run -r --example yolov8-head ### 1. Donwload ONNX Model -[yolov8-head-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-head-f16.onnx) +[yolov8-head-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-head-f16.onnx) ### 2. Specify the ONNX model path in `main.rs` diff --git a/examples/yolov8-trash/README.md b/examples/yolov8-trash/README.md index 27c8c1c..d610a44 100644 --- a/examples/yolov8-trash/README.md +++ b/examples/yolov8-trash/README.md @@ -10,7 +10,7 @@ cargo run -r --example yolov8-trash ### 1. Donwload ONNX Model -[yolov8-plastic-bag-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-plastic-bag-f16.onnx) +[yolov8-plastic-bag-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-plastic-bag-f16.onnx) ### 2. Specify the ONNX model path in `main.rs` diff --git a/examples/yolov8/README.md b/examples/yolov8/README.md index 8b65881..b923a3a 100644 --- a/examples/yolov8/README.md +++ b/examples/yolov8/README.md @@ -38,7 +38,6 @@ yolo export model=yolov8m-seg.pt format=onnx simplify let options = Options::default() .with_model("ONNX_PATH") // <= modify this .with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15 - .with_saveout("YOLOv8"); let mut model = YOLO::new(&options)?; ``` diff --git a/examples/yolov8/main.rs b/examples/yolov8/main.rs index 45f41d8..ee2ade3 100644 --- a/examples/yolov8/main.rs +++ b/examples/yolov8/main.rs @@ -33,6 +33,7 @@ fn main() -> Result<(), Box> { // run & annotate for (xs, _paths) in dl { let ys = model.run(&xs)?; + println!("{:?}", ys); annotator.annotate(&xs, &ys); } diff --git a/examples/yolov9/README.md b/examples/yolov9/README.md index 5ce2bfb..d90549d 100644 --- a/examples/yolov9/README.md +++ b/examples/yolov9/README.md @@ -10,7 +10,7 @@ cargo run -r --example yolov9 - **Download** - [yolov9-c-dyn-fp16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov9-c-dyn-f16.onnx) + [yolov9-c-dyn-fp16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov9-c-dyn-f16.onnx) - **Export** ```shell @@ -31,7 +31,6 @@ cargo run -r --example yolov9 ```Rust let options = Options::default() .with_model("ONNX_PATH") // <= modify this - .with_saveout("YOLOv9"); ``` ### 3. Run diff --git a/src/models/blip.rs b/src/models/blip.rs index 0e8aea3..d902d75 100644 --- a/src/models/blip.rs +++ b/src/models/blip.rs @@ -4,7 +4,7 @@ use ndarray::{s, Array, Axis, IxDyn}; use std::io::Write; use tokenizers::Tokenizer; -use crate::{auto_load, ops, LogitsSampler, MinOptMax, Options, OrtEngine, TokenizerStream}; +use crate::{ops, LogitsSampler, MinOptMax, Options, OrtEngine, TokenizerStream}; #[derive(Debug)] pub struct Blip { @@ -27,11 +27,7 @@ impl Blip { visual.height().to_owned(), visual.width().to_owned(), ); - let tokenizer = match &options_textual.tokenizer { - None => auto_load("tokenizer-blip.json")?, - Some(tokenizer) => tokenizer.into(), - }; - let tokenizer = Tokenizer::from_file(tokenizer).unwrap(); + let tokenizer = Tokenizer::from_file(&options_textual.tokenizer.unwrap()).unwrap(); let tokenizer = TokenizerStream::new(tokenizer); visual.dry_run()?; textual.dry_run()?; diff --git a/src/models/clip.rs b/src/models/clip.rs index 2f19cad..3845651 100644 --- a/src/models/clip.rs +++ b/src/models/clip.rs @@ -1,7 +1,6 @@ -use crate::{auto_load, ops, MinOptMax, Options, OrtEngine}; +use crate::{ops, MinOptMax, Options, OrtEngine}; use anyhow::Result; use image::DynamicImage; -// use itertools::Itertools; use ndarray::{Array, Array2, Axis, IxDyn}; use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; @@ -28,11 +27,7 @@ impl Clip { visual.inputs_minoptmax()[0][2].to_owned(), visual.inputs_minoptmax()[0][3].to_owned(), ); - let tokenizer = match &options_textual.tokenizer { - None => auto_load("tokenizer-clip.json").unwrap(), - Some(tokenizer) => tokenizer.into(), - }; - let mut tokenizer = Tokenizer::from_file(tokenizer).unwrap(); + let mut tokenizer = Tokenizer::from_file(&options_textual.tokenizer.unwrap()).unwrap(); tokenizer.with_padding(Some(PaddingParams { strategy: PaddingStrategy::Fixed(context_length), direction: PaddingDirection::Right, diff --git a/src/models/mod.rs b/src/models/mod.rs index 16e652d..72f96bd 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -3,6 +3,7 @@ mod clip; mod db; mod dinov2; mod rtdetr; +mod rtmo; mod svtr; mod yolo; @@ -11,5 +12,6 @@ pub use clip::Clip; pub use db::DB; pub use dinov2::Dinov2; pub use rtdetr::RTDETR; +pub use rtmo::RTMO; pub use svtr::SVTR; pub use yolo::YOLO; diff --git a/src/models/rtdetr.rs b/src/models/rtdetr.rs index 2c436b3..b618799 100644 --- a/src/models/rtdetr.rs +++ b/src/models/rtdetr.rs @@ -41,7 +41,6 @@ impl RTDETR { .expect("Failed to get num_classes, make it explicit with `--nc`") .len(), ); - // let annotator = Annotator::default(); let confs = DynConf::new(&options.confs, nc); engine.dry_run()?; diff --git a/src/models/rtmo.rs b/src/models/rtmo.rs new file mode 100644 index 0000000..fea580e --- /dev/null +++ b/src/models/rtmo.rs @@ -0,0 +1,133 @@ +use anyhow::Result; +use image::DynamicImage; +use ndarray::{Array, Axis, IxDyn}; + +use crate::{ops, Bbox, DynConf, Keypoint, MinOptMax, Options, OrtEngine, Ys}; + +#[derive(Debug)] +pub struct RTMO { + engine: OrtEngine, + height: MinOptMax, + width: MinOptMax, + batch: MinOptMax, + confs: DynConf, + kconfs: DynConf, +} + +impl RTMO { + pub fn new(options: &Options) -> Result { + let engine = OrtEngine::new(options)?; + let (batch, height, width) = ( + engine.batch().to_owned(), + engine.height().to_owned(), + engine.width().to_owned(), + ); + let nc = 1; + let nk = options.nk.unwrap_or(17); + let confs = DynConf::new(&options.kconfs, nc); + let kconfs = DynConf::new(&options.kconfs, nk); + engine.dry_run()?; + + Ok(Self { + engine, + confs, + kconfs, + height, + width, + batch, + }) + } + + pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { + let xs_ = ops::letterbox(xs, self.height() as u32, self.width() as u32, 114.0)?; + let ys = self.engine.run(&[xs_])?; + let ys = self.postprocess(ys, xs)?; + Ok(ys) + } + + pub fn postprocess(&self, xs: Vec>, xs0: &[DynamicImage]) -> Result> { + let mut ys: Vec = Vec::new(); + let (preds_bboxes, preds_kpts) = if xs[0].ndim() == 3 { + (&xs[0], &xs[1]) + } else { + (&xs[1], &xs[0]) + }; + + for (idx, (batch_bboxes, batch_kpts)) in preds_bboxes + .axis_iter(Axis(0)) + .zip(preds_kpts.axis_iter(Axis(0))) + .enumerate() + { + let width_original = xs0[idx].width() as f32; + let height_original = xs0[idx].height() as f32; + let ratio = + (self.width() as f32 / width_original).min(self.height() as f32 / height_original); + + let mut y_bboxes = Vec::new(); + let mut y_kpts: Vec> = Vec::new(); + for (xyxyc, kpts) in batch_bboxes + .axis_iter(Axis(0)) + .zip(batch_kpts.axis_iter(Axis(0))) + { + // bbox + let x1 = xyxyc[0] / ratio; + let y1 = xyxyc[1] / ratio; + let x2 = xyxyc[2] / ratio; + let y2 = xyxyc[3] / ratio; + let confidence = xyxyc[4]; + if confidence < self.confs[0] { + continue; + } + let y_bbox = Bbox::new( + ( + ( + x1.max(0.0f32).min(width_original), + y1.max(0.0f32).min(height_original), + ), + (x2, y2), + ) + .into(), + 0, + confidence, + Some(String::from("Person")), + ); + y_bboxes.push(y_bbox); + + // keypoints + let mut kpts_ = Vec::new(); + for (i, kpt) in kpts.axis_iter(Axis(0)).enumerate() { + let x = kpt[0] / ratio; + let y = kpt[1] / ratio; + let c = kpt[2]; + if c < self.kconfs[i] { + kpts_.push(Keypoint::default()); + } else { + kpts_.push(Keypoint::new( + ( + x.max(0.0f32).min(width_original), + y.max(0.0f32).min(height_original), + ) + .into(), + c, + )); + } + } + y_kpts.push(kpts_); + } + ys.push(Ys::default().with_bboxes(&y_bboxes).with_keypoints(&y_kpts)); + } + Ok(ys) + } + + pub fn batch(&self) -> isize { + self.batch.opt + } + + pub fn width(&self) -> isize { + self.width.opt + } + + pub fn height(&self) -> isize { + self.height.opt + } +} diff --git a/src/models/yolo.rs b/src/models/yolo.rs index 9ce9a1e..a765694 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -34,7 +34,6 @@ pub struct YOLO { confs: DynConf, kconfs: DynConf, iou: f32, - // saveout: Option, names: Option>, apply_nms: bool, anchors_first: bool, diff --git a/src/options.rs b/src/options.rs index 9149cbc..6e8c995 100644 --- a/src/options.rs +++ b/src/options.rs @@ -156,6 +156,11 @@ impl Options { self } + pub fn with_tokenizer(mut self, tokenizer: &str) -> Self { + self.tokenizer = Some(auto_load(tokenizer).unwrap()); + self + } + pub fn with_unclip_ratio(mut self, x: f32) -> Self { self.unclip_ratio = x; self @@ -206,11 +211,6 @@ impl Options { self } - pub fn with_tokenizer(mut self, tokenizer: String) -> Self { - self.tokenizer = Some(tokenizer); - self - } - pub fn with_i00(mut self, x: MinOptMax) -> Self { self.i00 = Some(x); self