* Add docs

* Add mediapipe-selfie-segmenter model

* Update README.md

* Update RTMO model
This commit is contained in:
Jamjamjon
2025-06-08 18:15:54 +08:00
committed by GitHub
parent 0e8d4f832a
commit 70aeae9e01
77 changed files with 2325 additions and 1414 deletions

View File

@ -1,7 +1,7 @@
[package]
name = "usls"
edition = "2021"
version = "0.1.0-beta.4"
version = "0.1.0-rc.1"
rust-version = "1.82"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
@ -10,6 +10,7 @@ license = "MIT"
readme = "README.md"
exclude = ["assets/*", "examples/*", "runs/*", "benches/*", "tests/*"]
[dependencies]
anyhow = { version = "1" }
aksr = { version = "0.0.3" }
@ -47,8 +48,6 @@ tokenizers = { version = "0.21.1" }
paste = "1.0.15"
base64ct = "=1.7.3"
[build-dependencies]
prost-build = "0.13.5"
[dev-dependencies]
argh = "0.1.13"

357
README.md
View File

@ -1,291 +1,146 @@
<h2 align="center">usls</h2>
<p align="center">
<!-- Rust MSRV -->
<a href='https://crates.io/crates/usls'>
<img src='https://img.shields.io/crates/msrv/usls-yellow?' alt='Rust MSRV'>
</a>
<!-- ONNXRuntime MSRV -->
<a href='https://github.com/microsoft/onnxruntime/releases'>
<img src='https://img.shields.io/badge/onnxruntime-%3E%3D%201.19.0-3399FF' alt='ONNXRuntime MSRV'>
</a>
<!-- CUDA MSRV -->
<a href='https://developer.nvidia.com/cuda-toolkit-archive'>
<img src='https://img.shields.io/badge/CUDA-%3E%3D%2012.0-green' alt='CUDA MSRV'>
</a>
<!-- cuDNN MSRV -->
<a href='https://developer.nvidia.com/cudnn-downloads'>
<img src='https://img.shields.io/badge/cuDNN-%3E%3D%209.0-green4' alt='cuDNN MSRV'>
</a>
<!-- TensorRT MSRV -->
<a href='https://developer.nvidia.com/tensorrt'>
<img src='https://img.shields.io/badge/TensorRT-%3E%3D%2012.0-0ABF53' alt='TensorRT MSRV'>
</a>
</p>
<p align="center">
<!-- Examples Link -->
<a href="./examples">
<img src="https://img.shields.io/badge/Examples-1A86FD?&logo=anki" alt="Examples">
</a>
<!-- Docs.rs Link -->
<a href='https://docs.rs/usls'>
<img src='https://img.shields.io/badge/Docs-usls-yellow?&logo=docs.rs&color=FFA200' alt='Documentation'>
</a>
</p>
<p align="center">
<!-- CI Badge -->
<a href="https://github.com/jamjamjon/usls/actions/workflows/rust-ci.yml">
<a href="https://github.com/jamjamjon/usls/actions/workflows/rust-ci.yml">
<img src="https://github.com/jamjamjon/usls/actions/workflows/rust-ci.yml/badge.svg" alt="Rust CI">
</a>
<a href='https://crates.io/crates/usls'>
<img src='https://img.shields.io/crates/v/usls.svg' alt='Crates.io Version'>
</a>
<!-- Crates.io Downloads -->
<a href="https://crates.io/crates/usls">
<img alt="Crates.io Downloads" src="https://img.shields.io/crates/d/usls?&color=946CE6">
<a href='https://github.com/microsoft/onnxruntime/releases'>
<img src='https://img.shields.io/badge/onnxruntime-%3E%3D%201.22.0-3399FF' alt='ONNXRuntime MSRV'>
</a>
<a href='https://crates.io/crates/usls'>
<img src='https://img.shields.io/crates/msrv/usls-yellow?' alt='Rust MSRV'>
</a>
</p>
<p align="center">
<strong>⭐️ Star if helpful! ⭐️</strong>
</p>
**usls** is an evolving Rust library focused on inference for advanced **vision** and **vision-language** models, along with practical vision utilities.
**usls** is a cross-platform Rust library powered by ONNX Runtime for efficient inference of SOTA vision and multi-modal models(typically under 1B parameters).
- **SOTA Model Inference:** Supports a wide range of state-of-the-art vision and multi-modal models (typically with fewer than 1B parameters).
- **Multi-backend Acceleration:** Supports CPU, CUDA, TensorRT, and CoreML.
- **Easy Data Handling:** Easily read images, video streams, and folders with iterator support.
- **Rich Result Types:** Built-in containers for common vision outputs like bounding boxes (Hbb, Obb), polygons, masks, etc.
- **Annotation & Visualization:** Draw and display inference results directly, similar to OpenCV's `imshow()`.
## 📚 Documentation
- [API Documentation](https://docs.rs/usls/latest/usls/)
- [Examples](./examples)
## 🧩 Supported Models
## 🚀 Quick Start
```bash
# CPU
cargo run -r --example yolo # YOLOv8-n detect by default
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLO11](https://github.com/ultralytics/ultralytics), [YOLOv12](https://github.com/sunsmarterjie/yolov12)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
- **Vision Models**: [RT-DETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro), [FastViT](https://github.com/apple/ml-fastvit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MobileOne](https://github.com/apple/ml-mobileone)
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [jina-clip-v1-v2](https://huggingface.co/jinaai/jina-clip-v1), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242), [Moondream2](https://github.com/vikhyat/moondream/tree/main)
- **OCR-Related Models**: [FAST](https://github.com/czczup/FAST), [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947), [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159), [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html), [TrOCR](https://huggingface.co/microsoft/trocr-base-printed), [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
# NVIDIA CUDA
cargo run -r -F cuda --example yolo -- --device cuda:0
<details>
<summary>Full list of supported models (click to expand)</summary>
# NVIDIA TensorRT
cargo run -r -F tensorrt --example yolo -- --device tensorrt:0
| Model | Task / Description | Example | CoreML | CUDA<br />FP32 | CUDA<br />FP16 | TensorRT<br />FP32 | TensorRT<br />FP16 |
| -------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------- | ------ | -------------- | -------------- | ------------------ | ------------------ |
| [BEiT](https://github.com/microsoft/unilm/tree/master/beit) | Image Classification | [demo](examples/beit) | ✅ | ✅ | ✅ | | |
| [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) | Image Classification | [demo](examples/convnext) | ✅ | ✅ | ✅ | | |
| [FastViT](https://github.com/apple/ml-fastvit) | Image Classification | [demo](examples/fastvit) | ✅ | ✅ | ✅ | | |
| [MobileOne](https://github.com/apple/ml-mobileone) | Image Classification | [demo](examples/mobileone) | ✅ | ✅ | ✅ | | |
| [DeiT](https://github.com/facebookresearch/deit) | Image Classification | [demo](examples/deit) | ✅ | ✅ | ✅ | | |
| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision Embedding | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLOv5](https://github.com/ultralytics/yolov5) | Image Classification<br />Object Detection<br />Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLOv8<br />YOLO11](https://github.com/ultralytics/ultralytics) | Object Detection<br />Instance Segmentation<br />Image Classification<br />Oriented Object Detection<br />Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLOv12](https://github.com/sunsmarterjie/yolov12) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [RT-DETR](https://github.com/lyuwenyu/RT-DETR) | Object Detection | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | | |
| [RF-DETR](https://github.com/roboflow/rf-detr) | Object Detection | [demo](examples/rfdetr) | ✅ | ✅ | ✅ | | |
| [PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) | Object Detection | [demo](examples/picodet-layout) | ✅ | ✅ | ✅ | | |
| [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) | Object Detection | [demo](examples/picodet-layout) | ✅ | ✅ | ✅ | | |
| [D-FINE](https://github.com/manhbd-22022602/D-FINE) | Object Detection | [demo](examples/d-fine) | ✅ | ✅ | ✅ | | |
| [DEIM](https://github.com/ShihuaHuang95/DEIM) | Object Detection | [demo](examples/deim) | ✅ | ✅ | ✅ | | |
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ✅ | ❌ | ❌ |
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | |
| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | |
| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | |
| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | |
| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | ✅ | | |
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Open-Set Detection With Language | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | ✅ | | |
| [CLIP](https://github.com/openai/CLIP) | Vision-Language Embedding | [demo](examples/clip) | ✅ | ✅ | ✅ | ❌ | ❌ |
| [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](examples/clip) | ✅ | ✅ | ✅ | ❌ | ❌ |
| [jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2) | Vision-Language Embedding | [demo](examples/clip) | ✅ | ✅ | ✅ | ❌ | ❌ |
| [mobileclip](https://github.com/apple/ml-mobileclip) | Vision-Language Embedding | [demo](examples/clip) | ✅ | ✅ | ✅ | ❌ | ❌ |
| [BLIP](https://github.com/salesforce/BLIP) | Image Captioning | [demo](examples/blip) | ✅ | ✅ | ✅ | ❌ | ❌ |
| [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [FAST](https://github.com/czczup/FAST) | Text Detection | [demo](examples/fast) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LinkNet](https://arxiv.org/abs/1707.03718) | Text Detection | [demo](examples/linknet) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) | Tabel Recognition | [demo](examples/slanet) | ✅ | ✅ | ✅ | | |
| [TrOCR](https://huggingface.co/microsoft/trocr-base-printed) | Text Recognition | [demo](examples/trocr) | ✅ | ✅ | ✅ | | |
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DepthAnything v1<br />DepthAnything v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ✅ | ❌ | ❌ |
| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | ✅ | ✅ | ✅ | | |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | ✅ | ✅ | ✅ | | |
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | |
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection<br />Open-Set Keypoints Detection<br />Image Caption<br />Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | |
| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation<br />Background Removal | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | |
| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation<br />Background Removal | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | |
# Apple Silicon CoreML
cargo run -r -F coreml --example yolo -- --device coreml
</details>
# Intel OpenVINO
cargo run -r -F openvino -F ort-load-dynamic --example yolo -- --device openvino:CPU
## 🛠️ Installation
To get started, you'll need:
### 1. Protocol Buffers Compiler (`protoc`)
Required for building the project. [Official installation guide](https://protobuf.dev/installation/)
```shell
# Linux (apt)
sudo apt install -y protobuf-compiler
# macOS (Homebrew)
brew install protobuf
# Windows (Winget)
winget install protobuf
# Verify installation
protoc --version # Should be 3.x or higher
# And other EPs...
```
### 2. Rust Toolchain
```shell
# Install Rust and Cargo
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
### 3. Add usls to Your Project
## ⚙️ Installation
Add the following to your `Cargo.toml`:
```toml
[dependencies]
# Recommended: Use GitHub version
usls = { git = "https://github.com/jamjamjon/usls" }
usls = { git = "https://github.com/jamjamjon/usls", features = [ "cuda" ] }
# Alternative: Use crates.io version
usls = "latest-version"
```
> **Note:** **The GitHub version is recommended as it contains the latest updates.**
## ⚡ Cargo Features
- **ONNXRuntime-related features (enabled by default)**, provide model inference and model zoo support:
- **`ort-download-binaries`** (**default**): Automatically downloads prebuilt `ONNXRuntime` binaries for supported platforms. Provides core model loading and inference capabilities using the `CPU` execution provider.
- **`ort-load-dynamic `** Dynamic linking. You'll need to compile `ONNXRuntime` from [source](https://github.com/microsoft/onnxruntime) or download a [precompiled package](https://github.com/microsoft/onnxruntime/releases), and then link it manually. [See the guide here](https://ort.pyke.io/setup/linking#dynamic-linking).
## ⚡ Supported Models
<details>
<summary>Click to expand</summary>
- **`cuda`**: Enables the NVIDIA `CUDA` provider. Requires `CUDA` toolkit and `cuDNN` installed.
- **`trt`**: Enables the NVIDIA `TensorRT` provider. Requires `TensorRT` libraries installed.
- **`mps`**: Enables the Apple `CoreML` provider for macOS.
| Model | Task / Description | Example |
| ----- | ----------------- | ------- |
| [BEiT](https://github.com/microsoft/unilm/tree/master/beit) | Image Classification | [demo](examples/beit) |
| [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) | Image Classification | [demo](examples/convnext) |
| [FastViT](https://github.com/apple/ml-fastvit) | Image Classification | [demo](examples/fastvit) |
| [MobileOne](https://github.com/apple/ml-mobileone) | Image Classification | [demo](examples/mobileone) |
| [DeiT](https://github.com/facebookresearch/deit) | Image Classification | [demo](examples/deit) |
| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision Embedding | [demo](examples/dinov2) |
| [YOLOv5](https://github.com/ultralytics/yolov5) | Image Classification<br />Object Detection<br />Instance Segmentation | [demo](examples/yolo) |
| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) |
| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) |
| [YOLOv8<br />YOLO11](https://github.com/ultralytics/ultralytics) | Object Detection<br />Instance Segmentation<br />Image Classification<br />Oriented Object Detection<br />Keypoint Detection | [demo](examples/yolo) |
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) |
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) |
| [YOLOv12](https://github.com/sunsmarterjie/yolov12) | Object Detection | [demo](examples/yolo) |
| [RT-DETR](https://github.com/lyuwenyu/RT-DETR) | Object Detection | [demo](examples/rtdetr) |
| [RF-DETR](https://github.com/roboflow/rf-detr) | Object Detection | [demo](examples/rfdetr) |
| [PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) | Object Detection | [demo](examples/picodet-layout) |
| [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) | Object Detection | [demo](examples/picodet-layout) |
| [D-FINE](https://github.com/manhbd-22022602/D-FINE) | Object Detection | [demo](examples/d-fine) |
| [DEIM](https://github.com/ShihuaHuang95/DEIM) | Object Detection | [demo](examples/deim) |
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) |
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) |
| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) |
| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) |
| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) |
| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) |
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) |
| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Open-Set Detection With Language | [demo](examples/yolo) |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) |
| [CLIP](https://github.com/openai/CLIP) | Vision-Language Embedding | [demo](examples/clip) |
| [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](examples/clip) |
| [jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2) | Vision-Language Embedding | [demo](examples/clip) |
| [mobileclip](https://github.com/apple/ml-mobileclip) | Vision-Language Embedding | [demo](examples/clip) |
| [BLIP](https://github.com/salesforce/BLIP) | Image Captioning | [demo](examples/blip) |
| [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) |
| [FAST](https://github.com/czczup/FAST) | Text Detection | [demo](examples/fast) |
| [LinkNet](https://arxiv.org/abs/1707.03718) | Text Detection | [demo](examples/linknet) |
| [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) |
| [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) | Tabel Recognition | [demo](examples/slanet) |
| [TrOCR](https://huggingface.co/microsoft/trocr-base-printed) | Text Recognition | [demo](examples/trocr) |
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) |
| [DepthAnything v1<br />DepthAnything v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) |
| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) |
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) |
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection<br />Open-Set Keypoints Detection<br />Image Caption<br />Visual Question Answering | [demo](examples/moondream2) |
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) |
| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) |
| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation<br />Background Removal | [demo](examples/rmbg) |
| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation<br />Background Removal | [demo](examples/rmbg) |
| [MediaPipe: Selfie-segmentation](https://ai.google.dev/edge/mediapipe/solutions/vision/image_segmenter) | Image Segmentation | [demo](examples/mediapipe-selfie-segmentation) |
- **If you only need basic features** (such as image/video reading, result visualization, etc.), you can disable the default features to minimize dependencies:
```shell
usls = { git = "https://github.com/jamjamjon/usls", default-features = false }
```
- **`video`** : Enable video stream reading, and video writing.(Note: Powered by [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb). Check their repositories for potential issues.)
## ✨ Example
- Model Inference
```shell
cargo run -r --example yolo # CPU
cargo run -r -F cuda --example yolo -- --device cuda:0 # GPU
```
- Reading Images
```rust
// Read a single image
let image = DataLoader::try_read_one("./assets/bus.jpg")?;
// Read multiple images
let images = DataLoader::try_read_n(&["./assets/bus.jpg", "./assets/cat.png"])?;
// Read all images in a folder
let images = DataLoader::try_read_folder("./assets")?;
// Read images matching a pattern (glob)
let images = DataLoader::try_read_pattern("./assets/*.Jpg")?;
// Load images and iterate
let dl = DataLoader::new("./assets")?.with_batch(2).build()?;
for images in dl.iter() {
// Code here
}
```
- Reading Video
```rust
let dl = DataLoader::new("http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4")?
.with_batch(1)
.with_nf_skip(2)
.with_progress_bar(true)
.build()?;
for images in dl.iter() {
// Code here
}
```
- Annotate
```rust
let annotator = Annotator::default();
let image = DataLoader::try_read_one("./assets/bus.jpg")?;
// hbb
let hbb = Hbb::default()
.with_xyxy(669.5233, 395.4491, 809.0367, 878.81226)
.with_id(0)
.with_name("person")
.with_confidence(0.87094545);
let _ = annotator.annotate(&image, &hbb)?;
// keypoints
let keypoints: Vec<Keypoint> = vec![
Keypoint::default()
.with_xy(139.35767, 443.43655)
.with_id(0)
.with_name("nose")
.with_confidence(0.9739332),
Keypoint::default()
.with_xy(147.38545, 434.34055)
.with_id(1)
.with_name("left_eye")
.with_confidence(0.9098319),
Keypoint::default()
.with_xy(128.5701, 434.07516)
.with_id(2)
.with_name("right_eye")
.with_confidence(0.9320564),
];
let _ = annotator.annotate(&image, &keypoints)?;
```
</details>
- Visualizing Inference Results and Exporting Video
```rust
let dl = DataLoader::new(args.source.as_str())?.build()?;
let mut viewer = Viewer::default().with_window_scale(0.5);
## 📦 Cargo Features
- **`ort-download-binaries`** (**default**): Automatically downloads prebuilt ONNXRuntime binaries for supported platforms
- **`ort-load-dynamic`**: Dynamic linking to ONNXRuntime libraries ([Guide](https://ort.pyke.io/setup/linking#dynamic-linking))
- **`video`**: Enable video stream reading and writing (via [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb))
- **`cuda`**: NVIDIA CUDA GPU acceleration support
- **`tensorrt`**: NVIDIA TensorRT optimization for inference acceleration
- **`coreml`**: Apple CoreML acceleration for macOS/iOS devices
- **`openvino`**: Intel OpenVINO toolkit for CPU/GPU/VPU acceleration
- **`onednn`**: Intel oneDNN (formerly MKL-DNN) for CPU optimization
- **`directml`**: Microsoft DirectML for Windows GPU acceleration
- **`xnnpack`**: Google XNNPACK for mobile and edge device optimization
- **`rocm`**: AMD ROCm platform for GPU acceleration
- **`cann`**: Huawei CANN (Compute Architecture for Neural Networks) support
- **`rknpu`**: Rockchip NPU acceleration
- **`acl`**: Arm Compute Library for Arm processors
- **`nnapi`**: Android Neural Networks API support
- **`armnn`**: Arm NN inference engine
- **`tvm`**: Apache TVM tensor compiler stack
- **`qnn`**: Qualcomm Neural Network SDK
- **`migraphx`**: AMD MIGraphX for GPU acceleration
- **`vitis`**: Xilinx Vitis AI for FPGA acceleration
- **`azure`**: Azure Machine Learning integration
for images in &dl {
// Check if the window exists and is open
if viewer.is_window_exist() && !viewer.is_window_open() {
break;
}
// Show image in window
viewer.imshow(&images[0])?;
// Handle key events and delay
if let Some(key) = viewer.wait_key(1) {
if key == usls::Key::Escape {
break;
}
}
// Your custom code here
// Write video frame (requires video feature)
// if args.save_video {
// viewer.write_video_frame(&images[0])?;
// }
}
```
**All examples are located in the [examples](./examples/) directory.**
## ❓ FAQ
See issues or open a new discussion.
See [issues](https://github.com/jamjamjon/usls/issues) or open a new discussion.
## 🤝 Contributing

View File

@ -1,7 +0,0 @@
use std::io::Result;
fn main() -> Result<()> {
prost_build::compile_protos(&["src/utils/onnx.proto3"], &["src"])?;
Ok(())
}

View File

@ -0,0 +1,10 @@
## Quick Start
```shell
cargo run -r --example mediapipe-selfie-segmentation -- --dtype f16
```
## Results
![](https://github.com/jamjamjon/assets/releases/download/mediapipe/demo-selfie-segmentaion.jpg)

View File

@ -0,0 +1,49 @@
use usls::{models::MediaPipeSegmenter, Annotator, Config, DataLoader};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// dtype
#[argh(option, default = "String::from(\"auto\")")]
dtype: String,
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
}
fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
// build model
let config = Config::mediapipe_selfie_segmentater()
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = MediaPipeSegmenter::new(config)?;
// load image
let xs = DataLoader::try_read_n(&["images/selfie-segmenter.png"])?;
// run
let ys = model.forward(&xs)?;
// annotate
let annotator =
Annotator::default().with_mask_style(usls::Style::mask().with_mask_cutout(true));
for (x, y) in xs.iter().zip(ys.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", model.spec()])?
.join(usls::timestamp(None))
.display(),
))?;
}
Ok(())
}

View File

@ -1,21 +1,48 @@
use anyhow::Result;
use usls::{models::RTMO, Annotator, Config, DataLoader, Style, SKELETON_COCO_19};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// dtype
#[argh(option, default = "String::from(\"fp16\")")]
dtype: String,
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
/// scale: s, m, l
#[argh(option, default = "String::from(\"t\")")]
scale: String,
}
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
// build model
let mut model = RTMO::new(Config::rtmo_s().commit()?)?;
let config = match args.scale.as_str() {
"t" => Config::rtmo_t(),
"s" => Config::rtmo_s(),
"m" => Config::rtmo_m(),
"l" => Config::rtmo_l(),
_ => unreachable!(),
}
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = RTMO::new(config)?;
// load image
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
// run
let ys = model.forward(&xs)?;
println!("ys: {:?}", ys);
// println!("ys: {:?}", ys);
// annotate
let annotator = Annotator::default()
@ -37,5 +64,8 @@ fn main() -> Result<()> {
))?;
}
// summary
model.summary();
Ok(())
}

View File

@ -1,12 +1,11 @@
use aksr::Builder;
use crate::{
impl_ort_config_methods, impl_processor_config_methods,
models::{SamKind, YOLOPredsFormat},
ORTConfig, ProcessorConfig, Scale, Task, Version,
};
/// Config for building models and inference
/// Configuration for model inference including engines, processors, and task settings.
#[derive(Builder, Debug, Clone)]
pub struct Config {
// Basics

View File

@ -550,9 +550,19 @@ trait DataLoaderIterator {
}
}
/// An iterator implementation for `DataLoader` that enables batch processing of images.
///
/// This struct is created by the `into_iter` method on `DataLoader`.
/// It provides functionality for:
/// - Receiving batches of images through a channel
/// - Tracking progress with an optional progress bar
/// - Processing images in configurable batch sizes
pub struct DataLoaderIntoIterator {
/// Channel receiver for getting batches of images
receiver: mpsc::Receiver<Vec<Image>>,
/// Optional progress bar for tracking iteration progress
progress_bar: Option<ProgressBar>,
/// Number of images to process in each batch
batch_size: u64,
}
@ -593,6 +603,15 @@ impl IntoIterator for DataLoader {
}
}
/// A borrowing iterator for `DataLoader` that enables batch processing of images.
///
/// This iterator is created by the `iter()` method on `DataLoader`, allowing iteration
/// over batches of images without taking ownership of the `DataLoader`.
///
/// # Fields
/// - `receiver`: A reference to the channel receiver that provides batches of images
/// - `progress_bar`: An optional reference to a progress bar for tracking iteration progress
/// - `batch_size`: The number of images to process in each batch
pub struct DataLoaderIter<'a> {
receiver: &'a mpsc::Receiver<Vec<Image>>,
progress_bar: Option<&'a ProgressBar>,

View File

@ -1,4 +1,5 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
/// Device types for model execution.
pub enum Device {
Cpu(usize),
Cuda(usize),

View File

@ -1,4 +1,5 @@
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
/// Data type enumeration for tensor elements.
pub enum DType {
#[default]
Auto,

View File

@ -46,25 +46,39 @@ impl From<TensorElementType> for DType {
}
/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions.
/// ONNX Runtime tensor attributes containing names, data types, and dimensions.
#[derive(Builder, Debug, Clone, Default)]
/// ONNX Runtime tensor attributes containing metadata.
pub struct OrtTensorAttr {
/// Tensor names.
pub names: Vec<String>,
/// Tensor data types.
pub dtypes: Vec<TensorElementType>,
/// Tensor dimensions for each tensor.
pub dimss: Vec<Vec<usize>>,
}
/// ONNX I/O structure containing input/output attributes and session.
#[derive(Debug)]
pub struct OnnxIo {
/// Input tensor attributes.
pub inputs: OrtTensorAttr,
/// Output tensor attributes.
pub outputs: OrtTensorAttr,
/// ONNX Runtime session.
pub session: Session,
/// ONNX model protocol buffer.
pub proto: onnx::ModelProto,
}
/// ONNX Runtime inference engine with configuration and session management.
#[derive(Debug, Builder)]
pub struct Engine {
/// Model file path.
pub file: String,
/// Model specification string.
pub spec: String,
/// Execution device.
pub device: Device,
#[args(inc)]
pub iiixs: Vec<Iiix>,
@ -72,9 +86,13 @@ pub struct Engine {
pub params: Option<usize>,
#[args(aka = "memory")]
pub wbmems: Option<usize>,
/// Input min-opt-max configurations.
pub inputs_minoptmax: Vec<Vec<MinOptMax>>,
/// ONNX I/O structure.
pub onnx: Option<OnnxIo>,
/// Timing statistics.
pub ts: Ts,
/// Number of dry runs for warmup.
pub num_dry_run: usize,
// global
@ -158,7 +176,7 @@ impl Default for Engine {
// cann
cann_graph_inference: true,
cann_dump_graphs: false,
cann_dump_om_model: false,
cann_dump_om_model: true,
// nnapi
nnapi_cpu_only: false,
nnapi_disable_cpu: false,

View File

@ -3,8 +3,11 @@ use crate::MinOptMax;
/// A struct for input composed of the i-th input, the ii-th dimension, and the value.
#[derive(Clone, Debug, Default)]
pub struct Iiix {
/// Input index.
pub i: usize,
/// Dimension index.
pub ii: usize,
/// Min-Opt-Max value specification.
pub x: MinOptMax,
}

View File

@ -9,6 +9,7 @@ use std::path::{Path, PathBuf};
use crate::{build_resizer_filter, Hub, Location, MediaType, X};
/// Information about image transformation including source and destination dimensions.
#[derive(Builder, Debug, Clone, Default)]
pub struct ImageTransformInfo {
pub width_src: u32,
@ -19,6 +20,7 @@ pub struct ImageTransformInfo {
pub width_scale: f32,
}
/// Image resize modes for different scaling strategies.
#[derive(Debug, Clone, Default)]
pub enum ResizeMode {
/// StretchToFit
@ -30,6 +32,7 @@ pub enum ResizeMode {
Letterbox,
}
/// Image wrapper with metadata and transformation capabilities.
#[derive(Builder, Clone)]
pub struct Image {
image: RgbImage,
@ -376,8 +379,13 @@ impl Image {
}
}
/// Extension trait for converting between vectors of different image types.
/// Provides methods to convert between `Vec<Image>` and `Vec<DynamicImage>`.
pub trait ImageVecExt {
/// Converts the vector into a vector of `DynamicImage`s.
fn into_dyns(self) -> Vec<DynamicImage>;
/// Converts the vector into a vector of `Image`s.
fn into_images(self) -> Vec<Image>;
}

View File

@ -1,9 +1,12 @@
use anyhow::Result;
use rand::distr::{weighted::WeightedIndex, Distribution};
/// Logits sampler for text generation with temperature and nucleus sampling.
#[derive(Debug, Clone)]
pub struct LogitsSampler {
/// Temperature parameter for controlling randomness in sampling.
temperature: f32,
/// Top-p parameter for nucleus sampling.
p: f32,
}

View File

@ -11,6 +11,7 @@ pub(crate) const STREAM_PROTOCOLS: &[&str] = &[
"rtsp://", "rtsps://", "rtspu://", "rtmp://", "rtmps://", "hls://",
];
/// Media location type indicating local or remote source.
#[derive(Debug, Clone, Default, Copy)]
pub enum Location {
#[default]
@ -18,6 +19,7 @@ pub enum Location {
Remote,
}
/// Stream type for media content.
#[derive(Debug, Clone, Copy, Default)]
pub enum StreamType {
#[default]
@ -25,6 +27,7 @@ pub enum StreamType {
Live,
}
/// Media type classification for different content formats.
#[derive(Debug, Clone, Copy, Default)]
pub enum MediaType {
#[default]

61
src/core/mod.rs Normal file
View File

@ -0,0 +1,61 @@
#[macro_use]
mod ort_config;
#[macro_use]
mod processor_config;
mod config;
mod dataloader;
mod device;
mod dir;
mod dtype;
mod dynconf;
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
mod engine;
mod hub;
mod iiix;
mod image;
mod logits_sampler;
mod media;
mod min_opt_max;
mod names;
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
#[allow(clippy::all)]
pub(crate) mod onnx;
mod ops;
mod processor;
mod retry;
mod scale;
mod task;
mod traits;
mod ts;
mod utils;
mod version;
mod x;
mod xs;
pub use config::*;
pub use dataloader::*;
pub use device::Device;
pub use dir::*;
pub use dtype::DType;
pub use dynconf::DynConf;
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
pub use engine::*;
pub use hub::*;
pub(crate) use iiix::Iiix;
pub use image::*;
pub use logits_sampler::LogitsSampler;
pub use media::*;
pub use min_opt_max::MinOptMax;
pub use names::*;
pub use ops::*;
pub use ort_config::ORTConfig;
pub use processor::*;
pub use processor_config::ProcessorConfig;
pub use scale::Scale;
pub use task::Task;
pub use traits::*;
pub use ts::Ts;
pub use utils::*;
pub use version::Version;
pub use x::X;
pub use xs::Xs;

View File

@ -1,3 +1,5 @@
/// A comprehensive list of 4585 object categories used in the YOLOE model.
/// A comprehensive list of 4585 object categories used in the YOLOE object detection model.
pub static NAMES_YOLOE_4585: [&str; 4585] = [
"3D CG rendering",
"3D glasses",
@ -4586,6 +4588,13 @@ pub static NAMES_YOLOE_4585: [&str; 4585] = [
"zoo",
];
/// Labels for DOTA (Dataset for Object deTection in Aerial images) v1.5 with 16 categories.
///
/// DOTA is a large-scale dataset for object detection in aerial images. This version includes:
/// - Transportation objects (planes, ships, vehicles)
/// - Sports facilities (courts, fields)
/// - Infrastructure (bridges, harbors, storage tanks)
/// - Other aerial view objects
pub const NAMES_DOTA_V1_5_16: [&str; 16] = [
"plane",
"ship",
@ -4604,6 +4613,13 @@ pub const NAMES_DOTA_V1_5_16: [&str; 16] = [
"swimming pool",
"container crane",
];
/// Labels for DOTA (Dataset for Object deTection in Aerial images) v1.0 with 15 categories.
///
/// Similar to DOTA v1.5 but excludes the "container crane" category. Includes:
/// - Transportation objects
/// - Sports facilities
/// - Infrastructure
/// - Other aerial view objects
pub const NAMES_DOTA_V1_15: [&str; 15] = [
"plane",
"ship",
@ -4622,6 +4638,13 @@ pub const NAMES_DOTA_V1_15: [&str; 15] = [
"swimming pool",
];
/// Labels for document layout analysis using YOLO with 10 categories.
///
/// These labels are used to identify different components in document layout analysis:
/// - Text elements (title, plain text)
/// - Visual elements (figures and their captions)
/// - Tabular elements (tables, captions, footnotes)
/// - Mathematical elements (formulas and captions)
pub const NAMES_YOLO_DOCLAYOUT_10: [&str; 10] = [
"title",
"plain text",
@ -4634,6 +4657,14 @@ pub const NAMES_YOLO_DOCLAYOUT_10: [&str; 10] = [
"isolate_formula",
"formula_caption",
];
/// Labels for PicoDet document layout analysis with 17 categories.
///
/// A comprehensive set of labels for detailed document structure analysis, including:
/// - Textual elements (titles, paragraphs, abstracts)
/// - Visual elements (images, figures)
/// - Structural elements (headers, footers)
/// - Reference elements (citations, footnotes)
/// - Special elements (algorithms, seals)
pub const NAMES_PICODET_LAYOUT_17: [&str; 17] = [
"paragraph_title",
"image",
@ -4653,8 +4684,26 @@ pub const NAMES_PICODET_LAYOUT_17: [&str; 17] = [
"footer",
"seal",
];
/// Simplified PicoDet document layout labels with 3 basic categories.
///
/// A minimal set of labels for basic document layout analysis:
/// - Images
/// - Tables
/// - Seals (official stamps or marks)
pub const NAMES_PICODET_LAYOUT_3: [&str; 3] = ["image", "table", "seal"];
/// Core PicoDet document layout labels with 5 essential categories.
///
/// A balanced set of labels for common document layout analysis tasks:
/// - Textual content (Text, Title, List)
/// - Visual content (Figure)
/// - Structured content (Table)
pub const NAMES_PICODET_LAYOUT_5: [&str; 5] = ["Text", "Title", "List", "Table", "Figure"];
/// COCO dataset keypoint labels for human pose estimation with 17 points.
///
/// Standard keypoints for human body pose detection including:
/// - Facial features (eyes, nose, ears)
/// - Upper body joints (shoulders, elbows, wrists)
/// - Lower body joints (hips, knees, ankles)
pub const NAMES_COCO_KEYPOINTS_17: [&str; 17] = [
"nose",
"left_eye",
@ -4675,6 +4724,15 @@ pub const NAMES_COCO_KEYPOINTS_17: [&str; 17] = [
"right_ankle",
];
/// COCO dataset object detection labels with 80 categories.
///
/// A widely-used set of object categories for general object detection, including:
/// - Living beings (person, animals)
/// - Vehicles and transportation
/// - Common objects and furniture
/// - Food and kitchen items
/// - Sports and recreational equipment
/// - Electronics and appliances
pub const NAMES_COCO_80: [&str; 80] = [
"person",
"bicycle",
@ -4758,6 +4816,14 @@ pub const NAMES_COCO_80: [&str; 80] = [
"toothbrush",
];
/// Extended COCO dataset labels with 91 categories including background and unused slots.
///
/// An extended version of COCO labels that includes:
/// - Background class (index 0)
/// - All standard COCO categories
/// - Reserved "unused" slots for dataset compatibility
///
/// Note: Indices are preserved for compatibility with pre-trained models, hence the "unused" entries.
pub const NAMES_COCO_91: [&str; 91] = [
"background", // 0
"person", // 1
@ -4852,6 +4918,13 @@ pub const NAMES_COCO_91: [&str; 91] = [
"toothbrush", // 90
];
/// Human body parts segmentation labels with 28 categories.
///
/// Detailed categorization of human body parts for segmentation tasks:
/// - Basic body parts (face, neck, hair, torso)
/// - Left/Right symmetrical parts (hands, arms, legs, feet)
/// - Clothing regions (apparel, upper/lower clothing)
/// - Facial features (lips, teeth, tongue)
pub const NAMES_BODY_PARTS_28: [&str; 28] = [
"Background",
"Apparel",
@ -4883,6 +4956,16 @@ pub const NAMES_BODY_PARTS_28: [&str; 28] = [
"Tongue",
];
/// ImageNet ILSVRC 1000-class classification labels.
///
/// The standard ImageNet 1K classification dataset includes:
/// - Animals (mammals, birds, fish, insects)
/// - Objects (tools, vehicles, furniture)
/// - Foods and plants
/// - Natural scenes
/// - Man-made structures
///
/// Note: Labels include both common names and scientific nomenclature where applicable.
pub const NAMES_IMAGENET_1K: [&str; 1000] = [
"tench, Tinca tinca",
"goldfish, Carassius auratus",

1210
src/core/onnx.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@ -11,19 +11,33 @@ use ndarray::{concatenate, s, Array, Array3, ArrayView1, Axis, IntoDimension, Ix
use rayon::prelude::*;
/// Image and tensor operations for preprocessing and postprocessing.
pub enum Ops<'a> {
/// Resize images to exact dimensions.
FitExact(&'a [DynamicImage], u32, u32, &'a str),
/// Apply letterbox padding to maintain aspect ratio.
Letterbox(&'a [DynamicImage], u32, u32, &'a str, u8, &'a str, bool),
/// Normalize values to a specific range.
Normalize(f32, f32),
/// Standardize using mean and standard deviation.
Standardize(&'a [f32], &'a [f32], usize),
/// Permute tensor dimensions.
Permute(&'a [usize]),
/// Insert a new axis at specified position.
InsertAxis(usize),
/// Convert from NHWC to NCHW format.
Nhwc2nchw,
/// Convert from NCHW to NHWC format.
Nchw2nhwc,
/// Apply L2 normalization.
Norm,
/// Apply sigmoid activation function.
Sigmoid,
/// Broadcast tensor to larger dimensions.
Broadcast,
/// Reshape tensor to specified shape.
ToShape,
/// Repeat tensor elements.
Repeat,
}

View File

@ -3,6 +3,7 @@ use anyhow::Result;
use crate::{try_fetch_file_stem, DType, Device, Hub, Iiix, MinOptMax};
/// ONNX Runtime configuration with device and optimization settings.
#[derive(Builder, Debug, Clone)]
pub struct ORTConfig {
pub file: String,
@ -71,7 +72,7 @@ impl Default for ORTConfig {
tensorrt_timing_cache: false,
cann_graph_inference: true,
cann_dump_graphs: false,
cann_dump_om_model: false,
cann_dump_om_model: true,
onednn_arena_allocator: true,
nnapi_cpu_only: false,
nnapi_disable_cpu: false,
@ -159,7 +160,6 @@ impl ORTConfig {
}
}
#[macro_export]
macro_rules! impl_ort_config_methods {
($ty:ty, $field:ident) => {
impl $ty {

View File

@ -7,6 +7,7 @@ use tokenizers::{Encoding, Tokenizer};
use crate::{Hub, Image, ImageTransformInfo, LogitsSampler, ProcessorConfig, ResizeMode, X};
/// Image and text processing pipeline with tokenization and transformation capabilities.
#[derive(Builder, Debug, Clone)]
pub struct Processor {
pub image_width: u32,

View File

@ -4,30 +4,51 @@ use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use crate::{Hub, ResizeMode};
/// Configuration for image and text processing pipelines.
#[derive(Builder, Debug, Clone)]
pub struct ProcessorConfig {
// Vision
/// Target image width for resizing.
pub image_width: Option<u32>,
/// Target image height for resizing.
pub image_height: Option<u32>,
/// Image resizing mode.
pub resize_mode: ResizeMode,
/// Image resize filter algorithm.
pub resize_filter: Option<&'static str>,
/// Padding value for image borders.
pub padding_value: u8,
/// Whether to normalize image values.
pub normalize: bool,
/// Standard deviation values for normalization.
pub image_std: Vec<f32>,
/// Mean values for normalization.
pub image_mean: Vec<f32>,
/// Whether to use NCHW format (channels first).
pub nchw: bool,
/// Whether to use unsigned integer format.
pub unsigned: bool,
// Text
/// Maximum sequence length for tokenization.
pub model_max_length: Option<u64>,
/// Path to tokenizer file.
pub tokenizer_file: Option<String>,
/// Path to model configuration file.
pub config_file: Option<String>,
/// Path to special tokens mapping file.
pub special_tokens_map_file: Option<String>,
/// Path to tokenizer configuration file.
pub tokenizer_config_file: Option<String>,
/// Path to generation configuration file.
pub generation_config_file: Option<String>,
/// Path to vocabulary file.
pub vocab_file: Option<String>,
/// Path to vocabulary text file.
pub vocab_txt: Option<String>,
/// Temperature parameter for text generation.
pub temperature: f32,
/// Top-p parameter for nucleus sampling.
pub topp: f32,
}
@ -150,7 +171,6 @@ impl ProcessorConfig {
}
}
#[macro_export]
macro_rules! impl_processor_config_methods {
($ty:ty, $field:ident) => {
impl $ty {

View File

@ -1,5 +1,6 @@
use std::str::FromStr;
/// Model scale variants for different model sizes.
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum Scale {
N,

View File

@ -1,15 +1,22 @@
use crate::{Hbb, Obb};
/// Trait for objects that have a confidence score.
pub trait HasScore {
/// Returns the confidence score.
fn score(&self) -> f32;
}
/// Trait for objects that can calculate Intersection over Union (IoU).
pub trait HasIoU {
/// Calculates IoU with another object.
fn iou(&self, other: &Self) -> f32;
}
/// Trait for Non-Maximum Suppression operations.
pub trait NmsOps {
/// Applies NMS in-place with the given IoU threshold.
fn apply_nms_inplace(&mut self, iou_threshold: f32);
/// Applies NMS and returns the filtered result.
fn apply_nms(self, iou_threshold: f32) -> Self;
}
@ -68,11 +75,17 @@ impl HasIoU for Obb {
}
}
/// Trait for geometric regions with area and intersection calculations.
pub trait Region {
/// Calculates the area of the region.
fn area(&self) -> f32;
/// Calculates the perimeter of the region.
fn perimeter(&self) -> f32;
/// Calculates the intersection area with another region.
fn intersect(&self, other: &Self) -> f32;
/// Calculates the union area with another region.
fn union(&self, other: &Self) -> f32;
/// Calculates Intersection over Union (IoU) with another region.
fn iou(&self, other: &Self) -> f32 {
self.intersect(other) / self.union(other)
}

View File

@ -19,6 +19,7 @@ macro_rules! elapsed {
}};
}
/// Time series collection for performance measurement and profiling.
#[derive(aksr::Builder, Debug, Default, Clone, PartialEq)]
pub struct Ts {
// { k1: [d1,d1,d1,..], k2: [d2,d2,d2,..], k3: [d3,d3,d3,..], ..}

View File

@ -1,52 +1,21 @@
mod config;
mod device;
mod dtype;
mod dynconf;
mod iiix;
mod logits_sampler;
mod min_opt_max;
mod names;
mod ops;
mod ort_config;
mod processor;
mod processor_config;
mod retry;
mod scale;
mod task;
mod traits;
mod ts;
mod version;
pub use config::*;
pub use device::Device;
pub use dtype::DType;
pub use dynconf::DynConf;
pub(crate) use iiix::Iiix;
pub use logits_sampler::LogitsSampler;
pub use min_opt_max::MinOptMax;
pub use names::*;
pub use ops::*;
pub use ort_config::ORTConfig;
pub use processor::*;
pub use processor_config::ProcessorConfig;
pub use scale::Scale;
pub use task::Task;
pub use traits::*;
pub use ts::Ts;
pub use version::Version;
/// The name of the current crate.
pub const CRATE_NAME: &str = env!("CARGO_PKG_NAME");
/// Standard prefix length for progress bar formatting.
pub const PREFIX_LENGTH: usize = 12;
/// Progress bar style for completion with iteration count.
pub const PROGRESS_BAR_STYLE_FINISH: &str =
"{prefix:>12.green.bold} {msg} for {human_len} iterations in {elapsed}";
/// Progress bar style for completion with multiplier format.
pub const PROGRESS_BAR_STYLE_FINISH_2: &str =
"{prefix:>12.green.bold} {msg} x{human_len} in {elapsed}";
/// Progress bar style for completion with byte size information.
pub const PROGRESS_BAR_STYLE_FINISH_3: &str =
"{prefix:>12.green.bold} {msg} ({binary_total_bytes}) in {elapsed}";
/// Progress bar style for ongoing operations with position indicator.
pub const PROGRESS_BAR_STYLE_CYAN_2: &str =
"{prefix:>12.cyan.bold} {human_pos}/{human_len} |{bar}| {msg}";
pub fn build_resizer_filter(
pub(crate) fn build_resizer_filter(
ty: &str,
) -> anyhow::Result<(fast_image_resize::Resizer, fast_image_resize::ResizeOptions)> {
use fast_image_resize::{FilterType, ResizeAlg, ResizeOptions, Resizer};
@ -66,7 +35,7 @@ pub fn build_resizer_filter(
))
}
pub fn try_fetch_file_stem<P: AsRef<std::path::Path>>(p: P) -> anyhow::Result<String> {
pub(crate) fn try_fetch_file_stem<P: AsRef<std::path::Path>>(p: P) -> anyhow::Result<String> {
let p = p.as_ref();
let stem = p
.file_stem()
@ -80,8 +49,7 @@ pub fn try_fetch_file_stem<P: AsRef<std::path::Path>>(p: P) -> anyhow::Result<St
Ok(stem.to_string())
}
// TODO
pub fn build_progress_bar(
pub(crate) fn build_progress_bar(
n: u64,
prefix: &str,
msg: Option<&str>,
@ -95,11 +63,41 @@ pub fn build_progress_bar(
Ok(pb)
}
/// Formats a byte size into a human-readable string using decimal (base-1000) units.
///
/// # Arguments
/// * `size` - The size in bytes to format
/// * `decimal_places` - Number of decimal places to show in the formatted output
///
/// # Returns
/// A string representing the size with appropriate decimal unit (B, KB, MB, etc.)
///
/// # Example
/// ```ignore
/// let size = 1500000.0;
/// let formatted = human_bytes_decimal(size, 2);
/// assert_eq!(formatted, "1.50 MB");
/// ```
pub fn human_bytes_decimal(size: f64, decimal_places: usize) -> String {
const DECIMAL_UNITS: [&str; 7] = ["B", "KB", "MB", "GB", "TB", "PB", "EB"];
format_bytes_internal(size, 1000.0, &DECIMAL_UNITS, decimal_places)
}
/// Formats a byte size into a human-readable string using binary (base-1024) units.
///
/// # Arguments
/// * `size` - The size in bytes to format
/// * `decimal_places` - Number of decimal places to show in the formatted output
///
/// # Returns
/// A string representing the size with appropriate binary unit (B, KiB, MiB, etc.)
///
/// # Example
/// ```ignore
/// let size = 1024.0 * 1024.0; // 1 MiB in bytes
/// let formatted = human_bytes_binary(size, 2);
/// assert_eq!(formatted, "1.00 MiB");
/// ```
pub fn human_bytes_binary(size: f64, decimal_places: usize) -> String {
const BINARY_UNITS: [&str; 7] = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"];
format_bytes_internal(size, 1024.0, &BINARY_UNITS, decimal_places)
@ -125,6 +123,22 @@ fn format_bytes_internal(
)
}
/// Generates a random alphanumeric string of the specified length.
///
/// # Arguments
/// * `length` - The desired length of the random string
///
/// # Returns
/// A string containing random alphanumeric characters (A-Z, a-z, 0-9).
/// Returns an empty string if length is 0.
///
/// # Example
/// ```ignore
/// let random = generate_random_string(8);
/// assert_eq!(random.len(), 8);
/// // Each character in the string will be alphanumeric
/// assert!(random.chars().all(|c| c.is_ascii_alphanumeric()));
/// ```
pub fn generate_random_string(length: usize) -> String {
use rand::{distr::Alphanumeric, rng, Rng};
if length == 0 {
@ -136,6 +150,28 @@ pub fn generate_random_string(length: usize) -> String {
result
}
/// Generates a timestamp string in the format "YYYYMMDDHHmmSSffffff" with an optional delimiter.
///
/// # Arguments
/// * `delimiter` - Optional string to insert between each component of the timestamp.
/// If None or empty string, components will be joined without delimiter.
///
/// # Returns
/// A string containing the current timestamp with the specified format.
///
/// # Example
/// ```ignore
/// use chrono::TimeZone;
///
/// // Without delimiter
/// let ts = timestamp(None);
/// assert_eq!(ts.len(), 20); // YYYYMMDDHHmmSSffffff
///
/// // With delimiter
/// let ts = timestamp(Some("-"));
/// // Format: YYYY-MM-DD-HH-mm-SS-ffffff
/// assert_eq!(ts.split('-').count(), 7);
/// ```
pub fn timestamp(delimiter: Option<&str>) -> String {
let delimiter = delimiter.unwrap_or("");
let format = format!("%Y{0}%m{0}%d{0}%H{0}%M{0}%S{0}%f", delimiter);

View File

@ -1,6 +1,7 @@
use aksr::Builder;
use std::convert::TryFrom;
/// Version representation with major, minor, and optional patch numbers.
#[derive(Debug, Builder, PartialEq, Eq, Copy, Clone, Hash, Default, PartialOrd, Ord)]
pub struct Version(pub u8, pub u8, pub Option<u8>);

View File

@ -6,6 +6,7 @@ use std::ops::{Deref, Index};
use crate::{generate_random_string, X};
/// Collection of named tensors with associated images and texts.
#[derive(Builder, Debug, Default, Clone)]
pub struct Xs {
map: HashMap<String, X>,

View File

@ -1,36 +0,0 @@
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
mod engine;
mod hbb;
mod image;
mod instance_meta;
mod keypoint;
mod mask;
mod obb;
mod polygon;
mod prob;
mod skeleton;
mod text;
mod x;
mod xs;
mod y;
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
#[allow(clippy::all)]
pub(crate) mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
pub use engine::*;
pub use hbb::*;
pub use image::*;
pub use instance_meta::*;
pub use keypoint::*;
pub use mask::*;
pub use obb::*;
pub use polygon::*;
pub use prob::*;
pub use skeleton::*;
pub use text::*;
pub use x::X;
pub use xs::Xs;
pub use y::*;

View File

@ -1,9 +0,0 @@
mod dataloader;
mod dir;
mod hub;
mod media;
pub use dataloader::*;
pub use dir::*;
pub use hub::*;
pub use media::*;

View File

@ -1,14 +1,64 @@
mod inference;
mod io;
//! # usls
//!
//! `usls` is a cross-platform Rust library that provides efficient inference for SOTA vision and multi-modal models using ONNX Runtime (typically under 1B parameters).
//!
//! ## 📚 Documentation
//! - [API Documentation](https://docs.rs/usls/latest/usls/)
//! - [Examples](https://github.com/jamjamjon/usls/tree/main/examples)
//! ## 🚀 Quick Start
//!
//! ```bash
//! # CPU
//! cargo run -r --example yolo # YOLOv8 detect by default
//!
//! # NVIDIA CUDA
//! cargo run -r -F cuda --example yolo -- --device cuda:0
//!
//! # NVIDIA TensorRT
//! cargo run -r -F tensorrt --example yolo -- --device tensorrt:0
//!
//! # Apple Silicon CoreML
//! cargo run -r -F coreml --example yolo -- --device coreml
//!
//! # Intel OpenVINO
//! cargo run -r -F openvino -F ort-load-dynamic --example yolo -- --device openvino:CPU
//! ```
//!
//! ## ⚡ Cargo Features
//! - **`ort-download-binaries`** (**default**): Automatically downloads prebuilt ONNXRuntime binaries for supported platforms
//! - **`ort-load-dynamic`**: Dynamic linking to ONNXRuntime libraries ([Guide](https://ort.pyke.io/setup/linking#dynamic-linking))
//! - **`video`**: Enable video stream reading and writing (via [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb))
//! - **`cuda`**: NVIDIA CUDA GPU acceleration support
//! - **`tensorrt`**: NVIDIA TensorRT optimization for inference acceleration
//! - **`coreml`**: Apple CoreML acceleration for macOS/iOS devices
//! - **`openvino`**: Intel OpenVINO toolkit for CPU/GPU/VPU acceleration
//! - **`onednn`**: Intel oneDNN (formerly MKL-DNN) for CPU optimization
//! - **`directml`**: Microsoft DirectML for Windows GPU acceleration
//! - **`xnnpack`**: Google XNNPACK for mobile and edge device optimization
//! - **`rocm`**: AMD ROCm platform for GPU acceleration
//! - **`cann`**: Huawei CANN (Compute Architecture for Neural Networks) support
//! - **`rknpu`**: Rockchip NPU acceleration
//! - **`acl`**: Arm Compute Library for Arm processors
//! - **`nnapi`**: Android Neural Networks API support
//! - **`armnn`**: Arm NN inference engine
//! - **`tvm`**: Apache TVM tensor compiler stack
//! - **`qnn`**: Qualcomm Neural Network SDK
//! - **`migraphx`**: AMD MIGraphX for GPU acceleration
//! - **`vitis`**: Xilinx Vitis AI for FPGA acceleration
//! - **`azure`**: Azure Machine Learning integration
//!
mod core;
/// Model Zoo
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
pub mod models;
mod utils;
#[macro_use]
mod results;
mod viz;
pub use inference::*;
pub use io::*;
pub use core::*;
pub use minifb::Key;
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
pub use models::*;
pub use utils::*;
pub use results::*;
pub use viz::*;

View File

@ -7,6 +7,7 @@ use crate::{
elapsed, Config, DynConf, Engine, Hbb, Image, Mask, Obb, Ops, Polygon, Processor, Ts, Xs, Y,
};
/// DB (Differentiable Binarization) model for text detection.
#[derive(Debug, Builder)]
pub struct DB {
engine: Engine,

View File

@ -7,6 +7,7 @@ use std::fmt::Write;
use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y};
#[derive(Builder, Debug)]
/// Grounding DINO model for open-vocabulary object detection.
pub struct GroundingDINO {
pub engine: Engine,
height: usize,

View File

@ -0,0 +1,9 @@
# MediaPipe: Selfie segmentation model
## Official Repository
The official website can be found on: [GUIDE](https://ai.google.dev/edge/mediapipe/solutions/vision/image_segmenter)
## Example
Refer to the [example](../../../examples/mediapipe-selfie-segmentation)

View File

@ -0,0 +1,22 @@
/// Model configuration for `MediaPipeSegmenter`
impl crate::Config {
pub fn mediapipe() -> Self {
Self::default()
.with_name("mediapipe")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 256.into())
.with_model_ixx(0, 3, 256.into())
.with_image_mean(&[0.5, 0.5, 0.5])
.with_image_std(&[0.5, 0.5, 0.5])
.with_normalize(true)
}
pub fn mediapipe_selfie_segmentater() -> Self {
Self::mediapipe().with_model_file("selfie-segmenter.onnx")
}
pub fn mediapipe_selfie_segmentater_landscape() -> Self {
Self::mediapipe().with_model_file("selfie-segmenter-landscape.onnx")
}
}

View File

@ -0,0 +1,91 @@
use aksr::Builder;
use anyhow::Result;
use ndarray::Axis;
use crate::{elapsed, Config, Engine, Image, Mask, Ops, Processor, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct MediaPipeSegmenter {
engine: Engine,
height: usize,
width: usize,
batch: usize,
ts: Ts,
spec: String,
processor: Processor,
}
impl MediaPipeSegmenter {
pub fn new(config: Config) -> Result<Self> {
let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&256.into()).opt(),
engine.try_width().unwrap_or(&256.into()).opt(),
engine.ts().clone(),
);
let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
Ok(Self {
engine,
height,
width,
batch,
ts,
spec,
processor,
})
}
fn preprocess(&mut self, xs: &[Image]) -> Result<Xs> {
Ok(self.processor.process_images(xs)?.into())
}
fn inference(&mut self, xs: Xs) -> Result<Xs> {
self.engine.run(xs)
}
pub fn forward(&mut self, xs: &[Image]) -> Result<Vec<Y>> {
let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? });
let ys = elapsed!("inference", self.ts, { self.inference(ys)? });
let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? });
Ok(ys)
}
pub fn summary(&mut self) {
self.ts.summary();
}
fn postprocess(&mut self, xs: Xs) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() {
let (h1, w1) = (
self.processor.images_transform_info[idx].height_src,
self.processor.images_transform_info[idx].width_src,
);
let luma = luma.mapv(|x| (x * 255.0) as u8);
let luma = Ops::resize_luma8_u8(
&luma.into_raw_vec_and_offset().0,
self.width as _,
self.height as _,
w1 as _,
h1 as _,
false,
"Bilinear",
)?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
}
Ok(ys)
}
}

View File

@ -0,0 +1,4 @@
mod config;
mod r#impl;
pub use r#impl::*;

View File

@ -15,6 +15,7 @@ mod fastvit;
mod florence2;
mod grounding_dino;
mod linknet;
mod mediapipe_segmenter;
mod mobileone;
mod modnet;
mod moondream2;
@ -43,6 +44,7 @@ pub use depth_pro::*;
pub use dinov2::*;
pub use florence2::*;
pub use grounding_dino::*;
pub use mediapipe_segmenter::*;
pub use modnet::*;
pub use moondream2::*;
pub use owl::*;

View File

@ -5,6 +5,7 @@ use rayon::prelude::*;
use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y};
/// OWL-ViT v2 model for open-vocabulary object detection.
#[derive(Debug, Builder)]
pub struct OWLv2 {
engine: Engine,

View File

@ -15,6 +15,13 @@ impl crate::Config {
.with_keypoint_confs(&[0.5])
}
pub fn rtmo_t() -> Self {
Self::rtmo()
.with_model_ixx(0, 2, 416.into())
.with_model_ixx(0, 3, 416.into())
.with_model_file("t.onnx")
}
pub fn rtmo_s() -> Self {
Self::rtmo().with_model_file("s.onnx")
}

View File

@ -1,6 +1,7 @@
use aksr::Builder;
use anyhow::Result;
use ndarray::Axis;
use rayon::prelude::*;
use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Keypoint, Processor, Ts, Xs, Y};
@ -68,69 +69,67 @@ impl RTMO {
}
fn postprocess(&mut self, xs: Xs) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
// let (preds_bboxes, preds_kpts) = (&xs["dets"], &xs["keypoints"]);
let (preds_bboxes, preds_kpts) = (&xs[0], &xs[1]);
for (idx, (batch_bboxes, batch_kpts)) in preds_bboxes
let ys: Vec<Y> = xs[0]
.axis_iter(Axis(0))
.zip(preds_kpts.axis_iter(Axis(0)))
.into_par_iter()
.zip(xs[1].axis_iter(Axis(0)).into_par_iter())
.enumerate()
{
let (height_original, width_original) = (
self.processor.images_transform_info[idx].height_src,
self.processor.images_transform_info[idx].width_src,
);
let ratio = self.processor.images_transform_info[idx].height_scale;
let mut y_bboxes = Vec::new();
let mut y_kpts: Vec<Vec<Keypoint>> = Vec::new();
for (xyxyc, kpts) in batch_bboxes
.axis_iter(Axis(0))
.zip(batch_kpts.axis_iter(Axis(0)))
{
// bbox
let x1 = xyxyc[0] / ratio;
let y1 = xyxyc[1] / ratio;
let x2 = xyxyc[2] / ratio;
let y2 = xyxyc[3] / ratio;
let confidence = xyxyc[4];
if confidence < self.confs[0] {
continue;
}
y_bboxes.push(
Hbb::default()
.with_xyxy(
x1.max(0.0f32).min(width_original as _),
y1.max(0.0f32).min(height_original as _),
x2,
y2,
)
.with_confidence(confidence)
.with_id(0)
.with_name("Person"),
.map(|(idx, (batch_bboxes, batch_kpts))| {
let (height_original, width_original) = (
self.processor.images_transform_info[idx].height_src,
self.processor.images_transform_info[idx].width_src,
);
let ratio = self.processor.images_transform_info[idx].height_scale;
// keypoints
let mut kpts_ = Vec::new();
for (i, kpt) in kpts.axis_iter(Axis(0)).enumerate() {
let x = kpt[0] / ratio;
let y = kpt[1] / ratio;
let c = kpt[2];
if c < self.kconfs[i] {
kpts_.push(Keypoint::default());
} else {
kpts_.push(Keypoint::default().with_id(i).with_confidence(c).with_xy(
x.max(0.0f32).min(width_original as _),
y.max(0.0f32).min(height_original as _),
));
let mut y_bboxes = Vec::new();
let mut y_kpts: Vec<Vec<Keypoint>> = Vec::new();
for (xyxyc, kpts) in batch_bboxes
.axis_iter(Axis(0))
.zip(batch_kpts.axis_iter(Axis(0)))
{
// bbox
let x1 = xyxyc[0] / ratio;
let y1 = xyxyc[1] / ratio;
let x2 = xyxyc[2] / ratio;
let y2 = xyxyc[3] / ratio;
let confidence = xyxyc[4];
if confidence < self.confs[0] {
continue;
}
y_bboxes.push(
Hbb::default()
.with_xyxy(
x1.max(0.0f32).min(width_original as _),
y1.max(0.0f32).min(height_original as _),
x2,
y2,
)
.with_confidence(confidence)
.with_id(0)
.with_name("Person"),
);
// keypoints
let mut kpts_ = Vec::new();
for (i, kpt) in kpts.axis_iter(Axis(0)).enumerate() {
let x = kpt[0] / ratio;
let y = kpt[1] / ratio;
let c = kpt[2];
if c < self.kconfs[i] {
kpts_.push(Keypoint::default());
} else {
kpts_.push(Keypoint::default().with_id(i).with_confidence(c).with_xy(
x.max(0.0f32).min(width_original as _),
y.max(0.0f32).min(height_original as _),
));
}
}
y_kpts.push(kpts_);
}
y_kpts.push(kpts_);
}
ys.push(Y::default().with_hbbs(&y_bboxes).with_keypointss(&y_kpts));
}
Y::default().with_hbbs(&y_bboxes).with_keypointss(&y_kpts)
})
.collect();
Ok(ys)
}

View File

@ -2,6 +2,13 @@ use crate::{models::SamKind, Config};
/// Model configuration for `Segment Anything Model`
impl Config {
/// Creates a base SAM configuration with common settings.
///
/// Sets up default parameters for image preprocessing and model architecture:
/// - 1024x1024 input resolution
/// - Adaptive resize mode
/// - Image normalization parameters
/// - Contour finding enabled
pub fn sam() -> Self {
Self::default()
.with_name("sam")
@ -19,6 +26,8 @@ impl Config {
.with_find_contours(true)
}
/// Creates a configuration for SAM v1 base model.
/// Uses the original ViT-B architecture.
pub fn sam_v1_base() -> Self {
Self::sam()
.with_encoder_file("sam-vit-b-encoder.onnx")
@ -29,6 +38,8 @@ impl Config {
// Self::sam().with_decoder_file("sam-vit-b-decoder-singlemask.onnx")
// }
/// Creates a configuration for SAM 2.0 tiny model.
/// Uses a hierarchical architecture with tiny backbone.
pub fn sam2_tiny() -> Self {
Self::sam()
.with_encoder_file("sam2-hiera-tiny-encoder.onnx")
@ -36,6 +47,8 @@ impl Config {
.with_decoder_file("sam2-hiera-tiny-decoder.onnx")
}
/// Creates a configuration for SAM 2.0 small model.
/// Uses a hierarchical architecture with small backbone.
pub fn sam2_small() -> Self {
Self::sam()
.with_encoder_file("sam2-hiera-small-encoder.onnx")
@ -43,6 +56,8 @@ impl Config {
.with_sam_kind(SamKind::Sam2)
}
/// Creates a configuration for SAM 2.0 base-plus model.
/// Uses a hierarchical architecture with enhanced base backbone.
pub fn sam2_base_plus() -> Self {
Self::sam()
.with_encoder_file("sam2-hiera-base-plus-encoder.onnx")
@ -50,6 +65,8 @@ impl Config {
.with_sam_kind(SamKind::Sam2)
}
/// Creates a configuration for MobileSAM tiny model.
/// Lightweight model optimized for mobile devices.
pub fn mobile_sam_tiny() -> Self {
Self::sam()
.with_encoder_file("mobile-sam-vit-t-encoder.onnx")
@ -57,6 +74,8 @@ impl Config {
.with_decoder_file("mobile-sam-vit-t-decoder.onnx")
}
/// Creates a configuration for SAM-HQ tiny model.
/// High-quality variant focused on better mask quality.
pub fn sam_hq_tiny() -> Self {
Self::sam()
.with_encoder_file("sam-hq-vit-t-encoder.onnx")
@ -64,6 +83,8 @@ impl Config {
.with_decoder_file("sam-hq-vit-t-decoder.onnx")
}
/// Creates a configuration for EdgeSAM 3x model.
/// Edge-based variant optimized for speed and efficiency.
pub fn edge_sam_3x() -> Self {
Self::sam()
.with_encoder_file("edge-sam-3x-encoder.onnx")

View File

@ -8,12 +8,18 @@ use crate::{
elapsed, Config, DynConf, Engine, Image, Mask, Ops, Polygon, Processor, SamPrompt, Ts, Xs, X, Y,
};
/// SAM model variants for different use cases.
#[derive(Debug, Clone)]
pub enum SamKind {
/// Original SAM model
Sam,
Sam2, // 2.0
/// SAM 2.0 with hierarchical architecture
Sam2,
/// Mobile optimized SAM
MobileSam,
/// High quality SAM with better segmentation
SamHq,
/// Efficient SAM with edge-based segmentation
EdgeSam,
}
@ -32,6 +38,10 @@ impl FromStr for SamKind {
}
}
/// Segment Anything Model (SAM) for image segmentation.
///
/// A foundation model for generating high-quality object masks from input prompts such as points or boxes.
/// Supports multiple variants including the original SAM, SAM2, MobileSAM, SAM-HQ and EdgeSAM.
#[derive(Builder, Debug)]
pub struct SAM {
encoder: Engine,
@ -49,6 +59,10 @@ pub struct SAM {
}
impl SAM {
/// Creates a new SAM model instance from the provided configuration.
///
/// Initializes the model based on the specified SAM variant (original SAM, SAM2, MobileSAM etc.)
/// and configures its encoder-decoder architecture.
pub fn new(config: Config) -> Result<Self> {
let encoder = Engine::try_from_config(&config.encoder)?;
let decoder = Engine::try_from_config(&config.decoder)?;
@ -94,6 +108,11 @@ impl SAM {
})
}
/// Runs the complete segmentation pipeline on a batch of images.
///
/// The pipeline consists of:
/// 1. Encoding the images into embeddings
/// 2. Decoding the embeddings with input prompts to generate segmentation masks
pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let ys = elapsed!("encode", self.ts, { self.encode(xs)? });
let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? });
@ -101,11 +120,16 @@ impl SAM {
Ok(ys)
}
/// Encodes input images into image embeddings.
pub fn encode(&mut self, xs: &[Image]) -> Result<Xs> {
let xs_ = self.processor.process_images(xs)?;
self.encoder.run(Xs::from(xs_))
}
/// Generates segmentation masks from image embeddings and input prompts.
///
/// Takes the image embeddings from the encoder and input prompts (points or boxes)
/// to generate binary segmentation masks for the prompted objects.
pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind {
SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])),
@ -285,10 +309,12 @@ impl SAM {
Ok(ys)
}
/// Returns the width of the low-resolution feature maps.
pub fn width_low_res(&self) -> usize {
self.width / 4
}
/// Returns the height of the low-resolution feature maps.
pub fn height_low_res(&self) -> usize {
self.height / 4
}

View File

@ -3,9 +3,12 @@ mod r#impl;
pub use r#impl::*;
/// SAM prompt containing coordinates and labels for segmentation.
#[derive(Debug, Default, Clone)]
pub struct SamPrompt {
/// Point coordinates for prompting.
pub coords: Vec<Vec<[f32; 2]>>,
/// Labels corresponding to the coordinates.
pub labels: Vec<Vec<f32>>,
}

View File

@ -2,24 +2,36 @@ use crate::Config;
/// Model configuration for `SAM2.1`
impl Config {
/// Creates a configuration for SAM 2.1 tiny model.
///
/// The smallest variant of the hierarchical architecture, optimized for speed.
pub fn sam2_1_tiny() -> Self {
Self::sam()
.with_encoder_file("sam2.1-hiera-tiny-encoder.onnx")
.with_decoder_file("sam2.1-hiera-tiny-decoder.onnx")
}
/// Creates a configuration for SAM 2.1 small model.
///
/// A balanced variant offering good performance and efficiency.
pub fn sam2_1_small() -> Self {
Self::sam()
.with_encoder_file("sam2.1-hiera-small-encoder.onnx")
.with_decoder_file("sam2.1-hiera-small-decoder.onnx")
}
/// Creates a configuration for SAM 2.1 base-plus model.
///
/// An enhanced base model with improved segmentation quality.
pub fn sam2_1_base_plus() -> Self {
Self::sam()
.with_encoder_file("sam2.1-hiera-base-plus-encoder.onnx")
.with_decoder_file("sam2.1-hiera-base-plus-decoder.onnx")
}
/// Creates a configuration for SAM 2.1 large model.
///
/// The most powerful variant with highest segmentation quality.
pub fn sam2_1_large() -> Self {
Self::sam()
.with_encoder_file("sam2.1-hiera-large-encoder.onnx")

View File

@ -6,6 +6,10 @@ use crate::{
elapsed, Config, DynConf, Engine, Image, Mask, Ops, Processor, SamPrompt, Ts, Xs, X, Y,
};
/// SAM2 (Segment Anything Model 2.1) for advanced image segmentation.
///
/// A hierarchical vision foundation model with improved efficiency and quality.
/// Features enhanced backbone architecture and optimized prompt handling.
#[derive(Builder, Debug)]
pub struct SAM2 {
encoder: Engine,
@ -20,6 +24,12 @@ pub struct SAM2 {
}
impl SAM2 {
/// Creates a new SAM2 model instance from the provided configuration.
///
/// Initializes the model with:
/// - Encoder-decoder architecture
/// - Image preprocessing settings
/// - Confidence thresholds
pub fn new(config: Config) -> Result<Self> {
let encoder = Engine::try_from_config(&config.encoder)?;
let decoder = Engine::try_from_config(&config.decoder)?;
@ -48,6 +58,11 @@ impl SAM2 {
})
}
/// Runs the complete segmentation pipeline on a batch of images.
///
/// The pipeline consists of:
/// 1. Image encoding into hierarchical features
/// 2. Prompt-guided mask generation
pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let ys = elapsed!("encode", self.ts, { self.encode(xs)? });
let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? });
@ -55,11 +70,16 @@ impl SAM2 {
Ok(ys)
}
/// Encodes input images into hierarchical feature representations.
pub fn encode(&mut self, xs: &[Image]) -> Result<Xs> {
let xs_ = self.processor.process_images(xs)?;
self.encoder.run(Xs::from(xs_))
}
/// Generates segmentation masks from encoded features and prompts.
///
/// Takes hierarchical image features and user prompts (points/boxes)
/// to generate accurate object masks.
pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let (image_embeddings, high_res_features_0, high_res_features_1) = (&xs[0], &xs[1], &xs[2]);
@ -153,10 +173,12 @@ impl SAM2 {
Ok(ys)
}
/// Returns the width of the low-resolution feature maps.
pub fn width_low_res(&self) -> usize {
self.width / 4
}
/// Returns the height of the low-resolution feature maps.
pub fn height_low_res(&self) -> usize {
self.height / 4
}

View File

@ -5,6 +5,7 @@ use rayon::prelude::*;
use crate::{elapsed, Config, DynConf, Engine, Image, Processor, Text, Ts, Xs, Y};
/// SVTR (Scene Text Recognition) model for text recognition.
#[derive(Builder, Debug)]
pub struct SVTR {
engine: Engine,

View File

@ -1,7 +1,15 @@
use crate::Scale;
/// Model configuration for `TrOCR`
/// Model configuration for `TrOCR`.
impl crate::Config {
/// Creates a base configuration for TrOCR models with default settings.
///
/// This includes:
/// - Batch size of 1
/// - Image input dimensions of 384x384 with 3 channels
/// - Image normalization with mean and std of [0.5, 0.5, 0.5]
/// - Lanczos3 resize filter
/// - Default tokenizer and model configuration files
pub fn trocr() -> Self {
Self::default()
.with_name("trocr")
@ -18,6 +26,9 @@ impl crate::Config {
.with_tokenizer_config_file("trocr/tokenizer_config.json")
}
/// Creates a configuration for the small TrOCR model variant optimized for printed text.
///
/// Uses the small scale model files and tokenizer configuration.
pub fn trocr_small_printed() -> Self {
Self::trocr()
.with_scale(Scale::S)
@ -27,6 +38,9 @@ impl crate::Config {
.with_tokenizer_file("trocr/tokenizer-small.json")
}
/// Creates a configuration for the base TrOCR model variant optimized for handwritten text.
///
/// Uses the base scale model files and tokenizer configuration.
pub fn trocr_base_handwritten() -> Self {
Self::trocr()
.with_scale(Scale::B)
@ -36,6 +50,9 @@ impl crate::Config {
.with_tokenizer_file("trocr/tokenizer-base.json")
}
/// Creates a configuration for the small TrOCR model variant optimized for handwritten text.
///
/// Modifies the small printed configuration to use handwritten-specific model files.
pub fn trocr_small_handwritten() -> Self {
Self::trocr_small_printed()
.with_visual_file("s-encoder-handwritten.onnx")
@ -43,6 +60,9 @@ impl crate::Config {
.with_textual_decoder_merged_file("s-decoder-merged-handwritten.onnx")
}
/// Creates a configuration for the base TrOCR model variant optimized for printed text.
///
/// Modifies the base handwritten configuration to use printed-specific model files.
pub fn trocr_base_printed() -> Self {
Self::trocr_base_handwritten()
.with_visual_file("b-encoder-printed.onnx")

View File

@ -6,9 +6,12 @@ use std::str::FromStr;
use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y};
/// TrOCR model variants for different text types.
#[derive(Debug, Copy, Clone)]
pub enum TrOCRKind {
/// Model variant optimized for machine-printed text recognition
Printed,
/// Model variant optimized for handwritten text recognition
HandWritten,
}
@ -24,23 +27,59 @@ impl FromStr for TrOCRKind {
}
}
/// TrOCR model for optical character recognition.
///
/// TrOCR is a transformer-based OCR model that combines an image encoder with
/// a text decoder for end-to-end text recognition. It supports both printed and
/// handwritten text recognition through different model variants.
///
/// The model consists of:
/// - An encoder that processes the input image
/// - A decoder that generates text tokens
/// - A merged decoder variant for optimized inference
/// - A processor for image preprocessing and text postprocessing
#[derive(Debug, Builder)]
pub struct TrOCR {
/// Image encoder engine
encoder: Engine,
/// Text decoder engine for token generation
decoder: Engine,
/// Optimized merged decoder engine
decoder_merged: Engine,
/// Maximum length of generated text sequence
max_length: u32,
/// Token ID representing end of sequence
eos_token_id: u32,
/// Token ID used to start text generation
decoder_start_token_id: u32,
/// Timestamp tracking for performance analysis
ts: Ts,
/// Number of key-value pairs in decoder attention
n_kvs: usize,
/// Image and text processor for pre/post processing
processor: Processor,
/// Batch size for inference
batch: usize,
/// Input image height
height: usize,
/// Input image width
width: usize,
}
impl TrOCR {
/// Creates a new TrOCR model instance from the given configuration.
///
/// # Arguments
/// * `config` - The model configuration containing paths to model files and parameters
///
/// # Returns
/// * `Result<Self>` - A new TrOCR instance if initialization succeeds
///
/// # Errors
/// Returns an error if:
/// - Required model files cannot be loaded
/// - Model configuration is invalid
/// - Tokenizer initialization fails
pub fn new(config: Config) -> Result<Self> {
let encoder = Engine::try_from_config(&config.visual)?;
let decoder = Engine::try_from_config(&config.textual_decoder)?;
@ -86,6 +125,16 @@ impl TrOCR {
})
}
/// Encodes the given images into feature vectors using the TrOCR encoder.
///
/// This method processes the images through the image processor and then
/// encodes them using the encoder engine.
///
/// # Arguments
/// * `xs` - A slice of `Image` instances to be encoded.
///
/// # Errors
/// Returns an error if image processing or encoding fails.
pub fn encode(&mut self, xs: &[Image]) -> Result<X> {
let ys = self.processor.process_images(xs)?;
self.batch = xs.len(); // update
@ -93,6 +142,16 @@ impl TrOCR {
Ok(ys[0].to_owned())
}
/// Performs the forward pass of the TrOCR model, from encoding images to decoding text.
///
/// This method encodes the input images, generates token IDs using the decoder,
/// and finally decodes the token IDs into text.
///
/// # Arguments
/// * `xs` - A slice of `Image` instances to be processed.
///
/// # Errors
/// Returns an error if any step in the forward pass fails.
pub fn forward(&mut self, xs: &[Image]) -> Result<Vec<Y>> {
let encoder_hidden_states = elapsed!("encode", self.ts, { self.encode(xs)? });
let generated = elapsed!("generate", self.ts, {
@ -182,6 +241,13 @@ impl TrOCR {
Ok(token_ids)
}
/// Decodes the given token IDs into text using the TrOCR processor.
///
/// # Arguments
/// * `token_ids` - A vector of vector of token IDs to be decoded.
///
/// # Errors
/// Returns an error if decoding fails.
pub fn decode(&self, token_ids: Vec<Vec<u32>>) -> Result<Vec<Y>> {
// decode
let texts = self.processor.decode_tokens_batch(&token_ids, false)?;
@ -195,6 +261,7 @@ impl TrOCR {
Ok(texts)
}
/// Displays a summary of the TrOCR model's configuration and state.
pub fn summary(&self) {
self.ts.summary();
}

View File

@ -5,6 +5,9 @@ use crate::{
};
impl Config {
/// Creates a base YOLO configuration with common settings.
///
/// Sets up default input dimensions (640x640) and image processing parameters.
pub fn yolo() -> Self {
Self::default()
.with_name("yolo")
@ -16,6 +19,12 @@ impl Config {
.with_resize_filter("CatmullRom")
}
/// Creates a configuration for YOLO image classification.
///
/// Configures the model for ImageNet classification with:
/// - 224x224 input size
/// - Exact resize mode with bilinear interpolation
/// - ImageNet 1000 class names
pub fn yolo_classify() -> Self {
Self::yolo()
.with_task(Task::ImageClassification)
@ -26,24 +35,38 @@ impl Config {
.with_class_names(&NAMES_IMAGENET_1K)
}
/// Creates a configuration for YOLO object detection.
///
/// Configures the model for COCO dataset object detection with 80 classes.
pub fn yolo_detect() -> Self {
Self::yolo()
.with_task(Task::ObjectDetection)
.with_class_names(&NAMES_COCO_80)
}
/// Creates a configuration for YOLO pose estimation.
///
/// Configures the model for human keypoint detection with 17 COCO keypoints.
pub fn yolo_pose() -> Self {
Self::yolo()
.with_task(Task::KeypointsDetection)
.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17)
}
/// Creates a configuration for YOLO instance segmentation.
///
/// Configures the model for COCO dataset instance segmentation with 80 classes.
pub fn yolo_segment() -> Self {
Self::yolo()
.with_task(Task::InstanceSegmentation)
.with_class_names(&NAMES_COCO_80)
}
/// Creates a configuration for YOLO oriented object detection.
///
/// Configures the model for detecting rotated objects with:
/// - 1024x1024 input size
/// - DOTA v1 dataset classes
pub fn yolo_obb() -> Self {
Self::yolo()
.with_model_ixx(0, 2, 1024.into())
@ -52,6 +75,11 @@ impl Config {
.with_class_names(&NAMES_DOTA_V1_15)
}
/// Creates a configuration for document layout analysis using YOLOv10.
///
/// Configures the model for detecting document structure elements with:
/// - Variable input size up to 1024x1024
/// - 10 document layout classes
pub fn doclayout_yolo_docstructbench() -> Self {
Self::yolo_detect()
.with_version(10.into())
@ -62,12 +90,16 @@ impl Config {
.with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1
}
// YOLOE models
/// Creates a base YOLOE configuration with 4585 classes.
///
/// Configures the model for instance segmentation with a large class vocabulary.
pub fn yoloe() -> Self {
Self::yolo()
.with_task(Task::InstanceSegmentation)
.with_class_names(&NAMES_YOLOE_4585)
}
/// Creates a configuration for YOLOE-v8s segmentation model.
/// Uses the small variant of YOLOv8 architecture.
pub fn yoloe_v8s_seg_pf() -> Self {
Self::yoloe()
.with_version(8.into())
@ -75,6 +107,8 @@ impl Config {
.with_model_file("yoloe-v8s-seg-pf.onnx")
}
/// Creates a configuration for YOLOE-v8m segmentation model.
/// Uses the medium variant of YOLOv8 architecture.
pub fn yoloe_v8m_seg_pf() -> Self {
Self::yoloe()
.with_version(8.into())
@ -82,6 +116,8 @@ impl Config {
.with_model_file("yoloe-v8m-seg-pf.onnx")
}
/// Creates a configuration for YOLOE-v8l segmentation model.
/// Uses the large variant of YOLOv8 architecture.
pub fn yoloe_v8l_seg_pf() -> Self {
Self::yoloe()
.with_version(8.into())
@ -89,6 +125,8 @@ impl Config {
.with_model_file("yoloe-v8l-seg-pf.onnx")
}
/// Creates a configuration for YOLOE-11s segmentation model.
/// Uses the small variant of YOLOv11 architecture.
pub fn yoloe_11s_seg_pf() -> Self {
Self::yoloe()
.with_version(11.into())

View File

@ -12,6 +12,14 @@ use crate::{
Ts, Version, Xs, Y,
};
/// YOLO (You Only Look Once) object detection model.
///
/// A versatile deep learning model that can perform multiple computer vision tasks including:
/// - Object Detection
/// - Instance Segmentation
/// - Keypoint Detection
/// - Image Classification
/// - Oriented Object Detection
#[derive(Debug, Builder)]
pub struct YOLO {
engine: Engine,
@ -45,6 +53,12 @@ impl TryFrom<Config> for YOLO {
}
impl YOLO {
/// Creates a new YOLO model instance from the provided configuration.
///
/// This function initializes the model by:
/// - Loading the ONNX model file
/// - Setting up input dimensions and processing parameters
/// - Configuring task-specific settings and output formats
pub fn new(config: Config) -> Result<Self> {
let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts, spec) = (
@ -284,16 +298,24 @@ impl YOLO {
})
}
/// Pre-processes input images before model inference.
fn preprocess(&mut self, xs: &[Image]) -> Result<Xs> {
let x = self.processor.process_images(xs)?;
Ok(x.into())
}
/// Performs model inference on the pre-processed input.
fn inference(&mut self, xs: Xs) -> Result<Xs> {
self.engine.run(xs)
}
/// Performs the complete inference pipeline on a batch of images.
///
/// The pipeline consists of:
/// 1. Pre-processing the input images
/// 2. Running model inference
/// 3. Post-processing the outputs to generate final predictions
pub fn forward(&mut self, xs: &[Image]) -> Result<Vec<Y>> {
let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? });
let ys = elapsed!("inference", self.ts, { self.inference(ys)? });
@ -302,6 +324,7 @@ impl YOLO {
Ok(ys)
}
/// Post-processes model outputs to generate final predictions.
fn postprocess(&self, xs: Xs) -> Result<Vec<Y>> {
// let protos = if xs.len() == 2 { Some(&xs[1]) } else { None };
let ys: Vec<Y> = xs[0]
@ -605,6 +628,7 @@ impl YOLO {
// Ok(ys.into())
}
/// Extracts class names from the ONNX model metadata if available.
fn fetch_names_from_onnx(engine: &Engine) -> Option<Vec<String>> {
// fetch class names from onnx metadata
// String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}`
@ -616,6 +640,7 @@ impl YOLO {
.into()
}
/// Extracts the number of keypoints from the ONNX model metadata if available.
fn fetch_nk_from_onnx(engine: &Engine) -> Option<usize> {
Regex::new(r"(\d+), \d+")
.ok()?
@ -624,6 +649,7 @@ impl YOLO {
.and_then(|m| m.as_str().parse::<usize>().ok())
}
/// Prints a summary of the model configuration and parameters.
pub fn summary(&mut self) {
self.ts.summary();
}

View File

@ -2,6 +2,7 @@ use ndarray::{ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr};
use crate::Task;
/// Bounding box coordinate format types.
#[derive(Debug, Clone, PartialEq)]
pub enum BoxType {
Cxcywh,
@ -11,6 +12,7 @@ pub enum BoxType {
XyCxcy,
}
/// Classification output format types.
#[derive(Debug, Clone, PartialEq)]
pub enum ClssType {
Clss,
@ -20,18 +22,21 @@ pub enum ClssType {
ClssConf,
}
/// Keypoint output format types.
#[derive(Debug, Clone, PartialEq)]
pub enum KptsType {
Xys,
Xycs,
}
/// Anchor position in the prediction pipeline.
#[derive(Debug, Clone, PartialEq)]
pub enum AnchorsPosition {
Before,
After,
}
/// YOLO prediction format configuration.
#[derive(Debug, Clone, PartialEq)]
pub struct YOLOPredsFormat {
pub clss: ClssType,

View File

@ -1,7 +1,8 @@
use aksr::Builder;
use crate::{impl_meta_methods, InstanceMeta, Keypoint, Style};
use crate::{InstanceMeta, Keypoint, Style};
/// Horizontal bounding box with position, size, and metadata.
#[derive(Builder, Clone, Default)]
pub struct Hbb {
x: f32,

View File

@ -1,4 +1,5 @@
#[derive(aksr::Builder, Clone, PartialEq)]
/// Metadata for detection instances including ID, confidence, and name.
pub struct InstanceMeta {
uid: usize,
id: Option<usize>,
@ -78,7 +79,6 @@ impl InstanceMeta {
}
}
#[macro_export]
macro_rules! impl_meta_methods {
() => {
pub fn with_uid(mut self, uid: usize) -> Self {

View File

@ -1,7 +1,7 @@
use aksr::Builder;
use std::ops::{Add, Div, Mul, Sub};
use crate::{impl_meta_methods, InstanceMeta, Style};
use crate::{InstanceMeta, Style};
/// Represents a keypoint in a 2D space with optional metadata.
#[derive(Builder, Default, Clone)]

View File

@ -3,13 +3,16 @@ use anyhow::Result;
use image::GrayImage;
use rayon::prelude::*;
use crate::{impl_meta_methods, InstanceMeta, Polygon, Style};
use crate::{InstanceMeta, Polygon, Style};
/// Mask: Gray Image.
#[derive(Builder, Default, Clone)]
pub struct Mask {
/// The grayscale image representing the mask.
mask: GrayImage,
/// Metadata associated with the mask instance.
meta: InstanceMeta,
/// Optional styling information for visualization.
style: Option<Style>,
}

22
src/results/mod.rs Normal file
View File

@ -0,0 +1,22 @@
#[macro_use]
mod instance_meta;
mod hbb;
mod keypoint;
mod mask;
mod obb;
mod polygon;
mod prob;
mod skeleton;
mod text;
mod y;
pub use hbb::*;
pub use instance_meta::*;
pub use keypoint::*;
pub use mask::*;
pub use obb::*;
pub use polygon::*;
pub use prob::*;
pub use skeleton::*;
pub use text::*;
pub use y::*;

View File

@ -1,7 +1,8 @@
use aksr::Builder;
use crate::{impl_meta_methods, Hbb, InstanceMeta, Keypoint, Polygon, Style};
use crate::{Hbb, InstanceMeta, Keypoint, Polygon, Style};
/// Oriented bounding box with four vertices and metadata.
#[derive(Builder, Default, Clone, PartialEq)]
pub struct Obb {
vertices: [[f32; 2]; 4], // ordered

View File

@ -4,7 +4,7 @@ use geo::{
Point, Simplify,
};
use crate::{impl_meta_methods, Hbb, InstanceMeta, Mask, Obb, Style};
use crate::{Hbb, InstanceMeta, Mask, Obb, Style};
/// Polygon.
#[derive(Builder, Clone)]

View File

@ -1,7 +1,8 @@
use aksr::Builder;
use crate::{impl_meta_methods, InstanceMeta, Style};
use crate::{InstanceMeta, Style};
/// Probability result with classification metadata.
#[derive(Builder, Clone, PartialEq, Default, Debug)]
pub struct Prob {
meta: InstanceMeta,

View File

@ -1,5 +1,6 @@
use crate::Color;
/// Connection between two keypoints with optional color.
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Connection {
pub indices: (usize, usize),
@ -24,6 +25,7 @@ impl From<(usize, usize, Color)> for Connection {
}
}
/// Skeleton structure containing keypoint connections.
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Skeleton {
pub connections: Vec<Connection>,
@ -85,6 +87,12 @@ impl<const N: usize> From<([(usize, usize); N], [Color; N])> for Skeleton {
}
}
/// Defines the keypoint connections for the COCO person skeleton with 19 connections.
/// Each tuple (a, b) represents a connection between keypoint indices a and b.
/// The connections define the following body parts:
/// - Upper body: shoulders, elbows, wrists
/// - Torso: shoulders to hips
/// - Lower body: hips, knees, ankles
pub const SKELETON_COCO_19: [(usize, usize); 19] = [
(15, 13),
(13, 11),
@ -107,6 +115,12 @@ pub const SKELETON_COCO_19: [(usize, usize); 19] = [
(4, 6),
];
/// Defines colors for visualizing each connection in the COCO person skeleton.
/// Colors are grouped by body parts:
/// - Blue (0x3399ff): Upper limbs
/// - Pink (0xff33ff): Torso
/// - Orange (0xff8000): Lower limbs
/// - Green (0x00ff00): Head and neck
pub const SKELETON_COLOR_COCO_19: [Color; 19] = [
Color(0x3399ffff),
Color(0x3399ffff),

View File

@ -1,7 +1,8 @@
use aksr::Builder;
use crate::{impl_meta_methods, InstanceMeta, Style};
use crate::{InstanceMeta, Style};
/// Text detection result with content and metadata.
#[derive(Builder, Clone, Default)]
pub struct Text {
text: String,
@ -11,13 +12,6 @@ pub struct Text {
impl std::fmt::Debug for Text {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// f.debug_struct("Text")
// .field("text", &self.text)
// .field("id", &self.meta.id())
// .field("name", &self.meta.name())
// .field("confidence", &self.meta.confidence())
// .finish()
let mut f = f.debug_struct("Text");
f.field("text", &self.text);
if let Some(id) = &self.meta.id() {

View File

@ -1,978 +0,0 @@
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package onnx;
// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
//
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
// of key-value pairs, where order does not matter and duplicates
// are not allowed.
// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
// proto3 requires the first enum value to be zero.
// We add this just to appease the compiler.
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;
// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION_2019_3_18 = 0x0000000000000005;
// IR VERSION 6 published on Sep 19, 2019
// - Add support for sparse tensor constants stored in model.
// - Add message SparseTensorProto
// - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external operator sets.
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on May 5, 2023
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION_2023_5_5 = 0x0000000000000009;
// IR VERSION 10 published on March 25, 2024
// Added UINT4, INT4, overload field for functions and metadata_props on multiple proto definitions.
IR_VERSION_2024_3_25 = 0x000000000000000A;
// IR VERSION 11 published on May 12, 2025
// Added FLOAT4E2M1, multi-device protobuf classes.
IR_VERSION = 0x000000000000000B;
}
// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
reserved 12, 16 to 19;
reserved "v";
// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;
SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed.
string doc_string = 13;
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accommodate proto3 implementations.
AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
float f = 2; // float
int64 i = 3; // int
bytes s = 4; // UTF-8 string
TensorProto t = 5; // tensor value
GraphProto g = 6; // graph
SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
// the shape of the value.
message ValueInfoProto {
// This field MUST be present in this version of the IR.
string name = 1; // namespace Value
// This field MUST be present in this version of the IR for
// inputs and outputs of the top-level graph.
TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
string doc_string = 3;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 4;
}
// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in this version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
string domain = 7; // namespace Domain
// Overload identifier, used only to map this to a model-local function.
string overload = 8;
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
string doc_string = 6;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 9;
// Configuration of multi-device annotations.
repeated NodeDeviceConfigurationProto device_configurations = 10;
}
// IntIntListEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message IntIntListEntryProto {
int64 key = 1;
repeated int64 value = 2;
};
// Multi-device configuration proto for NodeProto.
message NodeDeviceConfigurationProto {
// This field MUST be present for this version of the IR.
// ID of the configuration. MUST match the name of a DeviceConfigurationProto.
string configuration_id = 1;
// Sharding spec for the node.
repeated ShardingSpecProto sharding_spec = 2;
// Pipeline stage of this node.
int32 pipeline_stage = 3;
}
// ShardingSpecProto: This describes the sharding spec for a specific
// input or output tensor of a node.
message ShardingSpecProto {
// This field MUST be present for this version of the IR.
// Identifies the input or output of the node that is being sharded.
// Required to match a name specified in the node's input or output list of ValueInfoProtos.
// It is called `logical tensor` in subsequent descriptions.
string tensor_name = 1;
// The following is the list of devices across which the logical
// tensor is sharded or replicated.
repeated int64 device = 2;
// Each element v in above field devices may represent either a
// device or a set of devices (when we want the same shard/tensor
// to be replicated across a subset of devices), as indicated by
// the following optional map. If the map contains an entry for v,
// then v represents a device group, and the map indicates the set
// of devices in that group.
repeated IntIntListEntryProto index_to_device_group_map = 3;
// The following is the sharded-shape of the tensor, consisting of
// the sharding-spec for each axis of the tensor.
repeated ShardedDimProto sharded_dim = 4;
}
// ShardedDimProto: This describes the sharding spec for a single
// axis of a sharded tensor.
message ShardedDimProto {
// This field MUST be present for this version of the IR.
// The axis this sharding corresponds to. Must be in the range of
// [-r, r - 1], where r is the rank of the tensor. Negative axis values means
// counting from the back.
int64 axis = 1;
// Describes how the tensor on the provided axis is sharded.
// The common-case is described by a single instance of SimpleShardedDimProto.
// Multiple instances can be used to handle cases where a sharded
// tensor is reshaped, fusing multiple axes into one.
repeated SimpleShardedDimProto simple_sharding = 2;
}
// SimpleShardedDimProto: Indicates that N blocks are divided into M shards.
// N is allowed to be symbolic where M is required to be a constant.
message SimpleShardedDimProto {
// Dimension value to be sharded.
oneof dim {
int64 dim_value = 1;
string dim_param = 2;
}
// This field MUST be present for this version of the IR.
// Number of shards to split dim into.
int64 num_shards = 3;
}
// Training information
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data.
//
// The semantics of the initialization-step is that the initializers
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
// initialized as specified by the initializers in the graph, and then
// updated by the "initialization_binding" in every instance in
// ModelProto.training_info.
//
// The field "algorithm" defines a computation graph which represents a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto {
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// and can have multiple outputs. Usually, trainable tensors in neural
// networks are randomly initialized. To achieve that, for each tensor,
// the user can put a random number operator such as RandomNormal or
// RandomUniform in TrainingInfoProto.initialization.node and assign its
// random output to the specific tensor using "initialization_binding".
// This graph can also set the initializers in "algorithm" in the same
// TrainingInfoProto; a use case is resetting the number of training
// iteration to zero.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Thus, no initializer would be changed by default.
GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count.
//
// An execution of the training algorithm step is performed by executing the
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Evaluating the default training step never
// update any initializers.
GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details.
//
// By default, this field is empty and no initializer would be changed
// by the execution of "initialization".
repeated StringStringEntryProto initialization_binding = 3;
// Gradient-based training is usually an iterative procedure. In one gradient
// descent iteration, we apply
//
// x = x - r * g
//
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
// into the training graph, we split the update equation into
//
// y = x - r * g
// x = y
//
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
// tell that "y" should be assigned to "x", the field "update_binding" may
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
// and "y" (value of StringStringEntryProto).
// For a neural network with multiple trainable (mutable) tensors, there can
// be multiple key-value pairs in "update_binding".
//
// The initializers appears as keys in "update_binding" are considered
// mutable variables. This implies some behaviors
// as described below.
//
// 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one
// variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
// This field usually contains names of trainable tensors
// (in ModelProto.graph), optimizer states such as momentums in advanced
// stochastic gradient methods (in TrainingInfoProto.graph),
// and number of training iterations (in TrainingInfoProto.graph).
//
// By default, this field is empty and no initializer would be changed
// by the execution of "algorithm".
repeated StringStringEntryProto update_binding = 4;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
int64 ir_version = 1;
// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;
// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_name = 2;
// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_version = 3;
// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
string domain = 4;
// The version of the graph encoded. See Version enum below.
int64 model_version = 5;
// A human-readable documentation for this model. Markdown is allowed.
string doc_string = 6;
// The parameterized graph that is evaluated to execute the model.
GraphProto graph = 7;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
// Training-specific information. Sequentially executing all stored
// `TrainingInfoProto.algorithm`s and assigning their outputs following
// the corresponding `TrainingInfoProto.update_binding`s is one training
// iteration. Similarly, to initialize the model
// (as if training hasn't happened), the user should sequentially execute
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
// using `TrainingInfoProto.initialization_binding`s.
//
// If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// The (domain, name, overload) tuple must be unique across the function protos in this list.
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard operator sets are given higher priority or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
// Describes different target configurations for a multi-device use case.
// A model MAY describe multiple multi-device configurations for execution.
repeated DeviceConfigurationProto configuration = 26;
};
// DeviceConfigurationProto describes a multi-device configuration for a model.
message DeviceConfigurationProto {
// This field MUST be present for this version of the IR.
// Name of the configuration.
string name = 1;
// This field MUST be present for this version of the IR.
// Number of devices inside this configuration.
int32 num_devices = 2;
// Optional names of the devices. MUST be length of num_devices if provided.
repeated string device = 3;
}
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
string key = 1;
string value = 2;
};
message TensorAnnotation {
string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format.
repeated SparseTensorProto sparse_initializer = 15;
// A human-readable documentation for this graph. Markdown is allowed.
string doc_string = 10;
// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 16;
reserved 3, 4, 6 to 9;
reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero
// 4-bit integer data types
UINT4 = 21; // Unsigned integer in range [0, 15]
INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation
// 4-bit floating point data types
FLOAT4E2M1 = 23;
// Future extensions go here.
}
// The shape of the tensor.
repeated int64 dims = 1;
// The data type of the tensor.
// This field MUST have a valid TensorProto.DataType value
int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
// the current TensorProto.
message Segment {
int64 begin = 1;
int64 end = 2;
}
Segment segment = 3;
// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4:
// - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer
// representation before being written to the buffer.
// - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte.
// The first element is stored in the 4 least significant bits (LSB),
// and the second element is stored in the 4 most significant bits (MSB).
//
// Consequently:
// - For data types with a bit-width of 8 or greater, each `int32_data` stores one element.
// - For 4-bit data types, each `int32_data` stores two elements.
//
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1
repeated int32 int32_data = 5 [packed = true];
// For strings.
// Each element of string_data is a UTF-8 encoded Unicode
// string. No trailing null, no leading BOM. The protobuf "string"
// scalar type is not used to match ML community conventions.
// When this field is present, the data_type field MUST be STRING
repeated bytes string_data = 6;
// For int64.
// When this field is present, the data_type field MUST be INT64
repeated int64 int64_data = 7 [packed = true];
// Optionally, a name for the tensor.
string name = 8; // namespace Value
// A human-readable documentation for this tensor. Markdown is allowed.
string doc_string = 12;
// Serializations can either use one of the fields above, or use this
// raw bytes field. The only exception is the string case, where one is
// required to store the content in the repeated bytes string_data field.
//
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
// Floating-point data types MUST be stored in IEEE 754 format.
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
// uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB.
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
// variable length storage, and may lead to smaller binary footprint.
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
DataLocation data_location = 14;
// For double
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
repeated double double_data = 10 [packed = true];
// For uint64 and uint32 values
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 16;
}
// A serialized sparse-tensor value
message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats.
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
// corresponding to the j-th index of the i-th value (in the values tensor).
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
// must be the linearized-index of the i-th value (in the values tensor).
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
// using the shape provided below.
// The indices must appear in ascending order without duplication.
// In the first format, the ordering is lexicographic-ordering:
// e.g., index-value [1,4] must appear before [2,1]
TensorProto indices = 2;
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
repeated int64 dims = 3;
}
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
string denotation = 3;
};
repeated Dimension dim = 1;
}
// Types
//
// The standard ONNX data types.
message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
// repeated T
message Sequence {
// The type and optional shape of each element of the sequence.
// This field MUST be present for this version of the IR.
TypeProto elem_type = 1;
};
// map<K,V>
message Map {
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
int32 key_type = 1;
// This field MUST be present for this version of the IR.
TypeProto value_type = 2;
};
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
oneof value {
// The type of a tensor.
Tensor tensor_type = 1;
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
// as input and output to graphs and nodes. These types are needed to naturally
// support classical ML operators. DNN operators SHOULD restrict their input
// and output types to tensors.
// The type of a sequence.
Sequence sequence_type = 4;
// The type of a map.
Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
}
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
string denotation = 6;
}
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
// The empty string ("") or absence of this field implies the operator
// set that is defined as part of the ONNX specification.
// This field MUST be present in this version of the IR when referring to any other operator set.
string domain = 1;
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
int64 version = 2;
}
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar to op_type in NodeProto.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
string domain = 10;
// The overload identifier of the function.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
string overload = 13;
// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;

View File

@ -83,10 +83,14 @@ impl std::str::FromStr for Color {
}
impl Color {
/// Creates a new Color from RGBA components.
/// Each component is an 8-bit value (0-255).
const fn from_rgba(r: u8, g: u8, b: u8, a: u8) -> Self {
Self(((r as u32) << 24) | ((g as u32) << 16) | ((b as u32) << 8) | (a as u32))
}
/// Returns the color components as RGBA tuple.
/// Each component is an 8-bit value (0-255).
pub fn rgba(&self) -> (u8, u8, u8, u8) {
let r = ((self.0 >> 24) & 0xff) as u8;
let g = ((self.0 >> 16) & 0xff) as u8;
@ -95,70 +99,89 @@ impl Color {
(r, g, b, a)
}
/// Returns the RGB components as a tuple, excluding alpha.
/// Each component is an 8-bit value (0-255).
pub fn rgb(&self) -> (u8, u8, u8) {
let (r, g, b, _) = self.rgba();
(r, g, b)
}
/// Returns the BGR components as a tuple.
/// Useful for OpenCV-style color formats.
pub fn bgr(&self) -> (u8, u8, u8) {
let (r, g, b) = self.rgb();
(b, g, r)
}
/// Returns the red component (0-255).
pub fn r(&self) -> u8 {
self.rgba().0
}
/// Returns the green component (0-255).
pub fn g(&self) -> u8 {
self.rgba().1
}
/// Returns the blue component (0-255).
pub fn b(&self) -> u8 {
self.rgba().2
}
/// Returns the alpha component (0-255).
pub fn a(&self) -> u8 {
self.rgba().3
}
/// Returns the color as a hex string in the format "#RRGGBBAA".
pub fn hex(&self) -> String {
format!("#{:08x}", self.0)
}
/// Creates a new color with the specified alpha value while keeping RGB components.
pub fn with_alpha(self, a: u8) -> Self {
let (r, g, b) = self.rgb();
(r, g, b, a).into()
}
/// Creates a black color (RGB: 0,0,0) with full opacity.
pub fn black() -> Color {
[0, 0, 0, 255].into()
}
/// Creates a white color (RGB: 255,255,255) with full opacity.
pub fn white() -> Color {
[255, 255, 255, 255].into()
}
/// Creates a green color (RGB: 0,255,0) with full opacity.
pub fn green() -> Color {
[0, 255, 0, 255].into()
}
/// Creates a red color (RGB: 255,0,0) with full opacity.
pub fn red() -> Color {
[255, 0, 0, 255].into()
}
/// Creates a blue color (RGB: 0,0,255) with full opacity.
pub fn blue() -> Color {
[0, 0, 255, 255].into()
}
/// Creates a color palette from a slice of convertible values.
pub fn create_palette<A: Into<Self> + Copy>(xs: &[A]) -> Vec<Self> {
xs.iter().copied().map(Into::into).collect()
}
/// Attempts to create a color palette from hex color strings.
/// Returns an error if any string is not a valid hex color.
pub fn try_create_palette(xs: &[&str]) -> Result<Vec<Self>> {
xs.iter().map(|x| x.parse()).collect()
}
/// Creates a palette of random colors with the specified size.
pub fn palette_rand(n: usize) -> Vec<Self> {
let mut rng = rand::rng();
let xs: Vec<(u8, u8, u8)> = (0..n)
@ -174,20 +197,24 @@ impl Color {
Self::create_palette(&xs)
}
/// Returns a predefined palette of 20 base colors.
pub fn palette_base_20() -> Vec<Self> {
Self::create_palette(&PALETTE_BASE)
}
/// Returns a cotton candy themed palette of 5 colors.
pub fn palette_cotton_candy_5() -> Result<Vec<Self>> {
Self::try_create_palette(&["#ff595e", "#ffca3a", "#8ac926", "#1982c4", "#6a4c93"])
}
/// Returns a tropical sunrise themed palette of 5 colors.
#[inline(always)]
pub fn palette_tropical_sunrise_5() -> Result<Vec<Self>> {
// https://colorkit.co/palette/e12729-f37324-f8cc1b-72b043-007f4e/
Self::try_create_palette(&["#e12729", "#f37324", "#f8cc1b", "#72b043", "#007f4e"])
}
/// Returns a rainbow themed palette of 10 colors.
pub fn palette_rainbow_10() -> Vec<Self> {
Self::create_palette(&[
0xff595eff, 0xff924cff, 0xffca3aff, 0xc5ca30ff, 0x8ac926ff, 0x52a675ff, 0x1982c4ff,
@ -195,13 +222,16 @@ impl Color {
])
}
/// Returns the COCO dataset color palette with 80 colors.
pub fn palette_coco_80() -> Vec<Self> {
Self::create_palette(&PALETTE_COCO_80)
}
/// Returns the Pascal VOC dataset color palette with 21 colors.
pub fn palette_pascal_voc_21() -> Vec<Self> {
Self::create_palette(&PALETTE_PASCAL_VOC_20)
}
/// Returns the ADE20K dataset color palette with 150 colors.
pub fn palette_ade20k_150() -> Vec<Self> {
Self::create_palette(&PALETTE_ADE20K_150)
}

View File

@ -1,5 +1,6 @@
use crate::Color;
/// 256-color colormap variants for data visualization.
#[derive(Clone, Debug, PartialEq)]
pub enum ColorMap256 {
Turbo,

View File

@ -1,5 +1,6 @@
use crate::{Color, Style, StyleColors, TextRenderer};
/// Drawing context containing styles and renderers for visualization.
#[derive(Debug, Clone)]
pub struct DrawContext<'a> {
pub text_renderer: &'a TextRenderer,

View File

@ -7,6 +7,7 @@ mod prob;
use crate::{DrawContext, Style, Y};
/// Defines an interface for drawing objects on an image canvas
pub trait Drawable {
fn get_local_style(&self) -> Option<&Style> {
None

View File

@ -2,6 +2,7 @@ use aksr::Builder;
use crate::{Color, ColorMap256, Skeleton};
/// Style configuration for drawing annotations and visualizations.
#[derive(Debug, Clone, Builder, PartialEq)]
pub struct Style {
visible: bool, // For ALL
@ -131,6 +132,7 @@ impl Style {
}
}
/// Color configuration for different visual elements.
#[derive(Debug, Builder, Default, Clone, PartialEq, Copy)]
pub struct StyleColors {
pub outline: Option<Color>,
@ -139,6 +141,7 @@ pub struct StyleColors {
pub text_bg: Option<Color>,
}
/// Text positioning options relative to visual elements.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TextLoc {
InnerTopLeft,

View File

@ -5,6 +5,7 @@ use image::{Rgba, RgbaImage};
use crate::{Color, Hub};
/// Text rendering engine with font management and styling capabilities.
#[derive(Builder, Clone, Debug)]
pub struct TextRenderer {
#[args(except(setter))]