Add YOLOPv2 & Face-Parsing model (#3)

* Add YOLOP and face parsing model
This commit is contained in:
Jamjamjon
2024-04-14 15:15:59 +08:00
committed by GitHub
parent ead175234c
commit 51b75e9a21
63 changed files with 1168 additions and 885 deletions

120
README.md
View File

@ -1,49 +1,50 @@
# usls
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`.
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) and others.
## Supported Models
| Model | Task / Type | Example | CUDA<br />f32 | CUDA<br />f16 | TensorRT<br />f32 | TensorRT<br />f16 |
| :---------------------------------------------------------------: | :----------------------: |:----------------------: | :-----------: | :-----------: | :------------------------: | :-----------------------: |
| **[YOLOv8-detection](https://github.com/ultralytics/ultralytics)** | Object Detection | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| **[YOLOv8-pose](https://github.com/ultralytics/ultralytics)** | Keypoint Detection | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| **[YOLOv8-classification](https://github.com/ultralytics/ultralytics)** | Classification | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| **[YOLOv8-segmentation](https://github.com/ultralytics/ultralytics)** | Instance Segmentation | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| **[YOLOv9](https://github.com/WongKinYiu/yolov9)** | Object Detection | [demo](examples/yolov9) | ✅ | ✅ | ✅ | ✅ |
| **[RT-DETR](https://arxiv.org/abs/2304.08069)** | Object Detection | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | ✅ |
| **[FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)** | Instance Segmentation | [demo](examples/fastsam) | ✅ | ✅ | ✅ | ✅ |
| **[YOLO-World](https://github.com/AILab-CVC/YOLO-World)** | Object Detection | [demo](examples/yolo-world) | ✅ | ✅ | ✅ | ✅ |
| **[DINOv2](https://github.com/facebookresearch/dinov2)** | Vision-Self-Supervised | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ |
| **[CLIP](https://github.com/openai/CLIP)** | Vision-Language | [demo](examples/clip) | ✅ | ✅ | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| **[BLIP](https://github.com/salesforce/BLIP)** | Vision-Language | [demo](examples/blip) | ✅ | ✅ | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| [**DB**](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | ✅ | | ✅ | ✅ |
| [**SVTR**](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | | ✅ | ✅ |
| [**RTMO**](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | | |
| Model | Task / Type | Example | CUDA<br />f32 | CUDA<br />f16 | TensorRT<br />f32 | TensorRT<br />f16 |
| :---------------------------------------------------------------: | :------------------------------------------------------------------------: | :----------------------: | :-----------: | :-----------: | :------------------------: | :-----------------------: |
| [YOLOv8-detection](https://github.com/ultralytics/ultralytics) | Object Detection | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv8-pose](https://github.com/ultralytics/ultralytics) | Keypoint Detection | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv8-classification](https://github.com/ultralytics/ultralytics) | Classification | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv8-segmentation](https://github.com/ultralytics/ultralytics) | Instance Segmentation | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolov9) | ✅ | ✅ | ✅ | ✅ |
| [RT-DETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | ✅ |
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/fastsam) | ✅ | ✅ | ✅ | ✅ |
| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Object Detection | [demo](examples/yolo-world) | ✅ | ✅ | ✅ | ✅ |
| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision-Self-Supervised | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ |
| [CLIP](https://github.com/openai/CLIP) | Vision-Language | [demo](examples/clip) | ✅ | ✅ | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| [BLIP](https://github.com/salesforce/BLIP) | Vision-Language | [demo](examples/blip) | ✅ | ✅ | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
| [DB](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | ✅ | | ✅ | ✅ |
| [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | | ✅ | ✅ |
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | | | |
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
## Solution Models
Additionally, this repo also provides some solution models.
| Model | Example |
| :--------------------------------------------------------------------------------: | :------------------------------: |
| **text detection<br />(PPOCR-det v3, v4)**<br />**通用文本检测** | [demo](examples/db) |
| **text recognition<br />(PPOCR-rec v3, v4)**<br />**中英文-文本识别** | [demo](examples/svtr) |
| **face-landmark detection**<br />**人脸 & 关键点检测** | [demo](examples/yolov8-face) |
| **head detection**<br /> **人头检测** | [demo](examples/yolov8-head) |
| **fall detection**<br /> **摔倒检测** | [demo](examples/yolov8-falldown) |
| **trash detection**<br /> **垃圾检测** | [demo](examples/yolov8-plastic-bag) |
| Model | Example | Result |
| :------------------------------------------------------------: | :------------------------------: | :------------------------------: |
| Lane Line Segmentation<br /> Drivable Area Segmentation<br />Car Detection<br />车道线-可行驶区域-车辆检测 | [demo](examples/yolov8-plastic-bag) |<img src='examples/yolop/demo.png' width="220px" height="140px">|
| Face Parsing<br /> 人脸解析 | [demo](examples/face-parsing) |<img src='examples/face-parsing/demo.png' width="220px" height="200px"> |
| Text Detection<br />(PPOCR-det v3, v4)<br />通用文本检测 | [demo](examples/db) |<img src='examples/db/demo.jpg' width="250px" height="200px">|
| Text Recognition<br />(PPOCR-rec v3, v4)<br />中英文-文本识别 | [demo](examples/svtr) ||
| Face-Landmark Detection<br />人脸 & 关键点检测 | [demo](examples/yolov8-face) |<img src='examples/yolov8-face/demo.jpg' width="220px" height="180px">|
| Head Detection<br /> 人头检测 | [demo](examples/yolov8-head) |<img src='examples/yolov8-head/demo.jpg' width="220px" height="180px">|
| Fall Detection<br /> 摔倒检测 | [demo](examples/yolov8-falldown) | <img src='examples/yolov8-falldown/demo.jpg' width="220px" height="180px">|
| Trash Detection<br /> 垃圾检测 | [demo](examples/yolov8-plastic-bag) |<img src='examples/yolov8-trash/demo.jpg' width="250px" height="180px">|
## Demo
```
cargo run -r --example yolov8 # fastsam, yolov9, blip, clip, dinov2, yolo-world...
cargo run -r --example yolov8 # yolov9, blip, clip, dinov2, svtr, db, yolo-world...
```
## Integrate into your own project
#### 1. Install [ort](https://github.com/pykeio/ort)
## Installation
check **[ort guide](https://ort.pyke.io/setup/linking)**
@ -58,13 +59,16 @@ check **[ort guide](https://ort.pyke.io/setup/linking)**
</details>
#### 2. Add `usls` as a dependency to your project's `Cargo.toml`
## Integrate into your own project
#### 1. Add `usls` as a dependency to your project's `Cargo.toml`
```shell
cargo add --git https://github.com/jamjamjon/usls
```
#### 3. Set `Options` and build model
#### 2. Set `Options` and build model
```Rust
let options = Options::default()
@ -73,32 +77,29 @@ let mut model = YOLO::new(&options)?;
```
- If you want to run your model with TensorRT or CoreML
```Rust
let options = Options::default()
.with_trt(0) // using cuda by default
// .with_coreml(0)
```
```Rust
let options = Options::default()
.with_trt(0) // using cuda by default
// .with_coreml(0)
```
- If your model has dynamic shapes
```Rust
let options = Options::default()
.with_i00((1, 2, 4).into()) // dynamic batch
.with_i02((416, 640, 800).into()) // dynamic height
.with_i03((416, 640, 800).into()) // dynamic width
```
```Rust
let options = Options::default()
.with_i00((1, 2, 4).into()) // dynamic batch
.with_i02((416, 640, 800).into()) // dynamic height
.with_i03((416, 640, 800).into()) // dynamic width
```
- If you want to set a confidence level for each category
```Rust
let options = Options::default()
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
```
```Rust
let options = Options::default()
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
```
- Go check [Options](src/options.rs) for more model options.
#### 4. Prepare inputs, and then you're ready to go
#### 3. Prepare inputs, and then you're ready to go
- Build `DataLoader` to load images
@ -119,22 +120,9 @@ let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
let y = model.run(&x)?;
```
#### 5. Annotate and save results
#### 4. Annotate and save results
```Rust
let annotator = Annotator::default().with_saveout("YOLOv8");
annotator.annotate(&x, &y);
```
## Script: converte ONNX model from `float32` to `float16`
```python
import onnx
from pathlib import Path
from onnxconverter_common import float16
model_f32 = "onnx_model.onnx"
model_f16 = float16.convert_float_to_float16(onnx.load(model_f32))
saveout = Path(model_f32).with_name(Path(model_f32).stem + "-f16.onnx")
onnx.save(model_f16, saveout)
```

BIN
assets/car.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

BIN
assets/nini.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 408 KiB

8
convert2f16.py Normal file
View File

@ -0,0 +1,8 @@
import onnx
from pathlib import Path
from onnxconverter_common import float16
model_f32 = "onnx_model.onnx"
model_f16 = float16.convert_float_to_float16(onnx.load(model_f32))
saveout = Path(model_f32).with_name(Path(model_f32).stem + "-f16.onnx")
onnx.save(model_f16, saveout)

View File

@ -1,41 +1,15 @@
This demo shows how to use [BLIP](https://arxiv.org/abs/2201.12086) to do conditional or unconditional image captioning.
## Quick Start
```shell
cargo run -r --example blip
```
## Or you can manully
## BLIP ONNX Model
### 1. Donwload CLIP ONNX Model
[blip-visual-base](https://github.com/jamjamjon/assets/releases/download/v0.0.1/blip-visual-base.onnx)
[blip-textual-base](https://github.com/jamjamjon/assets/releases/download/v0.0.1/blip-textual-base.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
// visual
let options_visual = Options::default()
.with_model("VISUAL_MODEL") // <= modify this
.with_profile(false);
// textual
let options_textual = Options::default()
.with_model("TEXTUAL_MODEL") // <= modify this
.with_profile(false);
```
### 3. Then, run
```bash
cargo run -r --example blip
```
- [blip-visual-base](https://github.com/jamjamjon/assets/releases/download/v0.0.1/blip-visual-base.onnx)
- [blip-textual-base](https://github.com/jamjamjon/assets/releases/download/v0.0.1/blip-textual-base.onnx)
## Results

View File

@ -6,37 +6,10 @@ This demo showcases how to use [CLIP](https://github.com/openai/CLIP) to compute
cargo run -r --example clip
```
## Or you can manully
### 1.Donwload CLIP ONNX Model
[clip-b32-visual](https://github.com/jamjamjon/assets/releases/download/v0.0.1/clip-b32-visual.onnx)
[clip-b32-textual](https://github.com/jamjamjon/assets/releases/download/v0.0.1/clip-b32-textual.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
// visual
let options_visual = Options::default()
.with_model("VISUAL_MODEL") // <= modify this
.with_i00((1, 1, 4).into())
.with_profile(false);
// textual
let options_textual = Options::default()
.with_model("TEXTUAL_MODEL") // <= modify this
.with_i00((1, 1, 4).into())
.with_profile(false);
```
### 3. Then, run
```bash
cargo run -r --example clip
```
## CLIP ONNX Model
- [clip-b32-visual](https://github.com/jamjamjon/assets/releases/download/v0.0.1/clip-b32-visual.onnx)
- [clip-b32-textual](https://github.com/jamjamjon/assets/releases/download/v0.0.1/clip-b32-textual.onnx)
## Results
@ -50,9 +23,4 @@ cargo run -r --example clip
(86.59852%) ./examples/clip/images/doll.jpg => There is a doll with red hair and a clock on a table
[0.07032883, 0.00053773675, 0.0006372929, 0.06066096, 0.0007378078, 0.8659852, 0.0011121632]
```
## TODO
* [ ] TensorRT support for textual model
```

View File

@ -4,25 +4,10 @@
cargo run -r --example db
```
## Or you can manully
## ONNX Model
### 1. Donwload ONNX Model
[ppocr-v3-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-db-dyn.onnx)
[ppocr-v4-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-db-dyn.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
```
### 3. Run
```bash
cargo run -r --example db
```
- [ppocr-v3-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-db-dyn.onnx)
- [ppocr-v4-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-db-dyn.onnx)
### Speed test

View File

@ -22,9 +22,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// annotate
let annotator = Annotator::default()
.with_polygon_color([255u8, 0u8, 0u8])
.without_name(true)
.without_polygons(false)
.with_mask_alpha(0)
.without_bboxes(false)
.with_saveout("DB-Text-Detection");
annotator.annotate(&x, &y);

View File

@ -0,0 +1,93 @@
Using `YOLOv8-seg` model trained on `CelebAMask-HQ` for face-parsing.
## Quick Start
```shell
cargo run -r --example face-parsing
```
## Pretrained Model
- [face-parsing-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/face-parsing-dyn.onnx)
- [face-parsing-dyn-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/face-parsing-dyn-f16.onnx)
## Datasets
- [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing)
## YOLO Labels
- [Download Processed YOLO labels](https://github.com/jamjamjon/assets/releases/download/v0.0.1/CelebAMask-HQ-YOLO-Labels.zip)
- Or you can run Python script
```Python
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
mapping = {
'background': 0,
'skin': 1,
'nose': 2,
'eye_g': 3,
'l_eye': 4,
'r_eye': 5,
'l_brow': 6,
'r_brow': 7,
'l_ear': 8,
'r_ear': 9,
'mouth': 10,
'u_lip': 11,
'l_lip': 12,
'hair': 13,
'hat': 14,
'ear_r': 15,
'neck_l': 16,
'neck': 17,
'cloth': 18
}
def main():
saveout_dir = Path("labels")
if not saveout_dir.exists():
saveout_dir.mkdir()
else:
import shutil
shutil.rmtree(saveout_dir)
saveout_dir.mkdir()
image_list = [x for x in Path("CelebAMask-HQ-mask-anno/").rglob("*.png")]
for image_path in tqdm(image_list, total=len(image_list)):
image_gray = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
stem = image_path.stem
name, cls_ = stem.split("_", 1)
segments = cv2.findContours(image_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
saveout = saveout_dir / f"{int(name)}.txt"
with open(saveout, 'a+') as f:
for segment in segments:
line = f"{mapping[cls_]}"
segment = segment / 512
for seg in segment:
xn, yn = seg[0]
line += f" {xn} {yn}"
f.write(line + "\n")
if __name__ == "__main__":
main()
```
## Results
![](./demo.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 448 KiB

View File

@ -0,0 +1,33 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model
let options = Options::default()
.with_model("../models/face-parsing-dyn.onnx")
.with_i00((1, 1, 4).into())
.with_i02((416, 640, 800).into())
.with_i03((416, 640, 800).into())
// .with_trt(0)
// .with_fp16(true)
// .with_dry_run(10)
.with_confs(&[0.5]);
let mut model = YOLO::new(&options)?;
// load image
let x = vec![DataLoader::try_read("./assets/nini.png")?];
// run
let y = model.run(&x)?;
// annotate
let annotator = Annotator::default()
.without_conf(true)
.without_name(true)
.without_polygons(false)
.without_bboxes(true)
.with_masks_name(false)
.with_saveout("Face-Parsing");
annotator.annotate(&x, &y);
Ok(())
}

View File

@ -4,10 +4,7 @@
cargo run -r --example fastsam
```
## Or you can manully
### 1.Donwload or export ONNX Model
## Donwload or export ONNX Model
- **Export**
@ -20,20 +17,6 @@ cargo run -r --example fastsam
[FastSAM-s-dyn-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/FastSAM-s-dyn-f16.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("../models/FastSAM-s-dyn-f16.onnx") // <= modify this
.with_profile(false);
let mut model = YOLO::new(&options)?;
```
### 3. Then, run
```bash
cargo run -r --example fastsam
```
## Results

View File

@ -4,9 +4,7 @@
cargo run -r --example rtdetr
```
## Or you can manully
### 1. Donwload or export ONNX Model
## Donwload or export ONNX Model
- Export
@ -18,19 +16,6 @@ cargo run -r --example rtdetr
[rtdetr-l-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtdetr-l-f16.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_MODEL") // <= modify this
```
### 3. Then, run
```bash
cargo run -r --example rtdetr
```
## Results
![](./demo.jpg)

View File

@ -4,31 +4,14 @@
cargo run -r --example rtmo
```
## Or you can manully
## ONNX Model
### 1. Donwload ONNX Model
[rtmo-s-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-s-dyn.onnx)
[rtmo-m-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-m-dyn.onnx)
[rtmo-l-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-l-dyn.onnx)
[rtmo-s-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-s-dyn-f16.onnx)
[rtmo-m-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-m-dyn-f16.onnx)
[rtmo-l-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-l-dyn-f16.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_MODEL") // <= modify this
```
### 3. Then, run
```bash
cargo run -r --example rtmo
```
- [rtmo-s-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-s-dyn.onnx)
- [rtmo-m-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-m-dyn.onnx)
- [rtmo-l-dyn model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-l-dyn.onnx)
- [rtmo-s-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-s-dyn-f16.onnx)
- [rtmo-m-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-m-dyn-f16.onnx)
- [rtmo-l-dyn-f16 model](https://github.com/jamjamjon/assets/releases/download/v0.0.1/rtmo-l-dyn-f16.onnx)
## Results

View File

@ -4,26 +4,12 @@
cargo run -r --example svtr
```
## Or you can manully
## ONNX Model
### 1. Donwload ONNX Model
- [ppocr-v4-server-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-server-svtr-ch-dyn.onnx)
- [ppocr-v4-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-svtr-ch-dyn.onnx)
- [ppocr-v3-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-svtr-ch-dyn.onnx)
[ppocr-v4-server-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-server-svtr-ch-dyn.onnx)
[ppocr-v4-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-svtr-ch-dyn.onnx)
[ppocr-v3-svtr-ch-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-svtr-ch-dyn.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
```
### 3. Run
```bash
cargo run -r --example svtr
```
### Speed test

View File

@ -4,22 +4,20 @@
cargo run -r --example yolo-world
```
## Or you can manully
## Donwload or Export ONNX Model
### 1. Donwload or Export ONNX Model
- **Download**
- Download
[yolov8s-world-v2-shoes](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8s-world-v2-shoes.onnx)
- **Or generate your own `yolo-world` model and then Export**
[yolov8s-world-v2-shoes](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8s-world-v2-shoes.onnx)
- Or generate your own `yolo-world` model and then Export
- Installation
- **Installation**
```shell
pip install -U ultralytics
```
- Generate
- **Generate**
```python
from ultralytics import YOLO
@ -34,25 +32,12 @@ cargo run -r --example yolo-world
model.save("custom_yolov8m-world-v2.pt")
```
- Export
- **Export**
```shell
yolo export model=custom_yolov8m-world-v2.pt format=onnx simplify dynamic
```
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH"); // <= modify this
```
### 3. Then, run
```
cargo run -r --example yolo-world
```
## Results
![](./demo.jpg)

14
examples/yolop/README.md Normal file
View File

@ -0,0 +1,14 @@
## Quick Start
```shell
cargo run -r --example yolop
```
## Pretrained Model
- [yolopv2-dyn-480x800](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolopv2-dyn-480x800.onnx)
- [yolopv2-dyn-736x1280](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolopv2-dyn-736x1280.onnx)
## Results
![](./demo.png)

BIN
examples/yolop/demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 922 KiB

26
examples/yolop/main.rs Normal file
View File

@ -0,0 +1,26 @@
use usls::{models::YOLOPv2, Annotator, DataLoader, Options};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model
let options = Options::default()
.with_model("../models/yolopv2-dyn-480x800.onnx")
.with_i00((1, 1, 8).into())
// .with_trt(0)
// .with_fp16(true)
.with_confs(&[0.3]);
let mut model = YOLOPv2::new(&options)?;
// load image
let x = vec![DataLoader::try_read("./assets/car.jpg")?];
// run
let y = model.run(&x)?;
// annotate
let annotator = Annotator::default()
.with_masks_name(false)
.with_saveout("YOLOPv2");
annotator.annotate(&x, &y);
Ok(())
}

View File

@ -4,26 +4,9 @@
cargo run -r --example yolov8-face
```
## Or you can manully
## ONNX Model
### 1. Donwload ONNX Model
[yolov8-face-dyn-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-face-dyn-f16.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
.with_profile(false);
let mut model = YOLO::new(&options)?;
```
### 3. Then, run
```bash
cargo run -r --example yolov8-face
```
- [yolov8-face-dyn-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-face-dyn-f16.onnx)
## Results

View File

@ -4,26 +4,10 @@
cargo run -r --example yolov8-falldown
```
## Or you can manully
## ONNX Model
### 1.Donwload ONNX Model
- [yolov8-falldown-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-falldown-f16.onnx)
[yolov8-falldown-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-falldown-f16.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
.with_profile(false);
let mut model = YOLO::new(&options)?
```
### 3. Then, run
```bash
cargo run -r --example yolov8-falldown
```
## Results

View File

@ -4,26 +4,10 @@
cargo run -r --example yolov8-head
```
## Or you can manully
## ONNX Model
### 1. Donwload ONNX Model
- [yolov8-head-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-head-f16.onnx)
[yolov8-head-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-head-f16.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
.with_profile(false);
let mut model = YOLO::new(&options)?;
```
### 3. Then, run
```bash
cargo run -r --example yolov8-head
```
## Results

View File

@ -6,26 +6,10 @@ Model for detecting plastic bag.
cargo run -r --example yolov8-trash
```
## Or you can manully
## ONNX Model
### 1. Donwload ONNX Model
- [yolov8-plastic-bag-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-plastic-bag-f16.onnx)
[yolov8-plastic-bag-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolov8-plastic-bag-f16.onnx)
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
.with_profile(false);
let mut model = YOLO::new(&options)?;
```
### 3. Then, run
```bash
cargo run -r --example yolov8-trash
```
## Results

View File

@ -1,20 +1,10 @@
## Features
- Support `Classification`, `Segmentation`, `Detection`, `Pose(Keypoints)-Detection` tasks.
- Support `FP16` & `FP32` ONNX models.
- Support `CoreML`, `CUDA` and `TensorRT` execution provider to accelerate computation.
- Support dynamic input shapes(`batch`, `width`, `height`).
- Support dynamic confidence(`DynConf`) for each class in Detection task.
## Quick Start
```shell
cargo run -r --example yolov8
```
## Or you can manully
### 1. Export `YOLOv8` ONNX Models
## Export `YOLOv8` ONNX Models
```bash
pip install -U ultralytics
@ -32,26 +22,11 @@ yolo export model=yolov8m-pose.pt format=onnx simplify
yolo export model=yolov8m-seg.pt format=onnx simplify
```
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
let mut model = YOLO::new(&options)?;
```
### 3. Then, run
```
cargo run -r --example yolov8
```
## Result
| Task | Annotated image |
| :-------------------: | --------------------- |
| Instance Segmentation | ![img](./demo-seg.jpg) |
| Instance Segmentation | ![img](./demo-seg.png) |
| Classification | ![img](./demo-cls.jpg) |
| Detection | ![img](./demo-det.jpg) |
| Pose | ![img](./demo-pose.jpg) |
| Detection | ![img](./demo-det.png) |
| Pose | ![img](./demo-pose.png) |

Binary file not shown.

Before

Width:  |  Height:  |  Size: 234 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 237 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -1,16 +1,18 @@
use usls::{models::YOLO, Annotator, DataLoader, Options, COCO_SKELETON_17};
use usls::{
models::YOLO, Annotator, DataLoader, Options, COCO_KEYPOINT_NAMES_17, COCO_SKELETON_17,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model
let options = Options::default()
// .with_model("../models/yolov8m-seg-dyn-f16.onnx")
.with_model("../models/yolov8m-cls.onnx")
.with_model("../models/yolov8m-dyn-f16.onnx")
// .with_trt(0) // cuda by default
// .with_fp16(true)
.with_i00((1, 1, 4).into())
.with_i02((224, 224, 800).into())
.with_i03((224, 224, 800).into())
.with_i02((224, 640, 800).into())
.with_i03((224, 640, 800).into())
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
.with_names2(&COCO_KEYPOINT_NAMES_17)
.with_profile(false)
.with_dry_run(3);
let mut model = YOLO::new(&options)?;
@ -25,6 +27,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_skeletons(&COCO_SKELETON_17)
.without_conf(false)
.without_name(false)
.with_keypoints_name(false)
.with_keypoints_conf(false)
.with_masks_name(false)
.without_masks(false)
.without_polygons(false)
.without_bboxes(false)
@ -33,7 +38,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// run & annotate
for (xs, _paths) in dl {
let ys = model.run(&xs)?;
println!("{:?}", ys);
annotator.annotate(&xs, &ys);
}

View File

@ -4,9 +4,7 @@
cargo run -r --example yolov9
```
## Or you can manully
### 1. Donwload or Export ONNX Model
## Donwload or Export ONNX Model
- **Download**
@ -26,19 +24,6 @@ cargo run -r --example yolov9
python export.py --weights yolov9-c.pt --include onnx --simplify --dynamic
```
### 2. Specify the ONNX model path in `main.rs`
```Rust
let options = Options::default()
.with_model("ONNX_PATH") // <= modify this
```
### 3. Run
```
cargo run -r --example yolov9
```
## Results
![](./demo.jpg)

View File

@ -1,339 +0,0 @@
use crate::{
auto_load, string_now, Bbox, Embedding, Keypoint, Polygon, Ys, CHECK_MARK, CROSS_MARK,
};
use ab_glyph::{FontVec, PxScale};
use anyhow::Result;
use image::{DynamicImage, GrayImage, ImageBuffer, Rgb, RgbImage};
#[derive(Debug)]
pub struct Annotator {
font: ab_glyph::FontVec,
scale_: f32, // Cope with ab_glyph & imageproc=0.24.0
skeletons: Option<Vec<(usize, usize)>>,
polygon_color: Rgb<u8>,
saveout: Option<String>,
without_conf: bool,
without_name: bool,
without_bboxes: bool,
without_masks: bool,
without_polygons: bool,
without_keypoints: bool,
}
impl Default for Annotator {
fn default() -> Self {
Self {
font: Self::load_font(None).unwrap(),
scale_: 6.666667,
skeletons: None,
polygon_color: Rgb([255, 255, 255]),
saveout: None,
without_conf: false,
without_name: false,
without_bboxes: false,
without_masks: false,
without_polygons: false,
without_keypoints: false,
}
}
}
impl Annotator {
pub fn without_conf(mut self, x: bool) -> Self {
self.without_conf = x;
self
}
pub fn without_name(mut self, x: bool) -> Self {
self.without_name = x;
self
}
pub fn without_bboxes(mut self, x: bool) -> Self {
self.without_bboxes = x;
self
}
pub fn without_masks(mut self, x: bool) -> Self {
self.without_masks = x;
self
}
pub fn without_polygons(mut self, x: bool) -> Self {
self.without_polygons = x;
self
}
pub fn without_keypoints(mut self, x: bool) -> Self {
self.without_keypoints = x;
self
}
pub fn with_saveout(mut self, saveout: &str) -> Self {
self.saveout = Some(saveout.to_string());
self
}
pub fn with_polygon_color(mut self, rgb: [u8; 3]) -> Self {
self.polygon_color = Rgb(rgb);
self
}
pub fn with_skeletons(mut self, skeletons: &[(usize, usize)]) -> Self {
self.skeletons = Some(skeletons.to_vec());
self
}
pub fn with_font(mut self, path: &str) -> Self {
self.font = Self::load_font(Some(path)).unwrap();
self
}
pub fn save(&self, image: &RgbImage, saveout: &str) {
let mut saveout = std::path::PathBuf::from("runs").join(saveout);
if !saveout.exists() {
std::fs::create_dir_all(&saveout).unwrap();
}
saveout.push(string_now("-"));
let saveout = format!("{}.jpg", saveout.to_str().unwrap());
match image.save(&saveout) {
Err(err) => println!("{} Saving failed: {:?}", CROSS_MARK, err),
Ok(_) => println!("{} Annotated image saved at: {}", CHECK_MARK, saveout),
}
}
pub fn annotate(&self, imgs: &[DynamicImage], ys: &[Ys]) {
for (img, y) in imgs.iter().zip(ys.iter()) {
let mut img_rgb = img.to_rgb8();
// masks
if !self.without_masks {
if let Some(masks) = &y.masks {
self.plot_masks(&mut img_rgb, masks)
}
}
// polygons
if !self.without_polygons {
if let Some(polygons) = &y.polygons {
self.plot_polygons(&mut img_rgb, polygons)
}
}
// bboxes
if !self.without_bboxes {
if let Some(bboxes) = &y.bboxes {
self.plot_bboxes(&mut img_rgb, bboxes)
}
}
// keypoints
if !self.without_keypoints {
if let Some(keypoints) = &y.keypoints {
self.plot_keypoints(&mut img_rgb, keypoints)
}
}
// probs
if let Some(probs) = &y.probs {
self.plot_probs(&mut img_rgb, probs)
}
if let Some(saveout) = &self.saveout {
self.save(&img_rgb, saveout);
}
}
}
pub fn plot_masks(&self, img: &mut RgbImage, masks: &[Vec<u8>]) {
for mask in masks.iter() {
let mask_nd: GrayImage =
ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec())
.expect("can not crate image from ndarray");
for _x in 0..img.width() {
for _y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y);
if mask_p.0[0] > 0 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(img, _x, _y);
img_p.0[0] /= 2;
img_p.0[1] = 255 - (255 - img_p.0[1]) / 3;
img_p.0[2] /= 2;
imageproc::drawing::Canvas::draw_pixel(img, _x, _y, img_p)
}
}
}
}
}
pub fn plot_bboxes(&self, img: &mut RgbImage, bboxes: &[Bbox]) {
for bbox in bboxes.iter() {
imageproc::drawing::draw_hollow_rect_mut(
img,
imageproc::rect::Rect::at(bbox.xmin().round() as i32, bbox.ymin().round() as i32)
.of_size(bbox.width().round() as u32, bbox.height().round() as u32),
image::Rgb(self.get_color(bbox.id()).into()),
);
let mut legend = String::new();
if !self.without_name {
legend.push_str(&bbox.name().unwrap_or(&bbox.id().to_string()).to_string());
}
if !self.without_conf {
if !self.without_name {
legend.push_str(&format!(": {:.4}", bbox.confidence()));
} else {
legend.push_str(&format!("{:.4}", bbox.confidence()));
}
}
let scale_dy = img.width().max(img.height()) as f32 / 40.0;
let scale = PxScale::from(scale_dy);
let (text_w, text_h) = imageproc::drawing::text_size(scale, &self.font, &legend); // u32
let text_h = text_h + text_h / 3;
let top = if bbox.ymin() > text_h as f32 {
(bbox.ymin().round() as u32 - text_h) as i32
} else {
(text_h - bbox.ymin().round() as u32) as i32
};
// text
if !legend.is_empty() {
imageproc::drawing::draw_filled_rect_mut(
img,
imageproc::rect::Rect::at(bbox.xmin() as i32, top).of_size(text_w, text_h),
image::Rgb(self.get_color(bbox.id()).into()),
);
imageproc::drawing::draw_text_mut(
img,
image::Rgb([0, 0, 0]),
bbox.xmin() as i32,
top - (scale_dy / self.scale_).floor() as i32 + 2,
scale,
&self.font,
&legend,
);
}
}
}
pub fn plot_polygons(&self, img: &mut RgbImage, polygons: &[Polygon]) {
for polygon in polygons.iter() {
// option: draw polygon
let polygon = polygon
.points
.iter()
.map(|p| imageproc::point::Point::new(p.x, p.y))
.collect::<Vec<_>>();
imageproc::drawing::draw_hollow_polygon_mut(img, &polygon, self.polygon_color);
// option: draw circle
// polygon.points.iter().for_each(|point| {
// imageproc::drawing::draw_filled_circle_mut(
// img,
// (point.x as i32, point.y as i32),
// 1,
// // image::Rgb([255, 255, 255]),
// self.polygon_color,
// );
// });
}
}
pub fn plot_probs(&self, img: &mut RgbImage, probs: &Embedding) {
let topk = 5usize;
let (x, mut y) = (img.width() as i32 / 20, img.height() as i32 / 20);
for k in probs.topk(topk).iter() {
let legend = format!("{}: {:.4}", k.2.as_ref().unwrap_or(&k.0.to_string()), k.1);
let scale_dy = img.width().max(img.height()) as f32 / 30.0;
let scale = PxScale::from(scale_dy);
let (text_w, text_h) = imageproc::drawing::text_size(scale, &self.font, &legend);
let text_h = text_h + text_h / 3;
y += text_h as i32;
imageproc::drawing::draw_filled_rect_mut(
img,
imageproc::rect::Rect::at(x, y).of_size(text_w, text_h),
image::Rgb(self.get_color(k.0).into()),
);
imageproc::drawing::draw_text_mut(
img,
image::Rgb((0, 0, 0).into()),
x,
y - (scale_dy / self.scale_).floor() as i32 + 2,
scale,
&self.font,
&legend,
);
}
}
pub fn plot_keypoints(&self, img: &mut RgbImage, keypoints: &[Vec<Keypoint>]) {
let radius = 3;
for kpts in keypoints.iter() {
for (i, kpt) in kpts.iter().enumerate() {
if kpt.confidence() == 0.0 {
continue;
}
// draw point
imageproc::drawing::draw_filled_circle_mut(
img,
(kpt.x() as i32, kpt.y() as i32),
radius,
image::Rgb(self.get_color(i + 10).into()),
);
}
// draw skeleton
if let Some(skeletons) = &self.skeletons {
for &(i, ii) in skeletons.iter() {
let kpt1 = &kpts[i];
let kpt2 = &kpts[ii];
if kpt1.confidence() == 0.0 || kpt2.confidence() == 0.0 {
continue;
}
imageproc::drawing::draw_line_segment_mut(
img,
(kpt1.x(), kpt1.y()),
(kpt2.x(), kpt2.y()),
image::Rgb([255, 51, 255]),
);
}
}
}
}
fn load_font(path: Option<&str>) -> Result<FontVec> {
let path_font = match path {
None => auto_load("Arial.ttf")?,
Some(p) => p.into(),
};
let buffer = std::fs::read(path_font)?;
Ok(FontVec::try_from_vec(buffer.to_owned()).unwrap())
}
pub fn get_color(&self, n: usize) -> (u8, u8, u8) {
Self::color_palette()[n % Self::color_palette().len()]
}
fn color_palette() -> Vec<(u8, u8, u8)> {
vec![
(0, 255, 0),
(255, 128, 0),
(0, 0, 255),
(255, 153, 51),
(255, 0, 0),
(255, 51, 255),
(102, 178, 255),
(51, 153, 255),
(255, 51, 51),
(153, 255, 153),
(102, 255, 102),
(153, 204, 255),
(255, 153, 153),
(255, 178, 102),
(230, 230, 0),
(255, 153, 255),
(255, 102, 255),
(255, 102, 102),
(51, 255, 51),
(255, 255, 255),
]
}
}

429
src/core/annotator.rs Normal file
View File

@ -0,0 +1,429 @@
use crate::{auto_load, string_now, Bbox, Embedding, Keypoint, Mask, Ys, CHECK_MARK, CROSS_MARK};
use ab_glyph::{FontVec, PxScale};
use anyhow::Result;
use image::{DynamicImage, Rgba, RgbaImage};
#[derive(Debug)]
pub struct Annotator {
font: ab_glyph::FontVec,
scale_: f32, // Cope with ab_glyph & imageproc=0.24.0
skeletons: Option<Vec<(usize, usize)>>,
saveout: Option<String>,
mask_alpha: u8,
polygon_color: Rgba<u8>,
without_conf: bool,
without_name: bool,
with_keypoints_conf: bool,
with_keypoints_name: bool,
with_masks_name: bool,
without_bboxes: bool,
without_masks: bool,
without_polygons: bool,
without_keypoints: bool,
keypoint_radius: usize,
}
impl Default for Annotator {
fn default() -> Self {
Self {
font: Self::load_font(None).unwrap(),
scale_: 6.666667,
mask_alpha: 179,
polygon_color: Rgba([255, 255, 255, 255]),
skeletons: None,
saveout: None,
without_conf: false,
without_name: false,
with_keypoints_conf: false,
with_keypoints_name: false,
with_masks_name: false,
without_bboxes: false,
without_masks: false,
without_polygons: false,
without_keypoints: false,
keypoint_radius: 3,
}
}
}
impl Annotator {
pub fn with_keypoint_radius(mut self, x: usize) -> Self {
self.keypoint_radius = x;
self
}
pub fn without_conf(mut self, x: bool) -> Self {
self.without_conf = x;
self
}
pub fn without_name(mut self, x: bool) -> Self {
self.without_name = x;
self
}
pub fn with_keypoints_conf(mut self, x: bool) -> Self {
self.with_keypoints_conf = x;
self
}
pub fn with_keypoints_name(mut self, x: bool) -> Self {
self.with_keypoints_name = x;
self
}
pub fn with_masks_name(mut self, x: bool) -> Self {
self.with_masks_name = x;
self
}
pub fn without_bboxes(mut self, x: bool) -> Self {
self.without_bboxes = x;
self
}
pub fn without_masks(mut self, x: bool) -> Self {
self.without_masks = x;
self
}
pub fn without_polygons(mut self, x: bool) -> Self {
self.without_polygons = x;
self
}
pub fn with_mask_alpha(mut self, x: u8) -> Self {
self.mask_alpha = x;
self
}
pub fn with_polygon_color(mut self, rgba: [u8; 4]) -> Self {
self.polygon_color = Rgba(rgba);
self
}
pub fn without_keypoints(mut self, x: bool) -> Self {
self.without_keypoints = x;
self
}
pub fn with_saveout(mut self, saveout: &str) -> Self {
self.saveout = Some(saveout.to_string());
self
}
pub fn with_skeletons(mut self, skeletons: &[(usize, usize)]) -> Self {
self.skeletons = Some(skeletons.to_vec());
self
}
pub fn with_font(mut self, path: &str) -> Self {
self.font = Self::load_font(Some(path)).unwrap();
self
}
pub fn save(&self, image: &RgbaImage, saveout: &str) {
let mut saveout = std::path::PathBuf::from("runs").join(saveout);
if !saveout.exists() {
std::fs::create_dir_all(&saveout).unwrap();
}
saveout.push(string_now("-"));
let saveout = format!("{}.png", saveout.to_str().unwrap());
match image.save(&saveout) {
Err(err) => println!("{} Saving failed: {:?}", CROSS_MARK, err),
Ok(_) => println!("{} Annotated image saved to: {}", CHECK_MARK, saveout),
}
}
pub fn annotate(&self, imgs: &[DynamicImage], ys: &[Ys]) {
for (img, y) in imgs.iter().zip(ys.iter()) {
let mut img_rgb = img.to_rgba8();
// masks
if !self.without_polygons {
if let Some(xs) = &y.masks {
self.plot_polygons(&mut img_rgb, xs)
}
}
// bboxes
if !self.without_bboxes {
if let Some(xs) = &y.bboxes {
self.plot_bboxes(&mut img_rgb, xs)
}
}
// keypoints
if !self.without_keypoints {
if let Some(xs) = &y.keypoints {
self.plot_keypoints(&mut img_rgb, xs)
}
}
// probs
if let Some(xs) = &y.probs {
self.plot_probs(&mut img_rgb, xs)
}
if let Some(saveout) = &self.saveout {
self.save(&img_rgb, saveout);
}
}
}
pub fn plot_bboxes(&self, img: &mut RgbaImage, bboxes: &[Bbox]) {
for bbox in bboxes.iter() {
imageproc::drawing::draw_hollow_rect_mut(
img,
imageproc::rect::Rect::at(bbox.xmin().round() as i32, bbox.ymin().round() as i32)
.of_size(bbox.width().round() as u32, bbox.height().round() as u32),
image::Rgba(self.get_color(bbox.id()).into()),
);
let mut legend = String::new();
if !self.without_name {
legend.push_str(&bbox.name().unwrap_or(&bbox.id().to_string()).to_string());
}
if !self.without_conf {
if !self.without_name {
legend.push_str(&format!(": {:.4}", bbox.confidence()));
} else {
legend.push_str(&format!("{:.4}", bbox.confidence()));
}
}
if !legend.is_empty() {
let scale_dy = img.width().max(img.height()) as f32 / 40.0;
let scale = PxScale::from(scale_dy);
let (text_w, text_h) = imageproc::drawing::text_size(scale, &self.font, &legend); // u32
let text_h = text_h + text_h / 3;
let top = if bbox.ymin() > text_h as f32 {
(bbox.ymin().round() as u32 - text_h) as i32
} else {
(text_h - bbox.ymin().round() as u32) as i32
};
let mut left = bbox.xmin() as i32;
if left + text_w as i32 > img.width() as i32 {
left = img.width() as i32 - text_w as i32;
}
imageproc::drawing::draw_filled_rect_mut(
img,
imageproc::rect::Rect::at(left, top).of_size(text_w, text_h),
image::Rgba(self.get_color(bbox.id()).into()),
);
imageproc::drawing::draw_text_mut(
img,
image::Rgba([0, 0, 0, 255]),
left,
top - (scale_dy / self.scale_).floor() as i32 + 2,
scale,
&self.font,
&legend,
);
}
}
}
pub fn plot_polygons(&self, img: &mut RgbaImage, masks: &[Mask]) {
let mut convas = img.clone();
for mask in masks.iter() {
// mask
let mut polygon_i32 = mask
.polygon
.points
.iter()
.map(|p| imageproc::point::Point::new(p.x as i32, p.y as i32))
.collect::<Vec<_>>();
if polygon_i32.first() == polygon_i32.last() {
polygon_i32.pop();
}
let mut mask_color = self.get_color(mask.id);
mask_color.3 = self.mask_alpha;
imageproc::drawing::draw_polygon_mut(
&mut convas,
&polygon_i32,
Rgba(mask_color.into()),
);
// contour
let polygon_f32 = mask
.polygon
.points
.iter()
.map(|p| imageproc::point::Point::new(p.x, p.y))
.collect::<Vec<_>>();
imageproc::drawing::draw_hollow_polygon_mut(img, &polygon_f32, self.polygon_color);
// text
let mut legend = String::new();
if self.with_masks_name {
legend.push_str(&mask.name().unwrap_or(&mask.id().to_string()).to_string());
}
if !legend.is_empty() {
let scale_dy = img.width().max(img.height()) as f32 / 60.0;
let scale = PxScale::from(scale_dy);
let (text_w, text_h) = imageproc::drawing::text_size(scale, &self.font, &legend); // u32
let text_h = text_h + text_h / 3;
let bbox = mask.polygon.find_min_rect();
let top = (bbox.cy().round() as u32 - text_h) as i32;
let mut left = (bbox.cx() as i32 - text_w as i32 / 2).max(0);
if left + text_w as i32 > img.width() as i32 {
left = img.width() as i32 - text_w as i32;
}
imageproc::drawing::draw_filled_rect_mut(
&mut convas,
imageproc::rect::Rect::at(left, top).of_size(text_w, text_h),
image::Rgba(self.get_color(mask.id()).into()),
);
imageproc::drawing::draw_text_mut(
&mut convas,
image::Rgba([0, 0, 0, 255]),
left,
top - (scale_dy / self.scale_).floor() as i32 + 2,
scale,
&self.font,
&legend,
);
}
}
image::imageops::overlay(img, &convas, 0, 0);
}
pub fn plot_probs(&self, img: &mut RgbaImage, probs: &Embedding) {
let topk = 5usize;
let (x, mut y) = (img.width() as i32 / 20, img.height() as i32 / 20);
for k in probs.topk(topk).iter() {
let legend = format!("{}: {:.4}", k.2.as_ref().unwrap_or(&k.0.to_string()), k.1);
let scale_dy = img.width().max(img.height()) as f32 / 30.0;
let scale = PxScale::from(scale_dy);
let (text_w, text_h) = imageproc::drawing::text_size(scale, &self.font, &legend);
let text_h = text_h + text_h / 3;
y += text_h as i32;
imageproc::drawing::draw_filled_rect_mut(
img,
imageproc::rect::Rect::at(x, y).of_size(text_w, text_h),
image::Rgba(self.get_color(k.0).into()),
);
imageproc::drawing::draw_text_mut(
img,
image::Rgba([0, 0, 0, 255]),
x,
y - (scale_dy / self.scale_).floor() as i32 + 2,
scale,
&self.font,
&legend,
);
}
}
pub fn plot_keypoints(&self, img: &mut RgbaImage, keypoints: &[Vec<Keypoint>]) {
for kpts in keypoints.iter() {
for (i, kpt) in kpts.iter().enumerate() {
if kpt.confidence() == 0.0 {
continue;
}
imageproc::drawing::draw_filled_circle_mut(
img,
(kpt.x() as i32, kpt.y() as i32),
self.keypoint_radius as i32,
image::Rgba(self.get_color(i + 10).into()),
);
let mut legend = String::new();
if self.with_keypoints_name {
legend.push_str(&kpt.name().unwrap_or(&kpt.id().to_string()).to_string());
}
if self.with_keypoints_conf {
if self.with_keypoints_name {
legend.push_str(&format!(": {:.4}", kpt.confidence()));
} else {
legend.push_str(&format!("{:.4}", kpt.confidence()));
}
}
if !legend.is_empty() {
let scale_dy = img.width().max(img.height()) as f32 / 80.0;
let scale = PxScale::from(scale_dy);
let (text_w, text_h) =
imageproc::drawing::text_size(scale, &self.font, &legend); // u32
let text_h = text_h + text_h / 3;
let top = if kpt.y() > text_h as f32 {
(kpt.y().round() as u32 - text_h - self.keypoint_radius as u32) as i32
} else {
(text_h - self.keypoint_radius as u32 - kpt.y().round() as u32) as i32
};
let mut left =
(kpt.x() as i32 - self.keypoint_radius as i32 - text_w as i32 / 2).max(0);
if left + text_w as i32 > img.width() as i32 {
left = img.width() as i32 - text_w as i32;
}
imageproc::drawing::draw_filled_rect_mut(
img,
imageproc::rect::Rect::at(left, top).of_size(text_w, text_h),
image::Rgba(self.get_color(kpt.id() as usize).into()),
);
imageproc::drawing::draw_text_mut(
img,
image::Rgba([0, 0, 0, 255]),
left,
top - (scale_dy / self.scale_).floor() as i32 + 2,
scale,
&self.font,
&legend,
);
}
}
// draw skeleton
if let Some(skeletons) = &self.skeletons {
for &(i, ii) in skeletons.iter() {
let kpt1 = &kpts[i];
let kpt2 = &kpts[ii];
if kpt1.confidence() == 0.0 || kpt2.confidence() == 0.0 {
continue;
}
imageproc::drawing::draw_line_segment_mut(
img,
(kpt1.x(), kpt1.y()),
(kpt2.x(), kpt2.y()),
image::Rgba([255, 51, 255, 255]),
);
}
}
}
}
fn load_font(path: Option<&str>) -> Result<FontVec> {
let path_font = match path {
None => auto_load("Arial.ttf")?,
Some(p) => p.into(),
};
let buffer = std::fs::read(path_font)?;
Ok(FontVec::try_from_vec(buffer.to_owned()).unwrap())
}
pub fn get_color(&self, n: usize) -> (u8, u8, u8, u8) {
Self::color_palette()[n % Self::color_palette().len()]
}
fn color_palette() -> Vec<(u8, u8, u8, u8)> {
vec![
(0, 255, 0, 255),
(255, 128, 0, 255),
(0, 0, 255, 255),
(255, 153, 51, 255),
(255, 0, 0, 255),
(255, 51, 255, 255),
(102, 178, 255, 255),
(51, 153, 255, 255),
(255, 51, 51, 255),
(153, 255, 153, 255),
(102, 255, 102, 255),
(153, 204, 255, 255),
(255, 153, 153, 255),
(255, 178, 102, 255),
(230, 230, 0, 255),
(255, 153, 255, 255),
(255, 102, 255, 255),
(255, 102, 102, 255),
(51, 255, 51, 255),
(255, 255, 255, 255),
]
}
}

View File

@ -59,6 +59,7 @@ impl Bbox {
pub fn id(&self) -> usize {
self.id
}
pub fn name(&self) -> Option<&String> {
self.name.as_ref()
}

View File

@ -268,10 +268,17 @@ impl OrtEngine {
// oputput
let mut ys_ = Vec::new();
let t_post = std::time::Instant::now();
for ((_, y), dtype) in ys.iter().zip(self.odtypes.iter()) {
for (dtype, name) in self.odtypes.iter().zip(self.onames.iter()) {
let y = &ys[name.as_str()];
let y_ = match &dtype {
TensorElementType::Float32 => y.extract_tensor::<f32>()?.view().to_owned(),
TensorElementType::Float16 => y.extract_tensor::<f16>()?.view().mapv(f16::to_f32),
TensorElementType::Int64 => y
.extract_tensor::<i64>()?
.view()
.to_owned()
.mapv(|x| x as f32),
_ => todo!(),
};
ys_.push(y_);

63
src/core/keypoint.rs Normal file
View File

@ -0,0 +1,63 @@
use crate::Point;
#[derive(PartialEq, Clone)]
pub struct Keypoint {
pub point: Point,
confidence: f32,
id: isize,
name: Option<String>,
}
impl Default for Keypoint {
fn default() -> Self {
Self {
id: -1,
confidence: 0.0,
point: Point::default(),
name: None,
}
}
}
impl std::fmt::Debug for Keypoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Keypoint")
.field("x", &self.point.x)
.field("y", &self.point.y)
.field("confidence", &self.confidence)
.field("id", &self.id)
.field("name", &self.name)
.finish()
}
}
impl Keypoint {
pub fn new(point: Point, confidence: f32, id: isize, name: Option<String>) -> Self {
Self {
point,
confidence,
id,
name,
}
}
pub fn x(&self) -> f32 {
self.point.x
}
pub fn y(&self) -> f32 {
self.point.y
}
pub fn confidence(&self) -> f32 {
self.confidence
}
pub fn id(&self) -> isize {
self.id
}
pub fn name(&self) -> Option<&String> {
self.name.as_ref()
}
}

28
src/core/mask.rs Normal file
View File

@ -0,0 +1,28 @@
use crate::Polygon;
#[derive(Default, Clone, PartialEq)]
pub struct Mask {
pub polygon: Polygon,
pub id: usize,
pub name: Option<String>,
}
impl std::fmt::Debug for Mask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Mask")
.field("polygons(num_points)", &self.polygon.points.len())
.field("id", &self.id)
.field("name", &self.name)
.finish()
}
}
impl Mask {
pub fn id(&self) -> usize {
self.id
}
pub fn name(&self) -> Option<&String> {
self.name.as_ref()
}
}

45
src/core/mod.rs Normal file
View File

@ -0,0 +1,45 @@
mod annotator;
mod bbox;
mod dataloader;
mod device;
mod dynconf;
mod embedding;
mod engine;
mod keypoint;
mod logits_sampler;
mod mask;
mod metric;
mod min_opt_max;
pub mod ops;
mod options;
mod point;
mod polygon;
mod rect;
mod rotated_rect;
mod tokenizer_stream;
mod utils;
mod ys;
pub use annotator::Annotator;
pub use bbox::Bbox;
pub use dataloader::DataLoader;
pub use device::Device;
pub use dynconf::DynConf;
pub use embedding::Embedding;
pub use engine::OrtEngine;
pub use keypoint::Keypoint;
pub use logits_sampler::LogitsSampler;
pub use mask::Mask;
pub use metric::Metric;
pub use min_opt_max::MinOptMax;
pub use options::Options;
pub use point::Point;
pub use polygon::Polygon;
pub use rect::Rect;
pub use rotated_rect::RotatedRect;
pub use tokenizer_stream::TokenizerStream;
pub use utils::{
auto_load, config_dir, download, string_now, COCO_KEYPOINT_NAMES_17, COCO_NAMES_80,
COCO_SKELETON_17,
};
pub use ys::Ys;

View File

@ -1,5 +1,6 @@
use crate::{Mask, Polygon};
use anyhow::Result;
use image::{DynamicImage, GenericImageView};
use image::{DynamicImage, GenericImageView, GrayImage, ImageBuffer};
use ndarray::{Array, Axis, Ix2, IxDyn};
pub fn standardize(xs: Array<f32, IxDyn>, mean: &[f32], std: &[f32]) -> Array<f32, IxDyn> {
@ -106,3 +107,40 @@ pub fn resize_with_fixed_height(
}
Ok(ys)
}
pub fn build_dyn_image_from_raw(v: Vec<f32>, height: u32, width: u32) -> DynamicImage {
let v: ImageBuffer<image::Luma<_>, Vec<f32>> =
ImageBuffer::from_raw(width, height, v).expect("Faild to create image from ndarray");
image::DynamicImage::from(v)
}
pub fn descale_mask(mask: DynamicImage, w0: f32, h0: f32, w1: f32, h1: f32) -> DynamicImage {
// 0 -> 1
let (_, w, h) = scale_wh(w1, h1, w0, h0);
let mut mask = mask.to_owned();
let mask = mask.crop(0, 0, w as u32, h as u32);
mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle)
}
pub fn get_masks_from_image(
mask: GrayImage,
thresh: u8,
id: usize,
name: Option<String>,
) -> Vec<Mask> {
// let mask = mask.into_luma8();
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&mask, thresh);
let mut masks: Vec<Mask> = Vec::new();
contours.iter().for_each(|contour| {
// contour.border_type == imageproc::contours::BorderType::Outer &&
if contour.points.len() > 2 {
masks.push(Mask {
polygon: Polygon::from_contour(contour),
id,
name: name.to_owned(),
});
}
});
masks
}

View File

@ -46,8 +46,9 @@ pub struct Options {
pub apply_nms: bool,
pub tokenizer: Option<String>,
pub vocab: Option<String>,
pub names: Option<Vec<String>>, // class names
pub anchors_first: bool, // otuput format: [bs, anchors/na, pos+nc+nm]
pub names: Option<Vec<String>>, // class names
pub names2: Option<Vec<String>>, // could be keypoints names
pub anchors_first: bool, // otuput format: [bs, anchors/na, pos+nc+nm]
pub min_width: Option<f32>,
pub min_height: Option<f32>,
pub unclip_ratio: f32, // DB
@ -97,6 +98,7 @@ impl Default for Options {
tokenizer: None,
vocab: None,
names: None,
names2: None,
anchors_first: false,
min_width: None,
min_height: None,
@ -151,6 +153,11 @@ impl Options {
self
}
pub fn with_names2(mut self, names: &[&str]) -> Self {
self.names2 = Some(names.iter().map(|x| x.to_string()).collect::<Vec<String>>());
self
}
pub fn with_vocab(mut self, vocab: &str) -> Self {
self.vocab = Some(auto_load(vocab).unwrap());
self

View File

@ -12,12 +12,6 @@ impl From<Vec<Point>> for Polygon {
}
impl Polygon {
// pub fn new(points: &[Point]) -> Self {
// Self {
// points: points.to_vec(),
// }
// }
pub fn new() -> Self {
Self::default()
}
@ -62,6 +56,11 @@ impl Polygon {
area.abs() / 2.0
}
pub fn center(&self) -> Point {
let rect = self.find_min_rect();
rect.center()
}
pub fn find_min_rect(&self) -> Rect {
let (mut min_x, mut min_y, mut max_x, mut max_y) = (f32::MAX, f32::MAX, f32::MIN, f32::MIN);
for point in self.points.iter() {

View File

@ -89,11 +89,11 @@ impl Rect {
}
pub fn cx(&self) -> f32 {
self.bottom_right.x - self.top_left.x
(self.bottom_right.x + self.top_left.x) / 2.0
}
pub fn cy(&self) -> f32 {
self.bottom_right.y - self.top_left.y
(self.bottom_right.y + self.top_left.y) / 2.0
}
pub fn tl(&self) -> Point {

View File

@ -67,6 +67,7 @@ pub fn download<P: AsRef<Path> + std::fmt::Debug>(
}
assert_eq!(downloaded_bytes as u64, ntotal);
pb.finish();
println!();
Ok(())
}
@ -110,6 +111,25 @@ pub const COCO_SKELETON_17: [(usize, usize); 16] = [
(13, 15),
(14, 16),
];
pub const COCO_KEYPOINT_NAMES_17: [&str; 17] = [
"nose",
"left_eye",
"right_eye",
"left_ear",
"right_ear",
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
"left_hip",
"right_hip",
"left_knee",
"right_knee",
"left_ankle",
"right_ankle",
];
pub const COCO_NAMES_80: [&str; 80] = [
"person",

View File

@ -1,4 +1,4 @@
use crate::{Bbox, Embedding, Keypoint, Polygon};
use crate::{Bbox, Embedding, Keypoint, Mask};
#[derive(Clone, PartialEq, Default)]
pub struct Ys {
@ -6,8 +6,7 @@ pub struct Ys {
pub probs: Option<Embedding>,
pub bboxes: Option<Vec<Bbox>>,
pub keypoints: Option<Vec<Vec<Keypoint>>>,
pub masks: Option<Vec<Vec<u8>>>,
pub polygons: Option<Vec<Polygon>>,
pub masks: Option<Vec<Mask>>,
}
impl std::fmt::Debug for Ys {
@ -16,14 +15,7 @@ impl std::fmt::Debug for Ys {
.field("Probabilities", &self.probs)
.field("BoundingBoxes", &self.bboxes)
.field("Keypoints", &self.keypoints)
.field(
"Masks",
&format_args!("{:?}", self.masks().map(|masks| masks.len())),
)
.field(
"Polygons",
&format_args!("{:?}", self.polygons().map(|polygons| polygons.len())),
)
.field("Masks", &self.masks)
.finish()
}
}
@ -44,16 +36,11 @@ impl Ys {
self
}
pub fn with_masks(mut self, masks: &[Vec<u8>]) -> Self {
pub fn with_masks(mut self, masks: &[Mask]) -> Self {
self.masks = Some(masks.to_vec());
self
}
pub fn with_polygons(mut self, polygons: &[Polygon]) -> Self {
self.polygons = Some(polygons.to_vec());
self
}
pub fn probs(&self) -> Option<&Embedding> {
self.probs.as_ref()
}
@ -62,15 +49,31 @@ impl Ys {
self.keypoints.as_ref()
}
pub fn masks(&self) -> Option<&Vec<Vec<u8>>> {
pub fn masks(&self) -> Option<&Vec<Mask>> {
self.masks.as_ref()
}
pub fn polygons(&self) -> Option<&Vec<Polygon>> {
self.polygons.as_ref()
}
pub fn bboxes(&self) -> Option<&Vec<Bbox>> {
self.bboxes.as_ref()
}
pub fn non_max_suppression(xs: &mut Vec<Bbox>, iou_threshold: f32) {
xs.sort_by(|b1, b2| b2.confidence().partial_cmp(&b1.confidence()).unwrap());
let mut current_index = 0;
for index in 0..xs.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = xs[prev_index].iou(&xs[index]);
if iou > iou_threshold {
drop = true;
break;
}
}
if !drop {
xs.swap(current_index, index);
current_index += 1;
}
}
xs.truncate(current_index);
}
}

View File

@ -1,35 +0,0 @@
use crate::Point;
#[derive(PartialEq, Clone, Default)]
pub struct Keypoint {
pub point: Point,
confidence: f32,
}
impl std::fmt::Debug for Keypoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Keypoint")
.field("x", &self.point.x)
.field("y", &self.point.y)
.field("confidence", &self.confidence)
.finish()
}
}
impl Keypoint {
pub fn new(point: Point, confidence: f32) -> Self {
Self { point, confidence }
}
pub fn x(&self) -> f32 {
self.point.x
}
pub fn y(&self) -> f32 {
self.point.y
}
pub fn confidence(&self) -> f32 {
self.confidence
}
}

View File

@ -1,44 +1,6 @@
mod annotator;
mod bbox;
mod dataloader;
mod device;
mod dynconf;
mod embedding;
mod engine;
mod keypoint;
mod logits_sampler;
mod metric;
mod min_opt_max;
mod core;
pub mod models;
pub mod ops;
mod options;
mod point;
mod polygon;
mod rect;
mod rotated_rect;
mod tokenizer_stream;
mod utils;
mod ys;
pub use annotator::Annotator;
pub use bbox::Bbox;
pub use dataloader::DataLoader;
pub use device::Device;
pub use dynconf::DynConf;
pub use embedding::Embedding;
pub use engine::OrtEngine;
pub use keypoint::Keypoint;
pub use logits_sampler::LogitsSampler;
pub use metric::Metric;
pub use min_opt_max::MinOptMax;
pub use options::Options;
pub use point::Point;
pub use polygon::Polygon;
pub use rect::Rect;
pub use rotated_rect::RotatedRect;
pub use tokenizer_stream::TokenizerStream;
pub use utils::{auto_load, config_dir, download, string_now, COCO_NAMES_80, COCO_SKELETON_17};
pub use ys::Ys;
pub use core::*;
const GITHUB_ASSETS: &str = "https://github.com/jamjamjon/assets/releases/download/v0.0.1";
const CHECK_MARK: &str = "";

View File

@ -1,4 +1,4 @@
use crate::{ops, Bbox, DynConf, MinOptMax, Options, OrtEngine, Polygon, Ys};
use crate::{ops, Bbox, DynConf, Mask, MinOptMax, Options, OrtEngine, Polygon, Ys};
use anyhow::Result;
use image::{DynamicImage, ImageBuffer};
use ndarray::{Array, Axis, IxDyn};
@ -94,7 +94,7 @@ impl DB {
imageproc::contours::find_contours_with_threshold(&mask_im, 1);
// loop
let mut y_polygons: Vec<Polygon> = Vec::new();
let mut y_masks: Vec<Mask> = Vec::new();
for contour in contours.iter() {
if contour.points.len() <= 1 {
continue;
@ -115,15 +115,14 @@ impl DB {
if confidence < self.confs[0] {
continue;
}
let bbox = Bbox::new(rect, 0, confidence, None);
y_bbox.push(bbox);
y_polygons.push(polygon);
y_bbox.push(Bbox::new(rect, 0, confidence, None));
y_masks.push(Mask {
polygon,
id: 0,
name: None,
});
}
ys.push(
Ys::default()
.with_bboxes(&y_bbox)
.with_polygons(&y_polygons),
);
ys.push(Ys::default().with_bboxes(&y_bbox).with_masks(&y_masks));
}
Ok(ys)

View File

@ -6,6 +6,7 @@ mod rtdetr;
mod rtmo;
mod svtr;
mod yolo;
mod yolop;
pub use blip::Blip;
pub use clip::Clip;
@ -15,3 +16,4 @@ pub use rtdetr::RTDETR;
pub use rtmo::RTMO;
pub use svtr::SVTR;
pub use yolo::YOLO;
pub use yolop::YOLOPv2;

View File

@ -111,14 +111,7 @@ impl RTDETR {
);
y_bboxes.push(y_bbox)
}
let y = Ys {
probs: None,
bboxes: Some(y_bboxes),
keypoints: None,
masks: None,
polygons: None,
};
ys.push(y);
ys.push(Ys::default().with_bboxes(&y_bboxes));
}
Ok(ys)
}

View File

@ -109,6 +109,8 @@ impl RTMO {
)
.into(),
c,
i as isize,
None, // Name
));
}
}

View File

@ -1,12 +1,11 @@
use anyhow::Result;
use clap::ValueEnum;
use image::{DynamicImage, ImageBuffer};
use image::DynamicImage;
use ndarray::{s, Array, Axis, IxDyn};
use regex::Regex;
use crate::{
ops, Bbox, DynConf, Embedding, Keypoint, MinOptMax, Options, OrtEngine, Point, Polygon, Rect,
Ys,
ops, Bbox, DynConf, Embedding, Keypoint, Mask, MinOptMax, Options, OrtEngine, Point, Rect, Ys,
};
const CXYWH_OFFSET: usize = 4;
@ -35,6 +34,7 @@ pub struct YOLO {
kconfs: DynConf,
iou: f32,
names: Option<Vec<String>>,
names_kpt: Option<Vec<String>>,
apply_nms: bool,
anchors_first: bool,
}
@ -83,6 +83,8 @@ impl YOLO {
},
};
let names_kpt = options.names2.to_owned().or(None);
// try from model metadata
let nk = engine
.try_fetch("kpt_shape")
@ -115,6 +117,7 @@ impl YOLO {
batch,
task,
names,
names_kpt,
anchors_first: options.anchors_first,
})
}
@ -226,6 +229,8 @@ impl YOLO {
ky.max(0.0f32).min(height_original),
),
kconf,
i as isize,
self.names_kpt.as_ref().map(|names| names[i].to_owned()),
));
}
}
@ -247,10 +252,7 @@ impl YOLO {
// decode
let mut y_bboxes: Vec<Bbox> = Vec::new();
let mut y_kpts: Vec<Vec<Keypoint>> = Vec::new();
let mut y_masks: Vec<Vec<u8>> = Vec::new();
let mut y_polygons: Vec<Polygon> = Vec::new();
let mut y_masks: Vec<Mask> = Vec::new();
for elem in data.into_iter() {
if let Some(kpts) = elem.1 {
y_kpts.push(kpts)
@ -267,23 +269,23 @@ impl YOLO {
let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)
// build image from ndarray
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec())
.expect("Faild to create image from ndarray");
let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn
// rescale masks
let (_, w_mask, h_mask) =
ops::scale_wh(width_original, height_original, nw as f32, nh as f32);
let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
let mask_original = mask_cropped.resize_exact(
width_original as u32,
height_original as u32,
image::imageops::FilterType::Triangle,
let mask_im = ops::build_dyn_image_from_raw(
mask.into_raw_vec(),
nw as u32,
nh as u32,
);
// crop-mask with bbox
let mut mask_object_cropped = mask_original.into_luma8(); // gray image
// rescale masks
let mask_original = ops::descale_mask(
mask_im,
nw as f32,
nh as f32,
width_original,
height_original,
);
// crop mask with bbox
let mut mask_original = mask_original.into_luma8();
for y in 0..height_original as usize {
for x in 0..width_original as usize {
if x < elem.0.xmin() as usize
@ -291,33 +293,19 @@ impl YOLO {
|| y < elem.0.ymin() as usize
|| y > elem.0.ymax() as usize
{
mask_object_cropped.put_pixel(
x as u32,
y as u32,
image::Luma([0u8]),
);
mask_original.put_pixel(x as u32, y as u32, image::Luma([0u8]));
}
}
}
// mask -> contours
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(
&mask_object_cropped,
1,
);
// contours -> polygons
contours.iter().for_each(|contour| {
if let imageproc::contours::BorderType::Outer = contour.border_type {
if contour.points.len() > 1 {
y_polygons.push(Polygon::from_contour(contour));
}
}
});
// save each mask
y_masks.push(mask_object_cropped.into_raw());
// get masks from image
let masks = ops::get_masks_from_image(
mask_original,
1,
elem.0.id(),
elem.0.name().cloned(),
);
y_masks.extend(masks);
}
y_bboxes.push(elem.0);
}
@ -327,8 +315,7 @@ impl YOLO {
Ys::default()
.with_bboxes(&y_bboxes)
.with_keypoints(&y_kpts)
.with_masks(&y_masks)
.with_polygons(&y_polygons),
.with_masks(&y_masks),
);
}

162
src/models/yolop.rs Normal file
View File

@ -0,0 +1,162 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array, Axis, IxDyn};
use crate::{ops, Bbox, DynConf, MinOptMax, Options, OrtEngine, Rect, Ys};
#[derive(Debug)]
pub struct YOLOPv2 {
engine: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
confs: DynConf,
iou: f32,
}
impl YOLOPv2 {
pub fn new(options: &Options) -> Result<Self> {
let engine = OrtEngine::new(options)?;
let (batch, height, width) = (
engine.batch().to_owned(),
engine.height().to_owned(),
engine.width().to_owned(),
);
let nc = 80;
let confs = DynConf::new(&options.kconfs, nc);
engine.dry_run()?;
Ok(Self {
engine,
confs,
height,
width,
batch,
iou: options.iou,
})
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Ys>> {
let xs_ = ops::letterbox(xs, self.height() as u32, self.width() as u32, 114.0)?;
let xs_ = ops::normalize(xs_, 0.0, 255.0);
let ys = self.engine.run(&[xs_])?;
let ys = self.postprocess(ys, xs)?;
Ok(ys)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Ys>> {
let (xs_da, xs_ll, xs_det) = (&xs[0], &xs[1], &xs[2]);
let mut ys: Vec<Ys> = Vec::new();
for (idx, ((x_det, x_ll), x_da)) in xs_det
.axis_iter(Axis(0))
.zip(xs_ll.axis_iter(Axis(0)))
.zip(xs_da.axis_iter(Axis(0)))
.enumerate()
{
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
let (ratio, _, _) = ops::scale_wh(
image_width,
image_height,
self.width() as f32,
self.height() as f32,
);
// Vehicle
let mut ys_bbox = Vec::new();
for x in x_det.axis_iter(Axis(0)) {
let bbox = x.slice(s![0..4]);
let clss = x.slice(s![5..]).to_owned();
let conf = x[4];
let clss = conf * clss;
let (id, conf) = clss
.into_iter()
.enumerate()
.reduce(|max, x| if x.1 > max.1 { x } else { max })
.unwrap();
if conf < self.confs[id] {
continue;
}
let cx = bbox[0] / ratio;
let cy = bbox[1] / ratio;
let w = bbox[2] / ratio;
let h = bbox[3] / ratio;
let x = cx - w / 2.;
let y = cy - h / 2.;
ys_bbox.push(Bbox::new(
Rect::from_xywh(
x.max(0.0f32).min(image_width),
y.max(0.0f32).min(image_height),
w,
h,
),
id,
conf,
None,
));
}
Ys::non_max_suppression(&mut ys_bbox, self.iou);
// Drivable area
let x_da_0 = x_da.slice(s![0, .., ..]).to_owned();
let x_da_1 = x_da.slice(s![1, .., ..]).to_owned();
let x_da = x_da_1 - x_da_0;
let x_da = x_da
.into_shape((self.height() as usize, self.width() as usize, 1))?
.into_owned();
let v = x_da
.into_raw_vec()
.iter()
.map(|x| if x < &0.0 { 0.0 } else { 1.0 })
.collect::<Vec<_>>();
let mask_da =
ops::build_dyn_image_from_raw(v, self.height() as u32, self.width() as u32);
let mask_da = ops::descale_mask(
mask_da,
self.width() as f32,
self.height() as f32,
image_width,
image_height,
);
let mask_da = mask_da.into_luma8();
let mut y_masks =
ops::get_masks_from_image(mask_da, 1, 0, Some("Drivable area".to_string()));
// Lane line
let x_ll = x_ll
.into_shape((self.height() as usize, self.width() as usize, 1))?
.into_owned();
let v = x_ll
.into_raw_vec()
.iter()
.map(|x| if x < &0.5 { 0.0 } else { 1.0 })
.collect::<Vec<_>>();
let mask_ll =
ops::build_dyn_image_from_raw(v, self.height() as u32, self.width() as u32);
let mask_ll = ops::descale_mask(
mask_ll,
self.width() as f32,
self.height() as f32,
image_width,
image_height,
);
let mask_ll = mask_ll.into_luma8();
let masks = ops::get_masks_from_image(mask_ll, 1, 5, Some("Lane line".to_string()));
y_masks.extend(masks);
ys.push(Ys::default().with_bboxes(&ys_bbox).with_masks(&y_masks));
}
Ok(ys)
}
pub fn batch(&self) -> isize {
self.batch.opt
}
pub fn width(&self) -> isize {
self.width.opt
}
pub fn height(&self) -> isize {
self.height.opt
}
}