This commit is contained in:
Jamjamjon
2024-07-31 21:27:41 +08:00
committed by GitHub
parent 0901ab3e3c
commit 1d74085158
28 changed files with 1224 additions and 434 deletions

View File

@ -1,47 +0,0 @@
## v0.0.5 - 2024-07-12
### Changed
- Accelerated `YOLO`'s post-processing using `Rayon`. Now, `YOLOv8-seg` takes only around **~8ms (~20ms in the previous version)**, depending on your machine. Note that this repo's implementation of `YOLOv8-Segment` saves not only the masks but also their contour points. The official `YOLOv8` Python version only saves the masks, making it appear much faster.
- Merged all `YOLOv8-related` solution models into YOLO examples.
- Consolidated all `YOLO-series` model examples into the YOLO example.
- Refactored the `YOLO` struct to unify all `YOLO versions` and `YOLO tasks`. It now supports user-defined YOLO models with different `Preds Tensor Formats`.
- Introduced a new `Nms` trait, combining `apply_bboxes_nms()` and `apply_mbrs_nms()` into `apply_nms()`.
### Added
- Added support for `YOLOv6` and `YOLOv7`.
- Updated documentation for `y.rs`.
- Updated documentation for `bbox.rs`.
- Updated the `README.md`.
- Added `with_yolo_preds()` to `Options`.
- Added support for `Depth-Anything-v2`.
- Added `RTDETR` to the `YOLOVersion` struct.
### Removed
- Merged the following models' examples into the YOLOv8 example: `yolov8-face`, `yolov8-falldown`, `yolov8-head`, `yolov8-trash`, `fastsam`, and `face-parsing`.
- Removed `anchors_first`, `conf_independent`, and their related methods from `Options`.
## v0.0.4 - 2024-06-30
### Added
- Add X struct to handle input and preprocessing
- Add Ops struct to manage common operations
- Use SIMD (fast_image_resize) to accelerate model pre-processing and post-processing.YOLOv8-seg post-processing (~120ms => ~20ms), Depth-Anything post-processing (~23ms => ~2ms).
### Deprecated
- Mark `Ops::descale_mask()` as deprecated.
### Fixed
### Changed
### Removed
### Refactored
### Others

View File

@ -1,13 +1,13 @@
[package]
name = "usls"
version = "0.0.6"
version = "0.0.7"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
authors = ["Jamjamjon <jamjamjon.usls@gmail.com>"]
license = "MIT"
readme = "README.md"
exclude = ["assets/*", "examples/*"]
exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
[dependencies]
clap = { version = "4.2.4", features = ["derive"] }
@ -44,4 +44,15 @@ ab_glyph = "0.2.23"
geo = "0.28.0"
prost = "0.12.4"
human_bytes = "0.4.3"
fast_image_resize = { version = "4.0.0", git = "https://github.com/jamjamjon/fast_image_resize", branch = "dev" , features = ["image"]}
fast_image_resize = { version = "4.2.1", features = ["image"]}
[dev-dependencies]
criterion = "0.5.1"
[[bench]]
name = "yolo"
harness = false
[lib]
bench = false

123
README.md
View File

@ -1,8 +1,21 @@
# usls
[![Static Badge](https://img.shields.io/crates/v/usls.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/usls) ![Static Badge](https://img.shields.io/crates/d/usls?style=for-the-badge) [![Static Badge](https://img.shields.io/badge/Documents-usls-blue?style=for-the-badge&logo=docs.rs)](https://docs.rs/usls) [![Static Badge](https://img.shields.io/badge/GitHub-black?style=for-the-badge&logo=github)](https://github.com/jamjamjon/usls)
[![Static Badge](https://img.shields.io/crates/v/usls.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/usls) [![Static Badge](https://img.shields.io/badge/ONNXRuntime-v1.17.x-yellow?style=for-the-badge&logo=docs.rs)](https://github.com/microsoft/onnxruntime/releases) [![Static Badge](https://img.shields.io/badge/CUDA-11.x-green?style=for-the-badge&logo=docs.rs)](https://developer.nvidia.com/cuda-toolkit-archive) [![Static Badge](https://img.shields.io/badge/TRT-8.6.x.x-blue?style=for-the-badge&logo=docs.rs)](https://developer.nvidia.com/tensorrt)
[![Static Badge](https://img.shields.io/badge/Documents-usls-blue?style=for-the-badge&logo=docs.rs)](https://docs.rs/usls) ![Static Badge](https://img.shields.io/crates/d/usls?style=for-the-badge)
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [SAM](https://github.com/facebookresearch/segment-anything), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [MODNet](https://github.com/ZHKKKe/MODNet) and others.
| Segment Anything |
| :------------------------------------------------------: |
| <img src='examples/sam/demo2.png' width="800px"> |
| YOLO + SAM |
| :------------------------------------------------------: |
| <img src='examples/yolo-sam/demo.png' width="800px"> |
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [MODNet](https://github.com/ZHKKKe/MODNet) and others.
| Monocular Depth Estimation |
| :--------------------------------------------------------------: |
@ -13,9 +26,7 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
| :----------------------------------------------------: | :------------------------------------------------: |
| <img src='examples/yolop/demo.png' width="385px"> | <img src='examples/db/demo.png' width="385x"> |
| Portrait Matting |
| :------------------------------------------------------: |
| <img src='examples/modnet/demo.png' width="800px"> |
## Supported Models
@ -30,6 +41,10 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [SAM](https://github.com/facebookresearch/segment-anything) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | |
| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | |
| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | |
| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | |
| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision-Self-Supervised | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ |
| [CLIP](https://github.com/openai/CLIP) | Vision-Language | [demo](examples/clip) | ✅ | ✅ | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
@ -64,103 +79,13 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ...
## Integrate into your own project
### 1. Add `usls` as a dependency to your project's `Cargo.toml`
```Shell
# Add `usls` as a dependency to your project's `Cargo.toml`
cargo add usls
```
Or you can use specific commit
```Shell
# Or you can use specific commit
usls = { git = "https://github.com/jamjamjon/usls", rev = "???sha???"}
```
### 2. Build model
```Rust
let options = Options::default()
.with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
.with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb
.with_model("xxxx.onnx")?;
let mut model = YOLO::new(options)?;
```
- If you want to run your model with TensorRT or CoreML
```Rust
let options = Options::default()
.with_trt(0) // using cuda by default
// .with_coreml(0)
```
- If your model has dynamic shapes
```Rust
let options = Options::default()
.with_i00((1, 2, 4).into()) // dynamic batch
.with_i02((416, 640, 800).into()) // dynamic height
.with_i03((416, 640, 800).into()) // dynamic width
```
- If you want to set a confidence for each category
```Rust
let options = Options::default()
.with_confs(&[0.4, 0.15]) // class_0: 0.4, others: 0.15
```
- Go check [Options](src/core/options.rs) for more model options.
#### 3. Load images
- Build `DataLoader` to load images
```Rust
let dl = DataLoader::default()
.with_batch(model.batch.opt as usize)
.load("./assets/")?;
for (xs, _paths) in dl {
let _y = model.run(&xs)?;
}
```
- Or simply read one image
```Rust
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
let y = model.run(&x)?;
```
#### 4. Annotate and save
```Rust
let annotator = Annotator::default().with_saveout("YOLO");
annotator.annotate(&x, &y);
```
#### 5. Get results
The inference outputs of provided models will be saved to `Vec<Y>`.
- You can get detection bboxes with `y.bboxes()`:
```Rust
let ys = model.run(&xs)?;
for y in ys {
// bboxes
if let Some(bboxes) = y.bboxes() {
for bbox in bboxes {
println!(
"Bbox: {}, {}, {}, {}, {}, {}",
bbox.xmin(),
bbox.ymin(),
bbox.xmax(),
bbox.ymax(),
bbox.confidence(),
bbox.id(),
)
}
}
}
```
- Other: [Docs](https://docs.rs/usls/latest/usls/struct.Y.html)

BIN
assets/dog.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 217 KiB

BIN
assets/truck.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

96
benches/yolo.rs Normal file
View File

@ -0,0 +1,96 @@
use anyhow::Result;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use usls::{coco, models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion};
enum Stage {
Pre,
Run,
Post,
Pipeline,
}
fn yolo_stage_bench(
model: &mut YOLO,
x: &[image::DynamicImage],
stage: Stage,
n: u64,
) -> std::time::Duration {
let mut t_pre = std::time::Duration::new(0, 0);
let mut t_run = std::time::Duration::new(0, 0);
let mut t_post = std::time::Duration::new(0, 0);
let mut t_pipeline = std::time::Duration::new(0, 0);
for _ in 0..n {
let t0 = std::time::Instant::now();
let xs = model.preprocess(x).unwrap();
t_pre += t0.elapsed();
let t = std::time::Instant::now();
let xs = model.inference(xs).unwrap();
t_run += t.elapsed();
let t = std::time::Instant::now();
let _ys = black_box(model.postprocess(xs, x).unwrap());
t_post += t.elapsed();
t_pipeline += t0.elapsed();
}
match stage {
Stage::Pre => t_pre,
Stage::Run => t_run,
Stage::Post => t_post,
Stage::Pipeline => t_pipeline,
}
}
pub fn benchmark_cuda(c: &mut Criterion, h: isize, w: isize) -> Result<()> {
let mut group = c.benchmark_group(format!("YOLO ({}-{})", w, h));
group
.significance_level(0.05)
.sample_size(80)
.measurement_time(std::time::Duration::new(20, 0));
let options = Options::default()
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
.with_model("yolov8m-dyn.onnx")?
.with_cuda(0)
// .with_cpu()
.with_dry_run(0)
.with_i00((1, 1, 4).into())
.with_i02((320, h, 1280).into())
.with_i03((320, w, 1280).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
.with_names2(&coco::KEYPOINTS_NAMES_17);
let mut model = YOLO::new(options)?;
let xs = vec![DataLoader::try_read("./assets/bus.jpg")?];
group.bench_function("pre-process", |b| {
b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pre, n))
});
group.bench_function("run", |b| {
b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Run, n))
});
group.bench_function("post-process", |b| {
b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Post, n))
});
group.bench_function("pipeline", |b| {
b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pipeline, n))
});
group.finish();
Ok(())
}
pub fn criterion_benchmark(c: &mut Criterion) {
// benchmark_cuda(c, 416, 416).unwrap();
benchmark_cuda(c, 640, 640).unwrap();
benchmark_cuda(c, 448, 768).unwrap();
// benchmark_cuda(c, 800, 800).unwrap();
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

5
build.rs Normal file
View File

@ -0,0 +1,5 @@
fn main() {
// Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml
#[cfg(target_os = "macos")]
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
}

21
examples/sam/README.md Normal file
View File

@ -0,0 +1,21 @@
## Quick Start
```Shell
# SAM
cargo run -r --example sam
# MobileSAM
cargo run -r --example sam -- --kind mobile-sam
# EdgeSAM
cargo run -r --example sam -- --kind edge-sam
# SAM-HQ
cargo run -r --example sam -- --kind sam-hq
```
## Results
![](./demo.png)

BIN
examples/sam/demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 326 KiB

106
examples/sam/main.rs Normal file
View File

@ -0,0 +1,106 @@
use clap::Parser;
use usls::{
models::{SamKind, SamPrompt, SAM},
Annotator, DataLoader, Options,
};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[arg(long, value_enum, default_value_t = SamKind::Sam)]
pub kind: SamKind,
#[arg(long, default_value_t = 0)]
pub device_id: usize,
#[arg(long)]
pub use_low_res_mask: bool,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
// Options
let (options_encoder, options_decoder, saveout) = match args.kind {
SamKind::Sam => {
let options_encoder = Options::default()
// .with_model("sam-vit-b-encoder.onnx")?;
.with_model("sam-vit-b-encoder-u8.onnx")?;
let options_decoder = Options::default()
.with_i00((1, 1, 1).into())
.with_i11((1, 1, 1).into())
.with_i21((1, 1, 1).into())
.with_sam_kind(SamKind::Sam)
// .with_model("sam-vit-b-decoder.onnx")?;
// .with_model("sam-vit-b-decoder-singlemask.onnx")?;
.with_model("sam-vit-b-decoder-u8.onnx")?;
(options_encoder, options_decoder, "SAM")
}
SamKind::MobileSam => {
let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?;
let options_decoder = Options::default()
.with_i00((1, 1, 1).into())
.with_i11((1, 1, 1).into())
.with_i21((1, 1, 1).into())
.with_sam_kind(SamKind::MobileSam)
.with_model("mobile-sam-vit-t-decoder.onnx")?;
(options_encoder, options_decoder, "Mobile-SAM")
}
SamKind::SamHq => {
let options_encoder = Options::default().with_model("sam-hq-vit-t-encoder.onnx")?;
let options_decoder = Options::default()
.with_i00((1, 1, 1).into())
.with_i21((1, 1, 1).into())
.with_i31((1, 1, 1).into())
.with_sam_kind(SamKind::SamHq)
.with_model("sam-hq-vit-t-decoder.onnx")?;
(options_encoder, options_decoder, "SAM-HQ")
}
SamKind::EdgeSam => {
let options_encoder = Options::default().with_model("edge-sam-3x-encoder.onnx")?;
let options_decoder = Options::default()
.with_i00((1, 1, 1).into())
.with_i11((1, 1, 1).into())
.with_i21((1, 1, 1).into())
.with_sam_kind(SamKind::EdgeSam)
.with_model("edge-sam-3x-decoder.onnx")?;
(options_encoder, options_decoder, "Edge-SAM")
}
};
let options_encoder = options_encoder
.with_cuda(args.device_id)
.with_i00((1, 1, 1).into())
.with_i02((800, 1024, 1024).into())
.with_i03((800, 1024, 1024).into());
let options_decoder = options_decoder
.with_cuda(args.device_id)
.use_low_res_mask(args.use_low_res_mask)
.with_find_contours(true);
// Build model
let mut model = SAM::new(options_encoder, options_decoder)?;
// Load image
let xs = vec![DataLoader::try_read("./assets/truck.jpg")?];
// Build annotator
let annotator = Annotator::default().with_saveout(saveout);
// Prompt
let prompts = vec![
SamPrompt::default()
// .with_postive_point(500., 375.), // postive point
// .with_negative_point(774., 366.), // negative point
.with_bbox(215., 297., 643., 459.), // bbox
];
// Run & Annotate
let ys = model.run(&xs, &prompts)?;
annotator.annotate(&xs, &ys);
Ok(())
}

BIN
examples/yolo-sam/demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

63
examples/yolo-sam/main.rs Normal file
View File

@ -0,0 +1,63 @@
use usls::{
models::{SamKind, SamPrompt, YOLOTask, YOLOVersion, SAM, YOLO},
Annotator, DataLoader, Options, Vision,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build SAM
let options_encoder = Options::default()
.with_i00((1, 1, 1).into())
.with_model("mobile-sam-vit-t-encoder.onnx")?;
let options_decoder = Options::default()
.with_i11((1, 1, 1).into())
.with_i21((1, 1, 1).into())
.with_find_contours(true)
.with_sam_kind(SamKind::Sam)
.with_model("mobile-sam-vit-t-decoder.onnx")?;
let mut sam = SAM::new(options_encoder, options_decoder)?;
// build YOLOv8-Det
let options_yolo = Options::default()
.with_yolo_version(YOLOVersion::V8)
.with_yolo_task(YOLOTask::Detect)
.with_model("yolov8m-dyn.onnx")?
.with_cuda(0)
.with_i00((1, 1, 4).into())
.with_i02((416, 640, 800).into())
.with_i03((416, 640, 800).into())
.with_find_contours(false)
.with_confs(&[0.45]);
let mut yolo = YOLO::new(options_yolo)?;
// load one image
let xs = vec![DataLoader::try_read("./assets/dog.jpg")?];
// build annotator
let annotator = Annotator::default()
.with_bboxes_thickness(7)
.without_bboxes_name(true)
.without_bboxes_conf(true)
.without_mbrs(true)
.with_saveout("YOLO+SAM");
// run & annotate
let ys_det = yolo.run(&xs)?;
for y_det in ys_det {
if let Some(bboxes) = y_det.bboxes() {
for bbox in bboxes {
let ys_sam = sam.run(
&xs,
&[SamPrompt::default().with_bbox(
bbox.xmin(),
bbox.ymin(),
bbox.xmax(),
bbox.ymax(),
)],
)?;
annotator.annotate(&xs, &ys_sam);
}
}
}
Ok(())
}

View File

@ -25,29 +25,29 @@
```Shell
# Classify
cargo run -r --example yolo -- --task classify --version v5 # YOLOv5
cargo run -r --example yolo -- --task classify --version v8 # YOLOv8
cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5
cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8
# Detect
cargo run -r --example yolo -- --task detect --version v5 # YOLOv5
cargo run -r --example yolo -- --task detect --version v6 # YOLOv6
cargo run -r --example yolo -- --task detect --version v7 # YOLOv7
cargo run -r --example yolo -- --task detect --version v8 # YOLOv8
cargo run -r --example yolo -- --task detect --version v9 # YOLOv9
cargo run -r --example yolo -- --task detect --version v10 # YOLOv10
cargo run -r --example yolo -- --task detect --version rtdetr # YOLOv8-RTDETR
cargo run -r --example yolo -- --task detect --version v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world
cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5
cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6
cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7
cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8
cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9
cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10
cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR
cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world
# Pose
cargo run -r --example yolo -- --task pose --version v8 # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose
# Segment
cargo run -r --example yolo -- --task segment --version v5 # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --version v8 # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --version v8 --model FastSAM-s-dyn-f16.onnx # FastSAM
cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM
# Obb
cargo run -r --example yolo -- --task obb --version v8 # YOLOv8-Obb
cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb
```
<details close>
@ -175,7 +175,3 @@ yolo export model=yolov8m-obb.pt format=onnx simplify
[Here](https://github.com/THU-MIG/yolov10#export)
</details>

View File

@ -16,7 +16,7 @@ pub struct Args {
pub task: YOLOTask,
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
pub version: YOLOVersion,
pub ver: YOLOVersion,
#[arg(long, default_value_t = 224)]
pub width_min: isize,
@ -59,6 +59,9 @@ pub struct Args {
#[arg(long)]
pub no_plot: bool,
#[arg(long)]
pub no_contours: bool,
}
fn main() -> Result<()> {
@ -68,66 +71,87 @@ fn main() -> Result<()> {
let options = Options::default();
// version & task
let options =
match args.version {
YOLOVersion::V5 => {
match args.task {
YOLOTask::Classify => options
.with_model(&args.model.unwrap_or("yolov5n-cls-dyn.onnx".to_string()))?,
YOLOTask::Detect => {
options.with_model(&args.model.unwrap_or("yolov5n-dyn.onnx".to_string()))?
}
YOLOTask::Segment => options
.with_model(&args.model.unwrap_or("yolov5n-seg-dyn.onnx".to_string()))?,
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version),
}
}
YOLOVersion::V6 => match args.task {
YOLOTask::Detect => options
let (options, saveout) = match args.ver {
YOLOVersion::V5 => match args.task {
YOLOTask::Classify => (
options.with_model(&args.model.unwrap_or("yolov5n-cls-dyn.onnx".to_string()))?,
"YOLOv5-Classify",
),
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolov5n-dyn.onnx".to_string()))?,
"YOLOv5-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolov5n-seg-dyn.onnx".to_string()))?,
"YOLOv5-Segment",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V6 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolov6n-dyn.onnx".to_string()))?
.with_nc(args.nc),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version),
},
YOLOVersion::V7 => match args.task {
YOLOTask::Detect => options
"YOLOv6-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V7 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolov7-tiny-dyn.onnx".to_string()))?
.with_nc(args.nc),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version),
},
YOLOVersion::V8 => {
match args.task {
YOLOTask::Classify => options
.with_model(&args.model.unwrap_or("yolov8m-cls-dyn.onnx".to_string()))?,
YOLOTask::Detect => {
options.with_model(&args.model.unwrap_or("yolov8m-dyn.onnx".to_string()))?
}
YOLOTask::Segment => options
.with_model(&args.model.unwrap_or("yolov8m-seg-dyn.onnx".to_string()))?,
YOLOTask::Pose => options
.with_model(&args.model.unwrap_or("yolov8m-pose-dyn.onnx".to_string()))?,
YOLOTask::Obb => options
.with_model(&args.model.unwrap_or("yolov8m-obb-dyn.onnx".to_string()))?,
}
}
YOLOVersion::V9 => match args.task {
YOLOTask::Detect => options
.with_model(&args.model.unwrap_or("yolov9-c-dyn-f16.onnx".to_string()))?,
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version),
},
YOLOVersion::V10 => match args.task {
YOLOTask::Detect => {
options.with_model(&args.model.unwrap_or("yolov10n.onnx".to_string()))?
}
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version),
},
YOLOVersion::RTDETR => match args.task {
YOLOTask::Detect => {
options.with_model(&args.model.unwrap_or("rtdetr-l-f16.onnx".to_string()))?
}
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version),
},
}
.with_yolo_version(args.version)
"YOLOv7-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V8 => match args.task {
YOLOTask::Classify => (
options.with_model(&args.model.unwrap_or("yolov8m-cls-dyn.onnx".to_string()))?,
"YOLOv8-Classify",
),
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolov8m-dyn.onnx".to_string()))?,
"YOLOv8-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolov8m-seg-dyn.onnx".to_string()))?,
"YOLOv8-Segment",
),
YOLOTask::Pose => (
options.with_model(&args.model.unwrap_or("yolov8m-pose-dyn.onnx".to_string()))?,
"YOLOv8-Pose",
),
YOLOTask::Obb => (
options.with_model(&args.model.unwrap_or("yolov8m-obb-dyn.onnx".to_string()))?,
"YOLOv8-Obb",
),
},
YOLOVersion::V9 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolov9-c-dyn-f16.onnx".to_string()))?,
"YOLOv9-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V10 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolov10n.onnx".to_string()))?,
"YOLOv10-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::RTDETR => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("rtdetr-l-f16.onnx".to_string()))?,
"RTDETR",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
};
let options = options
.with_yolo_version(args.ver)
.with_yolo_task(args.task);
// device
@ -152,6 +176,7 @@ fn main() -> Result<()> {
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
// .with_names(&coco::NAMES_80)
.with_names2(&coco::KEYPOINTS_NAMES_17)
.with_find_contours(!args.no_contours) // find contours or not
.with_profile(args.profile);
let mut model = YOLO::new(options)?;
@ -163,9 +188,9 @@ fn main() -> Result<()> {
// build annotator
let annotator = Annotator::default()
.with_skeletons(&coco::SKELETONS_16)
.with_bboxes_thickness(7)
.without_masks(true) // No masks plotting.
.with_saveout("YOLO-Series");
.with_bboxes_thickness(4)
.without_masks(true) // No masks plotting when doing segment task.
.with_saveout(saveout);
// run & annotate
for (xs, _paths) in dl {

View File

@ -340,13 +340,6 @@ impl Annotator {
}
}
// masks
if !self.without_masks {
if let Some(xs) = &y.masks() {
self.plot_masks(&mut img_rgba, xs);
}
}
// bboxes
if !self.without_bboxes {
if let Some(xs) = &y.bboxes() {
@ -368,6 +361,13 @@ impl Annotator {
}
}
// masks
if !self.without_masks {
if let Some(xs) = &y.masks() {
self.plot_masks(&mut img_rgba, xs);
}
}
// probs
if let Some(xs) = &y.probs() {
self.plot_probs(&mut img_rgba, xs);

View File

@ -53,49 +53,58 @@ impl Default for DataLoader {
}
impl DataLoader {
pub fn load<P: AsRef<Path>>(&mut self, source: P) -> Result<Self> {
let source = source.as_ref();
let mut paths = VecDeque::new();
match source {
s if s.is_file() => paths.push_back(s.to_path_buf()),
s if s.is_dir() => {
for entry in WalkDir::new(s)
.into_iter()
.filter_entry(|e| !Self::_is_hidden(e))
{
let entry = entry.unwrap();
if entry.file_type().is_dir() {
continue;
pub fn load<P: AsRef<Path>>(mut self, source: P) -> Result<Self> {
self.paths = match source.as_ref() {
s if s.is_file() => VecDeque::from([s.to_path_buf()]),
s if s.is_dir() => WalkDir::new(s)
.into_iter()
.filter_entry(|e| !Self::_is_hidden(e))
.filter_map(|entry| match entry {
Err(_) => None,
Ok(entry) => {
if entry.file_type().is_dir() {
return None;
}
if !self.recursive && entry.depth() > 1 {
return None;
}
Some(entry.path().to_path_buf())
}
if !self.recursive && entry.depth() > 1 {
continue;
}
paths.push_back(entry.path().to_path_buf());
}
}
})
.collect::<VecDeque<_>>(),
// s if s.starts_with("rtsp://") || s.starts_with("rtmp://") || s.starts_with("http://")|| s.starts_with("https://") => todo!(),
s if !s.exists() => bail!("{s:?} Not Exists"),
_ => todo!(),
}
let n_new = paths.len();
self.paths.append(&mut paths);
println!(
"{CHECK_MARK} Found images x{n_new} ({} total)",
self.paths.len()
);
Ok(Self {
paths: self.paths.to_owned(),
batch: self.batch,
recursive: self.recursive,
})
};
println!("{CHECK_MARK} Found file x{}", self.paths.len());
Ok(self)
}
pub fn try_read<P: AsRef<Path>>(path: P) -> Result<DynamicImage> {
let img = image::ImageReader::open(&path)
.map_err(|_| anyhow!("Failed to open image at {:?}", path.as_ref()))?
.map_err(|err| {
anyhow!(
"Failed to open image at {:?}. Error: {:?}",
path.as_ref(),
err
)
})?
.with_guessed_format()
.map_err(|err| {
anyhow!(
"Failed to make a format guess based on the content: {:?}. Error: {:?}",
path.as_ref(),
err
)
})?
.decode()
.map_err(|_| anyhow!("Failed to decode image at {:?}", path.as_ref()))?
.map_err(|err| {
anyhow!(
"Failed to decode image at {:?}. Error: {:?}",
path.as_ref(),
err
)
})?
.into_rgb8();
Ok(DynamicImage::from(img))
}

View File

@ -4,7 +4,6 @@ use human_bytes::human_bytes;
use ndarray::{Array, IxDyn};
use ort::{
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
MINOR_VERSION,
};
use prost::Message;
use std::collections::HashSet;
@ -41,7 +40,7 @@ impl OrtEngine {
let model_proto = Self::load_onnx(&config.onnx_path)?;
let graph = match &model_proto.graph {
Some(graph) => graph,
None => anyhow::bail!("No graph found in this proto"),
None => anyhow::bail!("No graph found in this proto. Failed to parse ONNX model."),
};
// model params & mems
@ -101,6 +100,30 @@ impl OrtEngine {
(3, 3) => Self::_set_ixx(x, &config.i33, i, ii).unwrap_or(x_default),
(3, 4) => Self::_set_ixx(x, &config.i34, i, ii).unwrap_or(x_default),
(3, 5) => Self::_set_ixx(x, &config.i35, i, ii).unwrap_or(x_default),
(4, 0) => Self::_set_ixx(x, &config.i40, i, ii).unwrap_or(x_default),
(4, 1) => Self::_set_ixx(x, &config.i41, i, ii).unwrap_or(x_default),
(4, 2) => Self::_set_ixx(x, &config.i42, i, ii).unwrap_or(x_default),
(4, 3) => Self::_set_ixx(x, &config.i43, i, ii).unwrap_or(x_default),
(4, 4) => Self::_set_ixx(x, &config.i44, i, ii).unwrap_or(x_default),
(4, 5) => Self::_set_ixx(x, &config.i45, i, ii).unwrap_or(x_default),
(5, 0) => Self::_set_ixx(x, &config.i50, i, ii).unwrap_or(x_default),
(5, 1) => Self::_set_ixx(x, &config.i51, i, ii).unwrap_or(x_default),
(5, 2) => Self::_set_ixx(x, &config.i52, i, ii).unwrap_or(x_default),
(5, 3) => Self::_set_ixx(x, &config.i53, i, ii).unwrap_or(x_default),
(5, 4) => Self::_set_ixx(x, &config.i54, i, ii).unwrap_or(x_default),
(5, 5) => Self::_set_ixx(x, &config.i55, i, ii).unwrap_or(x_default),
(6, 0) => Self::_set_ixx(x, &config.i60, i, ii).unwrap_or(x_default),
(6, 1) => Self::_set_ixx(x, &config.i61, i, ii).unwrap_or(x_default),
(6, 2) => Self::_set_ixx(x, &config.i62, i, ii).unwrap_or(x_default),
(6, 3) => Self::_set_ixx(x, &config.i63, i, ii).unwrap_or(x_default),
(6, 4) => Self::_set_ixx(x, &config.i64_, i, ii).unwrap_or(x_default),
(6, 5) => Self::_set_ixx(x, &config.i65, i, ii).unwrap_or(x_default),
(7, 0) => Self::_set_ixx(x, &config.i70, i, ii).unwrap_or(x_default),
(7, 1) => Self::_set_ixx(x, &config.i71, i, ii).unwrap_or(x_default),
(7, 2) => Self::_set_ixx(x, &config.i72, i, ii).unwrap_or(x_default),
(7, 3) => Self::_set_ixx(x, &config.i73, i, ii).unwrap_or(x_default),
(7, 4) => Self::_set_ixx(x, &config.i74, i, ii).unwrap_or(x_default),
(7, 5) => Self::_set_ixx(x, &config.i75, i, ii).unwrap_or(x_default),
_ => todo!(),
};
v_.push(x);
@ -146,7 +169,7 @@ impl OrtEngine {
// summary
println!(
"{CHECK_MARK} ORT: 1.{MINOR_VERSION}.x | Opset: {} | EP: {:?} | Dtype: {:?} | Parameters: {}",
"{CHECK_MARK} Backend: ONNXRuntime | OpSet: {} | EP: {:?} | DType: {:?} | Params: {}",
model_proto.opset_import[0].version,
device,
inputs_attrs.dtypes,
@ -291,6 +314,12 @@ impl OrtEngine {
TensorElementType::Int64 => {
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
}
TensorElementType::Uint8 => {
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
}
TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
}
_ => todo!(),
};
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
@ -499,14 +528,12 @@ impl OrtEngine {
let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) {
Some(dtype) => dtype,
None => continue,
// None => anyhow::bail!("DType not supported"),
};
dtypes.push(tensor_type);
let shapes = match &tensor.shape {
Some(shapes) => shapes,
None => continue,
// None => anyhow::bail!("DType has no shapes"),
};
let mut shape_: Vec<isize> = Vec::new();
for shape in shapes.dim.iter() {

View File

@ -1,7 +1,6 @@
//! Some processing functions to image and ndarray.
use anyhow::Result;
use fast_image_resize as fir;
use fast_image_resize::{
images::{CroppedImageMut, Image},
pixels::PixelType,
@ -11,8 +10,6 @@ use image::{DynamicImage, GenericImageView};
use ndarray::{s, Array, Axis, IxDyn};
use rayon::prelude::*;
use crate::X;
pub enum Ops<'a> {
Resize(&'a [DynamicImage], u32, u32, &'a str),
Letterbox(&'a [DynamicImage], u32, u32, &'a str, u8, &'a str, bool),
@ -26,30 +23,13 @@ pub enum Ops<'a> {
}
impl Ops<'_> {
pub fn apply(ops: &[Self]) -> Result<X> {
let mut y = X::default();
for op in ops {
y = match op {
Self::Resize(xs, h, w, filter) => X::resize(xs, *h, *w, filter)?,
Self::Letterbox(xs, h, w, filter, bg, resize_by, center) => {
X::letterbox(xs, *h, *w, filter, *bg, resize_by, *center)?
}
Self::Normalize(min_, max_) => y.normalize(*min_, *max_)?,
Self::Standardize(mean, std, d) => y.standardize(mean, std, *d)?,
Self::Permute(shape) => y.permute(shape)?,
Self::InsertAxis(d) => y.insert_axis(*d)?,
Self::Nhwc2nchw => y.nhwc2nchw()?,
Self::Nchw2nhwc => y.nchw2nhwc()?,
_ => todo!(),
}
}
Ok(y)
}
pub fn normalize(x: Array<f32, IxDyn>, min: f32, max: f32) -> Result<Array<f32, IxDyn>> {
if min > max {
anyhow::bail!("Input `min` is greater than `max`");
if min >= max {
anyhow::bail!(
"Invalid range in `normalize`: `min` ({}) must be less than `max` ({}).",
min,
max
);
}
Ok((x - min) / (max - min))
}
@ -61,11 +41,11 @@ impl Ops<'_> {
dim: usize,
) -> Result<Array<f32, IxDyn>> {
if mean.len() != std.len() {
anyhow::bail!("The lengths of mean and std are not equal.");
anyhow::bail!("`standardize`: `mean` and `std` lengths are not equal. Mean length: {}, Std length: {}.", mean.len(), std.len());
}
let shape = x.shape();
if dim >= shape.len() || shape[dim] != mean.len() {
anyhow::bail!("The specified dimension or mean/std length is inconsistent with the input dimensions.");
anyhow::bail!("`standardize`: Dimension mismatch. `dim` is {} but shape length is {} or `mean` length is {}.", dim, shape.len(), mean.len());
}
let mut shape = vec![1; shape.len()];
shape[dim] = mean.len();
@ -77,11 +57,11 @@ impl Ops<'_> {
pub fn permute(x: Array<f32, IxDyn>, shape: &[usize]) -> Result<Array<f32, IxDyn>> {
if shape.len() != x.shape().len() {
anyhow::bail!(
"Shape inconsistent. Target: {:?}, {}, got: {:?}, {}",
x.shape(),
"`permute`: Shape length mismatch. Expected: {}, got: {}. Target shape: {:?}, provided shape: {:?}.",
x.shape().len(),
shape,
shape.len()
shape.len(),
x.shape(),
shape
);
}
Ok(x.permuted_axes(shape.to_vec()).into_dyn())
@ -98,7 +78,7 @@ impl Ops<'_> {
pub fn insert_axis(x: Array<f32, IxDyn>, d: usize) -> Result<Array<f32, IxDyn>> {
if x.shape().len() < d {
anyhow::bail!(
"The specified axis insertion position {} exceeds the shape's maximum limit of {}.",
"`insert_axis`: The specified axis position {} exceeds the maximum shape length {}.",
d,
x.shape().len()
);
@ -109,7 +89,7 @@ impl Ops<'_> {
pub fn norm(xs: Array<f32, IxDyn>, d: usize) -> Result<Array<f32, IxDyn>> {
if xs.shape().len() < d {
anyhow::bail!(
"The specified axis {} exceeds the shape's maximum limit of {}.",
"`norm`: Specified axis {} exceeds the maximum dimension length {}.",
d,
xs.shape().len()
);
@ -149,22 +129,22 @@ impl Ops<'_> {
crop_src: bool,
filter: &str,
) -> Result<Vec<u8>> {
let src_mask = fir::images::Image::from_vec_u8(
let src = Image::from_vec_u8(
w0 as _,
h0 as _,
v.iter().flat_map(|x| x.to_le_bytes()).collect(),
fir::PixelType::F32,
PixelType::F32,
)?;
let mut dst_mask = fir::images::Image::new(w1 as _, h1 as _, src_mask.pixel_type());
let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type());
let (mut resizer, mut options) = Self::build_resizer_filter(filter)?;
if crop_src {
let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _);
options = options.crop(0., 0., w.into(), h.into());
};
resizer.resize(&src_mask, &mut dst_mask, &options)?;
resizer.resize(&src, &mut dst, &options)?;
// u8*2 -> f32
let mask_f32: Vec<f32> = dst_mask
let mask_f32: Vec<f32> = dst
.into_vec()
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
@ -184,16 +164,15 @@ impl Ops<'_> {
crop_src: bool,
filter: &str,
) -> Result<Vec<u8>> {
let src_mask =
fir::images::Image::from_vec_u8(w0 as _, h0 as _, v.to_vec(), fir::PixelType::U8)?;
let mut dst_mask = fir::images::Image::new(w1 as _, h1 as _, src_mask.pixel_type());
let src = Image::from_vec_u8(w0 as _, h0 as _, v.to_vec(), PixelType::U8)?;
let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type());
let (mut resizer, mut options) = Self::build_resizer_filter(filter)?;
if crop_src {
let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _);
options = options.crop(0., 0., w.into(), h.into());
};
resizer.resize(&src_mask, &mut dst_mask, &options)?;
Ok(dst_mask.into_vec())
resizer.resize(&src, &mut dst, &options)?;
Ok(dst.into_vec())
}
pub fn build_resizer_filter(ty: &str) -> Result<(Resizer, ResizeOptions)> {
@ -205,7 +184,7 @@ impl Ops<'_> {
"Mitchell" => FilterType::Mitchell,
"Gaussian" => FilterType::Gaussian,
"Lanczos3" => FilterType::Lanczos3,
_ => anyhow::bail!("Unsupported resize filter type: {ty}"),
_ => anyhow::bail!("Unsupported resizer's filter type: {ty}"),
};
Ok((
Resizer::new(),
@ -215,22 +194,22 @@ impl Ops<'_> {
pub fn resize(
xs: &[DynamicImage],
height: u32,
width: u32,
th: u32,
tw: u32,
filter: &str,
) -> Result<Array<f32, IxDyn>> {
let mut ys = Array::ones((xs.len(), height as usize, width as usize, 3)).into_dyn();
let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn();
let (mut resizer, options) = Self::build_resizer_filter(filter)?;
for (idx, x) in xs.iter().enumerate() {
let buffer = if x.dimensions() == (width, height) {
let buffer = if x.dimensions() == (tw, th) {
x.to_rgb8().into_raw()
} else {
let mut dst_image = Image::new(width, height, PixelType::U8x3);
resizer.resize(x, &mut dst_image, &options)?;
dst_image.into_vec()
let mut dst = Image::new(tw, th, PixelType::U8x3);
resizer.resize(x, &mut dst, &options)?;
dst.into_vec()
};
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)?
.mapv(|x| x as f32);
let y_ =
Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32);
ys.slice_mut(s![idx, .., .., ..]).assign(&y_);
}
Ok(ys)
@ -238,55 +217,55 @@ impl Ops<'_> {
pub fn letterbox(
xs: &[DynamicImage],
height: u32,
width: u32,
th: u32,
tw: u32,
filter: &str,
bg: u8,
resize_by: &str,
center: bool,
) -> Result<Array<f32, IxDyn>> {
let mut ys = Array::ones((xs.len(), height as usize, width as usize, 3)).into_dyn();
let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn();
let (mut resizer, options) = Self::build_resizer_filter(filter)?;
for (idx, x) in xs.iter().enumerate() {
let (w0, h0) = x.dimensions();
let buffer = if w0 == width && h0 == height {
let buffer = if w0 == tw && h0 == th {
x.to_rgb8().into_raw()
} else {
let (w, h) = match resize_by {
"auto" => {
let r = (width as f32 / w0 as f32).min(height as f32 / h0 as f32);
let r = (tw as f32 / w0 as f32).min(th as f32 / h0 as f32);
(
(w0 as f32 * r).round() as u32,
(h0 as f32 * r).round() as u32,
)
}
"height" => (height * w0 / h0, height),
"width" => (width, width * h0 / w0),
_ => anyhow::bail!("Option: width, height, auto"),
"height" => (th * w0 / h0, th),
"width" => (tw, tw * h0 / w0),
_ => anyhow::bail!("Options for `letterbox`: width, height, auto"),
};
let mut dst_image = Image::from_vec_u8(
width,
height,
vec![bg; 3 * height as usize * width as usize],
let mut dst = Image::from_vec_u8(
tw,
th,
vec![bg; 3 * th as usize * tw as usize],
PixelType::U8x3,
)?;
let (l, t) = if center {
if w == width {
(0, (height - h) / 2)
if w == tw {
(0, (th - h) / 2)
} else {
((width - w) / 2, 0)
((tw - w) / 2, 0)
}
} else {
(0, 0)
};
let mut cropped_dst_image = CroppedImageMut::new(&mut dst_image, l, t, w, h)?;
resizer.resize(x, &mut cropped_dst_image, &options)?;
dst_image.into_vec()
let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?;
resizer.resize(x, &mut dst_cropped, &options)?;
dst.into_vec()
};
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)?
.mapv(|x| x as f32);
let y_ =
Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32);
ys.slice_mut(s![idx, .., .., ..]).assign(&y_);
}
Ok(ys)

View File

@ -4,7 +4,7 @@ use anyhow::Result;
use crate::{
auto_load,
models::{YOLOPreds, YOLOTask, YOLOVersion},
models::{SamKind, YOLOPreds, YOLOTask, YOLOVersion},
Device, MinOptMax,
};
@ -39,7 +39,30 @@ pub struct Options {
pub i33: Option<MinOptMax>,
pub i34: Option<MinOptMax>,
pub i35: Option<MinOptMax>,
pub i40: Option<MinOptMax>,
pub i41: Option<MinOptMax>,
pub i42: Option<MinOptMax>,
pub i43: Option<MinOptMax>,
pub i44: Option<MinOptMax>,
pub i45: Option<MinOptMax>,
pub i50: Option<MinOptMax>,
pub i51: Option<MinOptMax>,
pub i52: Option<MinOptMax>,
pub i53: Option<MinOptMax>,
pub i54: Option<MinOptMax>,
pub i55: Option<MinOptMax>,
pub i60: Option<MinOptMax>,
pub i61: Option<MinOptMax>,
pub i62: Option<MinOptMax>,
pub i63: Option<MinOptMax>,
pub i64_: Option<MinOptMax>,
pub i65: Option<MinOptMax>,
pub i70: Option<MinOptMax>,
pub i71: Option<MinOptMax>,
pub i72: Option<MinOptMax>,
pub i73: Option<MinOptMax>,
pub i74: Option<MinOptMax>,
pub i75: Option<MinOptMax>,
// trt related
pub trt_engine_cache_enable: bool,
pub trt_int8_enable: bool,
@ -63,6 +86,9 @@ pub struct Options {
pub yolo_task: Option<YOLOTask>,
pub yolo_version: Option<YOLOVersion>,
pub yolo_preds: Option<YOLOPreds>,
pub find_contours: bool,
pub sam_kind: Option<SamKind>,
pub use_low_res_mask: Option<bool>,
}
impl Default for Options {
@ -96,6 +122,30 @@ impl Default for Options {
i33: None,
i34: None,
i35: None,
i40: None,
i41: None,
i42: None,
i43: None,
i44: None,
i45: None,
i50: None,
i51: None,
i52: None,
i53: None,
i54: None,
i55: None,
i60: None,
i61: None,
i62: None,
i63: None,
i64_: None,
i65: None,
i70: None,
i71: None,
i72: None,
i73: None,
i74: None,
i75: None,
trt_engine_cache_enable: true,
trt_int8_enable: false,
trt_fp16_enable: false,
@ -116,6 +166,9 @@ impl Default for Options {
yolo_task: None,
yolo_version: None,
yolo_preds: None,
find_contours: false,
sam_kind: None,
use_low_res_mask: None,
}
}
}
@ -171,6 +224,21 @@ impl Options {
self
}
pub fn with_find_contours(mut self, x: bool) -> Self {
self.find_contours = x;
self
}
pub fn with_sam_kind(mut self, x: SamKind) -> Self {
self.sam_kind = Some(x);
self
}
pub fn use_low_res_mask(mut self, x: bool) -> Self {
self.use_low_res_mask = Some(x);
self
}
pub fn with_names(mut self, names: &[&str]) -> Self {
self.names = Some(names.iter().map(|x| x.to_string()).collect::<Vec<String>>());
self
@ -360,4 +428,124 @@ impl Options {
self.i35 = Some(x);
self
}
pub fn with_i40(mut self, x: MinOptMax) -> Self {
self.i40 = Some(x);
self
}
pub fn with_i41(mut self, x: MinOptMax) -> Self {
self.i41 = Some(x);
self
}
pub fn with_i42(mut self, x: MinOptMax) -> Self {
self.i42 = Some(x);
self
}
pub fn with_i43(mut self, x: MinOptMax) -> Self {
self.i43 = Some(x);
self
}
pub fn with_i44(mut self, x: MinOptMax) -> Self {
self.i44 = Some(x);
self
}
pub fn with_i45(mut self, x: MinOptMax) -> Self {
self.i45 = Some(x);
self
}
pub fn with_i50(mut self, x: MinOptMax) -> Self {
self.i50 = Some(x);
self
}
pub fn with_i51(mut self, x: MinOptMax) -> Self {
self.i51 = Some(x);
self
}
pub fn with_i52(mut self, x: MinOptMax) -> Self {
self.i52 = Some(x);
self
}
pub fn with_i53(mut self, x: MinOptMax) -> Self {
self.i53 = Some(x);
self
}
pub fn with_i54(mut self, x: MinOptMax) -> Self {
self.i54 = Some(x);
self
}
pub fn with_i55(mut self, x: MinOptMax) -> Self {
self.i55 = Some(x);
self
}
pub fn with_i60(mut self, x: MinOptMax) -> Self {
self.i60 = Some(x);
self
}
pub fn with_i61(mut self, x: MinOptMax) -> Self {
self.i61 = Some(x);
self
}
pub fn with_i62(mut self, x: MinOptMax) -> Self {
self.i62 = Some(x);
self
}
pub fn with_i63(mut self, x: MinOptMax) -> Self {
self.i63 = Some(x);
self
}
pub fn with_i64(mut self, x: MinOptMax) -> Self {
self.i64_ = Some(x);
self
}
pub fn with_i65(mut self, x: MinOptMax) -> Self {
self.i65 = Some(x);
self
}
pub fn with_i70(mut self, x: MinOptMax) -> Self {
self.i70 = Some(x);
self
}
pub fn with_i71(mut self, x: MinOptMax) -> Self {
self.i71 = Some(x);
self
}
pub fn with_i72(mut self, x: MinOptMax) -> Self {
self.i72 = Some(x);
self
}
pub fn with_i73(mut self, x: MinOptMax) -> Self {
self.i73 = Some(x);
self
}
pub fn with_i74(mut self, x: MinOptMax) -> Self {
self.i74 = Some(x);
self
}
pub fn with_i75(mut self, x: MinOptMax) -> Self {
self.i75 = Some(x);
self
}
}

View File

@ -14,6 +14,12 @@ impl From<Array<f32, IxDyn>> for X {
}
}
impl From<Vec<f32>> for X {
fn from(x: Vec<f32>) -> Self {
Self(Array::from_vec(x).into_dyn().into_owned())
}
}
impl std::ops::Deref for X {
type Target = Array<f32, IxDyn>;
@ -28,7 +34,23 @@ impl X {
}
pub fn apply(ops: &[Ops]) -> Result<Self> {
Ops::apply(ops)
let mut y = Self::default();
for op in ops {
y = match op {
Ops::Resize(xs, h, w, filter) => Self::resize(xs, *h, *w, filter)?,
Ops::Letterbox(xs, h, w, filter, bg, resize_by, center) => {
Self::letterbox(xs, *h, *w, filter, *bg, resize_by, *center)?
}
Ops::Normalize(min_, max_) => y.normalize(*min_, *max_)?,
Ops::Standardize(mean, std, d) => y.standardize(mean, std, *d)?,
Ops::Permute(shape) => y.permute(shape)?,
Ops::InsertAxis(d) => y.insert_axis(*d)?,
Ops::Nhwc2nchw => y.nhwc2nchw()?,
Ops::Nchw2nhwc => y.nchw2nhwc()?,
_ => todo!(),
}
}
Ok(y)
}
pub fn permute(mut self, shape: &[usize]) -> Result<Self> {

View File

@ -1,91 +1,150 @@
//! A Rust library integrated with ONNXRuntime, providing a collection of Computer Vison and Vision-Language models.
//! A Rust library integrated with ONNXRuntime, providing a collection of **Computer Vision** and **Vision-Language** models.
//!
//! [`OrtEngine`] provides ONNX model loading, metadata parsing, dry_run, inference and other functions, supporting EPs such as CUDA, TensorRT, CoreML, etc. You can use it as the ONNXRuntime engine for building models.
//! # Supported Models
//!
//! - [YOLOv5](https://github.com/ultralytics/yolov5): Object Detection, Instance Segmentation, Classification
//! - [YOLOv6](https://github.com/meituan/YOLOv6): Object Detection
//! - [YOLOv7](https://github.com/WongKinYiu/yolov7): Object Detection
//! - [YOLOv8](https://github.com/ultralytics/ultralytics): Object Detection, Instance Segmentation, Classification, Oriented Object Detection, Keypoint Detection
//! - [YOLOv9](https://github.com/WongKinYiu/yolov9): Object Detection
//! - [YOLOv10](https://github.com/THU-MIG/yolov10): Object Detection
//! - [RT-DETR](https://arxiv.org/abs/2304.08069): Object Detection
//! - [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM): Instance Segmentation
//! - [SAM](https://github.com/facebookresearch/segment-anything): Segmentation Anything
//! - [MobileSAM](https://github.com/ChaoningZhang/MobileSAM): Segmentation Anything
//! - [EdgeSAM](https://github.com/chongzhou96/EdgeSAM): Segmentation Anything
//! - [SAM-HQ](https://github.com/SysCV/sam-hq): Segmentation Anything
//! - [YOLO-World](https://github.com/AILab-CVC/YOLO-World): Object Detection
//! - [DINOv2](https://github.com/facebookresearch/dinov2): Vision-Self-Supervised
//! - [CLIP](https://github.com/openai/CLIP): Vision-Language
//! - [BLIP](https://github.com/salesforce/BLIP): Vision-Language
//! - [DB](https://arxiv.org/abs/1911.08947): Text Detection
//! - [SVTR](https://arxiv.org/abs/2205.00159): Text Recognition
//! - [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo): Keypoint Detection
//! - [YOLOPv2](https://arxiv.org/abs/2208.11434): Panoptic Driving Perception
//! - [Depth-Anything (v1, v2)](https://github.com/LiheYoung/Depth-Anything): Monocular Depth Estimation
//! - [MODNet](https://github.com/ZHKKKe/MODNet): Image Matting
//!
//!
//!
//! # Supported models
//! | Model | Task / Type |
//! | :---------------------------------------------------------------: | :-------------------------: |
//! | [YOLOv5](https://github.com/ultralytics/yolov5) | Object Detection<br />Instance Segmentation<br />Classification |
//! | [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection |
//! | [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection |
//! | [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection<br />Instance Segmentation<br />Classification<br />Oriented Object Detection<br />Keypoint Detection |
//! | [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection |
//! | [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection |
//! | [RT-DETR](https://arxiv.org/abs/2304.08069) | Object Detection |
//! | [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation |
//! | [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Object Detection |
//! | [DINOv2](https://github.com/facebookresearch/dinov2) | Vision-Self-Supervised |
//! | [CLIP](https://github.com/openai/CLIP) | Vision-Language |
//! | [BLIP](https://github.com/salesforce/BLIP) | Vision-Language |
//! | [DB](https://arxiv.org/abs/1911.08947) | Text Detection |
//! | [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition |
//! | [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection |
//! | [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception |
//! | [Depth-Anything<br />(v1, v2)](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation |
//! | [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting |
//! # Examples
//! [All Examples Here](https://github.com/jamjamjon/usls/tree/main/examples)
//! # Use provided models for inference
//! #### 1. Using provided [`models`] with [`Option`]
//! ```Rust, no_run
//!
//! [All Demos Here](https://github.com/jamjamjon/usls/tree/main/examples)
//!
//! # Using Provided Models for Inference
//!
//! #### 1. Build Model
//! Using provided [`models`] with [`Options`]
//!
//! ```rust, no_run
//! use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision};
//!
//! let options = Options::default()
//! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
//! .with_model("xxxx.onnx")?;
//! .with_trt(0)
//! .with_fp16(true)
//! .with_i00((1, 1, 4).into())
//! .with_i02((224, 640, 800).into())
//! .with_i03((224, 640, 800).into())
//! .with_confs(&[0.4, 0.15]) // class_0: 0.4, others: 0.15
//! .with_profile(false);
//! let mut model = YOLO::new(options)?;
//! ```
//! #### 2. Load images using [`DataLoader`] or [`image::io::Reader`]
//!
//! ```Rust, no_run
//! // Load one image
//! - Use `CUDA`, `TensorRT`, or `CoreML`
//!
//! ```rust, no_run
//! let options = Options::default()
//! .with_cuda(0) // using CUDA by default
//! // .with_trt(0)
//! // .with_coreml(0)
//! // .with_cpu();
//! ```
//!
//! - Dynamic Input Shapes
//!
//! ```rust, no_run
//! let options = Options::default()
//! .with_i00((1, 2, 4).into()) // dynamic batch
//! .with_i02((416, 640, 800).into()) // dynamic height
//! .with_i03((416, 640, 800).into()); // dynamic width
//! ```
//!
//! - Set Confidence Thresholds for Each Category
//!
//! ```rust, no_run
//! let options = Options::default()
//! .with_confs(&[0.4, 0.15]); // class_0: 0.4, others: 0.15
//! ```
//!
//! - Set Class Names
//!
//! ```rust, no_run
//! let options = Options::default()
//! .with_names(&coco::NAMES_80);
//! ```
//!
//! More options can be found in the [`Options`] documentation.
//!
//! #### 2. Load Images
//!
//! Ensure that the input image is RGB type.
//!
//! - Using [`image::ImageReader`] or [`DataLoader`] to Load One Image
//!
//! ```rust, no_run
//! let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
//! // or
//! let x = image::ImageReader::open("myimage.png")?.decode()?;
//! ```
//!
//! // Load images with batch_size = 4
//! - Using [`DataLoader`] to Load a Batch of Images
//!
//! ```rust, no_run
//! let dl = DataLoader::default()
//! .with_batch(4)
//! .load("./assets")?;
//! // Load one image with `image::io::Reader`
//! let x = image::io::Reader::open("myimage.png")?.decode()?
//! ```
//!
//! #### 3. Build annotator using [`Annotator`]
//! #### 3. (Optional) Annotate Results with [`Annotator`]
//!
//! ```Rust, no_run
//! ```rust, no_run
//! let annotator = Annotator::default();
//! ```
//!
//! - Set Saveout Name
//!
//! ```rust, no_run
//! let annotator = Annotator::default()
//! .with_bboxes_thickness(4)
//! .with_saveout("YOLOs");
//! ```
//!
//! - Set Bboxes Line Width
//!
//! #### 4. Run and annotate
//! ```rust, no_run
//! let annotator = Annotator::default()
//! .with_bboxes_thickness(4);
//! ```
//!
//! - Disable Mask Plotting
//!
//! ```rust, no_run
//! let annotator = Annotator::default()
//! .without_masks(true);
//! ```
//!
//! More options can be found in the [`Annotator`] documentation.
//!
//! ```Rust, no_run
//!
//! #### 4. Run and Annotate
//!
//! ```rust, no_run
//! for (xs, _paths) in dl {
//! let ys = model.run(&xs)?;
//! annotator.annotate(&xs, &ys);
//! }
//! ```
//!
//! #### 5. Parse inference results from [`Vec<Y>`]
//! For example, uou can get detection bboxes with `y.bboxes()`:
//! ```Rust, no_run
//! #### 5. Get Results
//!
//! The inference outputs of provided models will be saved to a [`Vec<Y>`].
//!
//! - For Example, Get Detection Bboxes with `y.bboxes()`
//!
//! ```rust, no_run
//! let ys = model.run(&xs)?;
//! for y in ys {
//! // bboxes
@ -99,18 +158,17 @@
//! bbox.ymax(),
//! bbox.confidence(),
//! bbox.id(),
//! )
//! );
//! }
//! }
//! }
//! ```
//! ```
//!
//! # Also, You Can Implement Your Own Model with [`OrtEngine`] and [`Options`]
//!
//! # Build your own model with [`OrtEngine`]
//!
//! Refer to [Demo: Depth-Anything](https://github.com/jamjamjon/usls/blob/main/src/models/depth_anything.rs)
//!
//! [`OrtEngine`] provides ONNX model loading, metadata parsing, dry_run, inference, and other functions, supporting EPs such as CUDA, TensorRT, CoreML, etc. You can use it as the ONNXRuntime engine for building models.
//!
//! Refer to [Demo: Depth-Anything](https://github.com/jamjamjon/usls/blob/main/src/models/depth_anything.rs) for more details.
mod core;
pub mod models;

View File

@ -27,7 +27,7 @@ impl Blip {
visual.height().to_owned(),
visual.width().to_owned(),
);
let tokenizer = Tokenizer::from_file(&options_textual.tokenizer.unwrap()).unwrap();
let tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap();
let tokenizer = TokenizerStream::new(tokenizer);
visual.dry_run()?;
textual.dry_run()?;

View File

@ -28,7 +28,7 @@ impl Clip {
visual.inputs_minoptmax()[0][2].to_owned(),
visual.inputs_minoptmax()[0][3].to_owned(),
);
let mut tokenizer = Tokenizer::from_file(&options_textual.tokenizer.unwrap()).unwrap();
let mut tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap();
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(context_length),
direction: PaddingDirection::Right,

View File

@ -119,31 +119,31 @@ impl DB {
continue;
}
let mask = Polygon::default().with_points_imageproc(&contour.points);
let delta = mask.area() * ratio.round() as f64 * self.unclip_ratio as f64
/ mask.perimeter();
let polygon = Polygon::default().with_points_imageproc(&contour.points);
let delta = polygon.area() * ratio.round() as f64 * self.unclip_ratio as f64
/ polygon.perimeter();
// TODO: optimize
let mask = mask
let polygon = polygon
.unclip(delta, image_width as f64, image_height as f64)
.resample(50)
// .simplify(6e-4)
.convex_hull();
if let Some(bbox) = mask.bbox() {
if let Some(bbox) = polygon.bbox() {
if bbox.height() < self.min_height || bbox.width() < self.min_width {
continue;
}
let confidence = mask.area() as f32 / bbox.area();
let confidence = polygon.area() as f32 / bbox.area();
if confidence < self.confs[0] {
continue;
}
y_bbox.push(bbox.with_confidence(confidence).with_id(0));
if let Some(mbr) = mask.mbr() {
if let Some(mbr) = polygon.mbr() {
y_mbrs.push(mbr.with_confidence(confidence).with_id(0));
}
y_polygons.push(mask.with_id(0));
y_polygons.push(polygon.with_id(0));
} else {
continue;
}

View File

@ -8,6 +8,7 @@ mod dinov2;
mod modnet;
mod rtdetr;
mod rtmo;
mod sam;
mod svtr;
mod yolo;
mod yolo_;
@ -21,10 +22,8 @@ pub use dinov2::Dinov2;
pub use modnet::MODNet;
pub use rtdetr::RTDETR;
pub use rtmo::RTMO;
pub use sam::{SamKind, SamPrompt, SAM};
pub use svtr::SVTR;
pub use yolo::YOLO;
pub use yolo_::*;
// {
// AnchorsPosition, BoxType, ClssType, KptsType, YOLOFormat, YOLOPreds, YOLOTask, YOLOVersion,
// };
pub use yolop::YOLOPv2;

291
src/models/sam.rs Normal file
View File

@ -0,0 +1,291 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array, Axis};
use rand::prelude::*;
use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y};
#[derive(Debug, Clone, clap::ValueEnum)]
pub enum SamKind {
Sam,
MobileSam,
SamHq,
EdgeSam,
}
#[derive(Debug, Default, Clone)]
pub struct SamPrompt {
points: Vec<f32>,
labels: Vec<f32>,
}
impl SamPrompt {
pub fn everything() -> Self {
todo!()
}
pub fn with_postive_point(mut self, x: f32, y: f32) -> Self {
self.points.extend_from_slice(&[x, y]);
self.labels.push(1.);
self
}
pub fn with_negative_point(mut self, x: f32, y: f32) -> Self {
self.points.extend_from_slice(&[x, y]);
self.labels.push(0.);
self
}
pub fn with_bbox(mut self, x: f32, y: f32, x2: f32, y2: f32) -> Self {
self.points.extend_from_slice(&[x, y, x2, y2]);
self.labels.extend_from_slice(&[2., 3.]);
self
}
pub fn point_coords(&self, r: f32) -> Result<X> {
let point_coords = Array::from_shape_vec((1, self.num_points(), 2), self.points.clone())?
.into_dyn()
.into_owned();
Ok(X::from(point_coords * r))
}
pub fn point_labels(&self) -> Result<X> {
let point_labels = Array::from_shape_vec((1, self.num_points()), self.labels.clone())?
.into_dyn()
.into_owned();
Ok(X::from(point_labels))
}
pub fn num_points(&self) -> usize {
self.points.len() / 2
}
}
#[derive(Debug)]
pub struct SAM {
encoder: OrtEngine,
decoder: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
pub conf: DynConf,
find_contours: bool,
kind: SamKind,
use_low_res_mask: bool,
}
impl SAM {
pub fn new(options_encoder: Options, options_decoder: Options) -> Result<Self> {
let mut encoder = OrtEngine::new(&options_encoder)?;
let mut decoder = OrtEngine::new(&options_decoder)?;
let (batch, height, width) = (
encoder.inputs_minoptmax()[0][0].to_owned(),
encoder.inputs_minoptmax()[0][2].to_owned(),
encoder.inputs_minoptmax()[0][3].to_owned(),
);
let conf = DynConf::new(&options_decoder.confs, 1);
let kind = match options_decoder.sam_kind {
Some(x) => x,
None => anyhow::bail!("Error: no clear `SamKind` specified."),
};
let find_contours = options_decoder.find_contours;
let use_low_res_mask = match kind {
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
options_decoder.use_low_res_mask.unwrap_or(false)
}
SamKind::EdgeSam => true,
};
encoder.dry_run()?;
decoder.dry_run()?;
Ok(Self {
encoder,
decoder,
batch,
height,
width,
conf,
kind,
find_contours,
use_low_res_mask,
})
}
pub fn run(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let ys = self.encode(xs)?;
self.decode(ys, xs, prompts)
}
pub fn encode(&mut self, xs: &[DynamicImage]) -> Result<Vec<X>> {
let xs_ = X::apply(&[
Ops::Letterbox(
xs,
self.height() as u32,
self.width() as u32,
"Bilinear",
0,
"auto",
false,
),
Ops::Standardize(&[123.675, 116.28, 103.53], &[58.395, 57.12, 57.375], 3),
Ops::Nhwc2nchw,
])?;
self.encoder.run(vec![xs_])
}
pub fn decode(
&mut self,
xs: Vec<X>,
xs0: &[DynamicImage],
prompts: &[SamPrompt],
) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
for (idx, image_embedding) in xs[0].axis_iter(Axis(0)).enumerate() {
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
let ratio =
(self.width() as f32 / image_width).min(self.height() as f32 / image_height);
let args = match self.kind {
SamKind::Sam | SamKind::MobileSam => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input,
X::zeros(&[1]), // has_mask_input
X::from(vec![image_height, image_width]), // orig_im_size
]
}
SamKind::SamHq => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding
X::from(xs[1].slice(s![idx, .., .., ..]).into_dyn().into_owned())
.insert_axis(0)?
.insert_axis(0)?, // intern_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
X::zeros(&[1]), // has_mask_input
X::from(vec![image_height, image_width]), // orig_im_size
]
}
SamKind::EdgeSam => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
]
}
};
let ys_ = self.decoder.run(args)?;
let mut y_masks: Vec<Mask> = Vec::new();
let mut y_polygons: Vec<Polygon> = Vec::new();
// masks & confs
let (masks, confs) = match self.kind {
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
if !self.use_low_res_mask {
(&ys_[0], &ys_[1])
} else {
(&ys_[2], &ys_[1])
}
}
SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) {
(2, 4) => (&ys_[1], &ys_[0]),
(4, 2) => (&ys_[0], &ys_[1]),
_ => anyhow::bail!("Can not parse the outputs of decoder."),
},
};
for (mask, iou) in masks.axis_iter(Axis(0)).zip(confs.axis_iter(Axis(0))) {
let (i, conf) = match iou
.to_owned()
.into_raw_vec()
.into_iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(&b.1))
{
Some((i, c)) => (i, c),
None => continue,
};
if conf < self.conf[0] {
continue;
}
let mask = mask.slice(s![i, .., ..]);
let (h, w) = mask.dim();
let luma = if self.use_low_res_mask {
Ops::resize_lumaf32_vec(
&mask.to_owned().into_raw_vec(),
w as _,
h as _,
image_width as _,
image_height as _,
true,
"Bilinear",
)?
} else {
mask.mapv(|x| if x > 0. { 255u8 } else { 0u8 })
.into_raw_vec()
};
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(image_width as _, image_height as _, luma) {
None => continue,
Some(x) => x,
};
// contours
let mut rng = thread_rng();
let id = rng.gen_range(0..20);
if self.find_contours {
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&luma, 0);
for c in contours.iter() {
let polygon = Polygon::default().with_points_imageproc(&c.points);
y_polygons.push(polygon.with_confidence(iou[0]).with_id(id));
}
}
y_masks.push(Mask::default().with_mask(luma).with_id(id));
}
let mut y = Y::default();
if !y_masks.is_empty() {
y = y.with_masks(&y_masks);
}
if !y_polygons.is_empty() {
y = y.with_polygons(&y_polygons);
}
ys.push(y);
}
Ok(ys)
}
pub fn width_low_res(&self) -> usize {
self.width() as usize / 4
}
pub fn height_low_res(&self) -> usize {
self.height() as usize / 4
}
pub fn batch(&self) -> isize {
self.batch.opt
}
pub fn width(&self) -> isize {
self.width.opt
}
pub fn height(&self) -> isize {
self.height.opt
}
}

View File

@ -24,6 +24,7 @@ pub struct YOLO {
names_kpt: Option<Vec<String>>,
task: YOLOTask,
layout: YOLOPreds,
find_contours: bool,
version: Option<YOLOVersion>,
}
@ -153,6 +154,7 @@ impl Vision for YOLO {
names_kpt,
layout,
version,
find_contours: options.find_contours,
})
}
@ -417,7 +419,6 @@ impl Vision for YOLO {
.into_par_iter()
.filter_map(|bbox| {
let coefs = coefs.slice(s![bbox.id_born(), ..]).to_vec();
let proto = protos.as_ref()?.slice(s![idx, .., .., ..]);
let (nm, mh, mw) = proto.dim();
@ -461,10 +462,9 @@ impl Vision for YOLO {
}
// Find contours
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&mask, 0);
Some((
let polygons = if self.find_contours {
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&mask, 0);
contours
.into_par_iter()
.map(|x| {
@ -473,7 +473,13 @@ impl Vision for YOLO {
.with_points_imageproc(&x.points)
.with_name(bbox.name().cloned())
})
.max_by(|x, y| x.area().total_cmp(&y.area()))?,
.max_by(|x, y| x.area().total_cmp(&y.area()))?
} else {
Polygon::default()
};
Some((
polygons,
Mask::default()
.with_mask(mask)
.with_id(bbox.id())
@ -482,7 +488,12 @@ impl Vision for YOLO {
})
.collect::<(Vec<_>, Vec<_>)>();
y = y.with_polygons(&y_polygons).with_masks(&y_masks);
if !y_polygons.is_empty() {
y = y.with_polygons(&y_polygons);
}
if !y_masks.is_empty() {
y = y.with_masks(&y_masks);
}
}
}

View File

@ -64,6 +64,11 @@ impl Polygon {
self
}
pub fn with_confidence(mut self, x: f32) -> Self {
self.confidence = x;
self
}
pub fn id(&self) -> isize {
self.id
}