From 70aeae9e010e300e1e85d54547f4b6fdddd4ffbf Mon Sep 17 00:00:00 2001 From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Sun, 8 Jun 2025 18:15:54 +0800 Subject: [PATCH] update (#109) * Add docs * Add mediapipe-selfie-segmenter model * Update README.md * Update RTMO model --- Cargo.toml | 5 +- README.md | 357 ++--- build.rs | 7 - .../mediapipe-selfie-segmentation/README.md | 10 + .../mediapipe-selfie-segmentation/main.rs | 49 + examples/rtmo/main.rs | 34 +- src/{utils => core}/config.rs | 3 +- src/{io => core}/dataloader.rs | 19 + src/{utils => core}/device.rs | 1 + src/{io => core}/dir.rs | 0 src/{utils => core}/dtype.rs | 1 + src/{utils => core}/dynconf.rs | 0 src/{inference => core}/engine.rs | 20 +- src/{io => core}/hub.rs | 0 src/{utils => core}/iiix.rs | 3 + src/{inference => core}/image.rs | 8 + src/{utils => core}/logits_sampler.rs | 3 + src/{io => core}/media.rs | 3 + src/{utils => core}/min_opt_max.rs | 0 src/core/mod.rs | 61 + src/{utils => core}/names.rs | 83 ++ src/core/onnx.rs | 1210 +++++++++++++++++ src/{utils => core}/ops.rs | 14 + src/{utils => core}/ort_config.rs | 4 +- src/{utils => core}/processor.rs | 1 + src/{utils => core}/processor_config.rs | 22 +- src/{utils => core}/retry.rs | 0 src/{utils => core}/scale.rs | 1 + src/{utils => core}/task.rs | 0 src/{utils => core}/traits.rs | 13 + src/{utils => core}/ts.rs | 1 + src/{utils/mod.rs => core/utils.rs} | 118 +- src/{utils => core}/version.rs | 1 + src/{inference => core}/x.rs | 0 src/{inference => core}/xs.rs | 1 + src/inference/mod.rs | 36 - src/io/mod.rs | 9 - src/lib.rs | 62 +- src/models/db/impl.rs | 1 + src/models/grounding_dino/impl.rs | 1 + src/models/mediapipe_segmenter/README.md | 9 + src/models/mediapipe_segmenter/config.rs | 22 + src/models/mediapipe_segmenter/impl.rs | 91 ++ src/models/mediapipe_segmenter/mod.rs | 4 + src/models/mod.rs | 2 + src/models/owl/impl.rs | 1 + src/models/rtmo/config.rs | 7 + src/models/rtmo/impl.rs | 113 +- src/models/sam/config.rs | 21 + src/models/sam/impl.rs | 28 +- src/models/sam/mod.rs | 3 + src/models/sam2/config.rs | 12 + src/models/sam2/impl.rs | 22 + src/models/svtr/impl.rs | 1 + src/models/trocr/config.rs | 22 +- src/models/trocr/impl.rs | 67 + src/models/yolo/config.rs | 40 +- src/models/yolo/impl.rs | 26 + src/models/yolo/preds.rs | 5 + src/{inference => results}/hbb.rs | 3 +- src/{inference => results}/instance_meta.rs | 2 +- src/{inference => results}/keypoint.rs | 2 +- src/{inference => results}/mask.rs | 5 +- src/results/mod.rs | 22 + src/{inference => results}/obb.rs | 3 +- src/{inference => results}/polygon.rs | 2 +- src/{inference => results}/prob.rs | 3 +- src/{inference => results}/skeleton.rs | 14 + src/{inference => results}/text.rs | 10 +- src/{inference => results}/y.rs | 0 src/utils/onnx.proto3 | 978 ------------- src/viz/color.rs | 30 + src/viz/colormap256.rs | 1 + src/viz/draw_ctx.rs | 1 + src/viz/drawable/mod.rs | 1 + src/viz/styles.rs | 3 + src/viz/text_renderer.rs | 1 + 77 files changed, 2325 insertions(+), 1414 deletions(-) delete mode 100644 build.rs create mode 100644 examples/mediapipe-selfie-segmentation/README.md create mode 100644 examples/mediapipe-selfie-segmentation/main.rs rename src/{utils => core}/config.rs (99%) rename src/{io => core}/dataloader.rs (95%) rename src/{utils => core}/device.rs (99%) rename src/{io => core}/dir.rs (100%) rename src/{utils => core}/dtype.rs (98%) rename src/{utils => core}/dynconf.rs (100%) rename src/{inference => core}/engine.rs (98%) rename src/{io => core}/hub.rs (100%) rename src/{utils => core}/iiix.rs (80%) rename src/{inference => core}/image.rs (95%) rename src/{utils => core}/logits_sampler.rs (93%) rename src/{io => core}/media.rs (96%) rename src/{utils => core}/min_opt_max.rs (100%) create mode 100644 src/core/mod.rs rename src/{utils => core}/names.rs (96%) create mode 100644 src/core/onnx.rs rename src/{utils => core}/ops.rs (96%) rename src/{utils => core}/ort_config.rs (99%) rename src/{utils => core}/processor.rs (99%) rename src/{utils => core}/processor_config.rs (90%) rename src/{utils => core}/retry.rs (100%) rename src/{utils => core}/scale.rs (98%) rename src/{utils => core}/task.rs (100%) rename src/{utils => core}/traits.rs (70%) rename src/{utils => core}/ts.rs (99%) rename src/{utils/mod.rs => core/utils.rs} (55%) rename src/{utils => core}/version.rs (95%) rename src/{inference => core}/x.rs (100%) rename src/{inference => core}/xs.rs (97%) delete mode 100644 src/inference/mod.rs delete mode 100644 src/io/mod.rs create mode 100644 src/models/mediapipe_segmenter/README.md create mode 100644 src/models/mediapipe_segmenter/config.rs create mode 100644 src/models/mediapipe_segmenter/impl.rs create mode 100644 src/models/mediapipe_segmenter/mod.rs rename src/{inference => results}/hbb.rs (98%) rename src/{inference => results}/instance_meta.rs (97%) rename src/{inference => results}/keypoint.rs (99%) rename src/{inference => results}/mask.rs (93%) create mode 100644 src/results/mod.rs rename src/{inference => results}/obb.rs (97%) rename src/{inference => results}/polygon.rs (99%) rename src/{inference => results}/prob.rs (92%) rename src/{inference => results}/skeleton.rs (80%) rename src/{inference => results}/text.rs (76%) rename src/{inference => results}/y.rs (100%) delete mode 100644 src/utils/onnx.proto3 diff --git a/Cargo.toml b/Cargo.toml index 8dc658b..dc42d58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "usls" edition = "2021" -version = "0.1.0-beta.4" +version = "0.1.0-rc.1" rust-version = "1.82" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" @@ -10,6 +10,7 @@ license = "MIT" readme = "README.md" exclude = ["assets/*", "examples/*", "runs/*", "benches/*", "tests/*"] + [dependencies] anyhow = { version = "1" } aksr = { version = "0.0.3" } @@ -47,8 +48,6 @@ tokenizers = { version = "0.21.1" } paste = "1.0.15" base64ct = "=1.7.3" -[build-dependencies] -prost-build = "0.13.5" [dev-dependencies] argh = "0.1.13" diff --git a/README.md b/README.md index 90813d7..5888792 100644 --- a/README.md +++ b/README.md @@ -1,291 +1,146 @@

usls

- - - Rust MSRV - - - - ONNXRuntime MSRV - - - - CUDA MSRV - - - - cuDNN MSRV - - - - TensorRT MSRV - -

-

- - - Examples - - - - Documentation - -

-

- - + Rust CI Crates.io Version - - - Crates.io Downloads + + ONNXRuntime MSRV + + + Rust MSRV

-

- ⭐️ Star if helpful! ⭐️ -

-**usls** is an evolving Rust library focused on inference for advanced **vision** and **vision-language** models, along with practical vision utilities. +**usls** is a cross-platform Rust library powered by ONNX Runtime for efficient inference of SOTA vision and multi-modal models(typically under 1B parameters). -- **SOTA Model Inference:** Supports a wide range of state-of-the-art vision and multi-modal models (typically with fewer than 1B parameters). -- **Multi-backend Acceleration:** Supports CPU, CUDA, TensorRT, and CoreML. -- **Easy Data Handling:** Easily read images, video streams, and folders with iterator support. -- **Rich Result Types:** Built-in containers for common vision outputs like bounding boxes (Hbb, Obb), polygons, masks, etc. -- **Annotation & Visualization:** Draw and display inference results directly, similar to OpenCV's `imshow()`. +## πŸ“š Documentation +- [API Documentation](https://docs.rs/usls/latest/usls/) +- [Examples](./examples) -## 🧩 Supported Models +## πŸš€ Quick Start +```bash +# CPU +cargo run -r --example yolo # YOLOv8-n detect by default -- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLO11](https://github.com/ultralytics/ultralytics), [YOLOv12](https://github.com/sunsmarterjie/yolov12) -- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) -- **Vision Models**: [RT-DETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro), [FastViT](https://github.com/apple/ml-fastvit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MobileOne](https://github.com/apple/ml-mobileone) -- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [jina-clip-v1-v2](https://huggingface.co/jinaai/jina-clip-v1), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242), [Moondream2](https://github.com/vikhyat/moondream/tree/main) -- **OCR-Related Models**: [FAST](https://github.com/czczup/FAST), [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947), [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159), [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html), [TrOCR](https://huggingface.co/microsoft/trocr-base-printed), [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) +# NVIDIA CUDA +cargo run -r -F cuda --example yolo -- --device cuda:0 -
-Full list of supported models (click to expand) +# NVIDIA TensorRT +cargo run -r -F tensorrt --example yolo -- --device tensorrt:0 -| Model | Task / Description | Example | CoreML | CUDA
FP32 | CUDA
FP16 | TensorRT
FP32 | TensorRT
FP16 | -| -------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------- | ------ | -------------- | -------------- | ------------------ | ------------------ | -| [BEiT](https://github.com/microsoft/unilm/tree/master/beit) | Image Classification | [demo](examples/beit) | βœ… | βœ… | βœ… | | | -| [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) | Image Classification | [demo](examples/convnext) | βœ… | βœ… | βœ… | | | -| [FastViT](https://github.com/apple/ml-fastvit) | Image Classification | [demo](examples/fastvit) | βœ… | βœ… | βœ… | | | -| [MobileOne](https://github.com/apple/ml-mobileone) | Image Classification | [demo](examples/mobileone) | βœ… | βœ… | βœ… | | | -| [DeiT](https://github.com/facebookresearch/deit) | Image Classification | [demo](examples/deit) | βœ… | βœ… | βœ… | | | -| [DINOv2](https://github.com/facebookresearch/dinov2) | VisionΒ Embedding | [demo](examples/dinov2) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLOv5](https://github.com/ultralytics/yolov5) | Image Classification
Object Detection
Instance Segmentation | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLOv8
YOLO11](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Image Classification
Oriented Object Detection
Keypoint Detection | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLOv12](https://github.com/sunsmarterjie/yolov12) | Object Detection | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [RT-DETR](https://github.com/lyuwenyu/RT-DETR) | Object Detection | [demo](examples/rtdetr) | βœ… | βœ… | βœ… | | | -| [RF-DETR](https://github.com/roboflow/rf-detr) | Object Detection | [demo](examples/rfdetr) | βœ… | βœ… | βœ… | | | -| [PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) | Object Detection | [demo](examples/picodet-layout) | βœ… | βœ… | βœ… | | | -| [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) | Object Detection | [demo](examples/picodet-layout) | βœ… | βœ… | βœ… | | | -| [D-FINE](https://github.com/manhbd-22022602/D-FINE) | Object Detection | [demo](examples/d-fine) | βœ… | βœ… | βœ… | | | -| [DEIM](https://github.com/ShihuaHuang95/DEIM) | Object Detection | [demo](examples/deim) | βœ… | βœ… | βœ… | | | -| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | βœ… | βœ… | βœ… | ❌ | ❌ | -| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | βœ… | βœ… | βœ… | | | -| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) | βœ… | βœ… | βœ… | | | -| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) | βœ… | βœ… | βœ… | | | -| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) | βœ… | βœ… | βœ… | | | -| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) | βœ… | βœ… | βœ… | | | -| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Open-Set Detection With Language | [demo](examples/yolo) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | βœ… | βœ… | βœ… | | | -| [CLIP](https://github.com/openai/CLIP) | Vision-Language Embedding | [demo](examples/clip) | βœ… | βœ… | βœ… | ❌ | ❌ | -| [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](examples/clip) | βœ… | βœ… | βœ… | ❌ | ❌ | -| [jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2) | Vision-Language Embedding | [demo](examples/clip) | βœ… | βœ… | βœ… | ❌ | ❌ | -| [mobileclip](https://github.com/apple/ml-mobileclip) | Vision-Language Embedding | [demo](examples/clip) | βœ… | βœ… | βœ… | ❌ | ❌ | -| [BLIP](https://github.com/salesforce/BLIP) | Image Captioning | [demo](examples/blip) | βœ… | βœ… | βœ… | ❌ | ❌ | -| [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [FAST](https://github.com/czczup/FAST) | Text Detection | [demo](examples/fast) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [LinkNet](https://arxiv.org/abs/1707.03718) | Text Detection | [demo](examples/linknet) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) | Tabel Recognition | [demo](examples/slanet) | βœ… | βœ… | βœ… | | | -| [TrOCR](https://huggingface.co/microsoft/trocr-base-printed) | Text Recognition | [demo](examples/trocr) | βœ… | βœ… | βœ… | | | -| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [DepthAnything v1
DepthAnything v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | βœ… | βœ… | βœ… | ❌ | ❌ | -| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | βœ… | βœ… | βœ… | | | -| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | βœ… | βœ… | βœ… | βœ… | βœ… | -| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | βœ… | βœ… | βœ… | | | -| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | βœ… | βœ… | βœ… | | | -| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
ImageΒ Caption
Visual Question Answering | [demo](examples/moondream2) | βœ… | βœ… | βœ… | | | -| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | βœ… | βœ… | βœ… | | | -| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | βœ… | βœ… | βœ… | | | -| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation
Background Removal | [demo](examples/rmbg) | βœ… | βœ… | βœ… | | | -| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation
Background Removal | [demo](examples/rmbg) | βœ… | βœ… | βœ… | | | +# Apple Silicon CoreML +cargo run -r -F coreml --example yolo -- --device coreml -
+# Intel OpenVINO +cargo run -r -F openvino -F ort-load-dynamic --example yolo -- --device openvino:CPU - -## πŸ› οΈ Installation - -To get started, you'll need: - -### 1. Protocol Buffers Compiler (`protoc`) -Required for building the project. [Official installation guide](https://protobuf.dev/installation/) -```shell -# Linux (apt) -sudo apt install -y protobuf-compiler - -# macOS (Homebrew) -brew install protobuf - -# Windows (Winget) -winget install protobuf - -# Verify installation -protoc --version # Should be 3.x or higher +# And other EPs... ``` -### 2. Rust Toolchain -```shell -# Install Rust and Cargo -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -``` -### 3. Add usls to Your Project +## βš™οΈ Installation Add the following to your `Cargo.toml`: + ```toml [dependencies] # Recommended: Use GitHub version -usls = { git = "https://github.com/jamjamjon/usls" } +usls = { git = "https://github.com/jamjamjon/usls", features = [ "cuda" ] } # Alternative: Use crates.io version usls = "latest-version" ``` -> **Note:** **The GitHub version is recommended as it contains the latest updates.** -## ⚑ Cargo Features -- **ONNXRuntime-related features (enabled by default)**, provide model inference and model zoo support: - - **`ort-download-binaries`** (**default**): Automatically downloads prebuilt `ONNXRuntime` binaries for supported platforms. Provides core model loading and inference capabilities using the `CPU` execution provider. - - **`ort-load-dynamic `** Dynamic linking. You'll need to compile `ONNXRuntime` from [source](https://github.com/microsoft/onnxruntime) or download a [precompiled package](https://github.com/microsoft/onnxruntime/releases), and then link it manually. [See the guide here](https://ort.pyke.io/setup/linking#dynamic-linking). - - - **`cuda`**: Enables the NVIDIA `CUDA` provider. Requires `CUDA` toolkit and `cuDNN` installed. - - **`trt`**: Enables the NVIDIA `TensorRT` provider. Requires `TensorRT` libraries installed. - - **`mps`**: Enables the Apple `CoreML` provider for macOS. +## ⚑ Supported Models +
+Click to expand -- **If you only need basic features** (such as image/video reading, result visualization, etc.), you can disable the default features to minimize dependencies: - ```shell - usls = { git = "https://github.com/jamjamjon/usls", default-features = false } - ``` - - **`video`** : Enable video stream reading, and video writing.(Note: Powered by [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb). Check their repositories for potential issues.) +| Model | Task / Description | Example | +| ----- | ----------------- | ------- | +| [BEiT](https://github.com/microsoft/unilm/tree/master/beit) | Image Classification | [demo](examples/beit) | +| [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) | Image Classification | [demo](examples/convnext) | +| [FastViT](https://github.com/apple/ml-fastvit) | Image Classification | [demo](examples/fastvit) | +| [MobileOne](https://github.com/apple/ml-mobileone) | Image Classification | [demo](examples/mobileone) | +| [DeiT](https://github.com/facebookresearch/deit) | Image Classification | [demo](examples/deit) | +| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision Embedding | [demo](examples/dinov2) | +| [YOLOv5](https://github.com/ultralytics/yolov5) | Image Classification
Object Detection
Instance Segmentation | [demo](examples/yolo) | +| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) | +| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | +| [YOLOv8
YOLO11](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Image Classification
Oriented Object Detection
Keypoint Detection | [demo](examples/yolo) | +| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | +| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | +| [YOLOv12](https://github.com/sunsmarterjie/yolov12) | Object Detection | [demo](examples/yolo) | +| [RT-DETR](https://github.com/lyuwenyu/RT-DETR) | Object Detection | [demo](examples/rtdetr) | +| [RF-DETR](https://github.com/roboflow/rf-detr) | Object Detection | [demo](examples/rfdetr) | +| [PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) | Object Detection | [demo](examples/picodet-layout) | +| [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) | Object Detection | [demo](examples/picodet-layout) | +| [D-FINE](https://github.com/manhbd-22022602/D-FINE) | Object Detection | [demo](examples/d-fine) | +| [DEIM](https://github.com/ShihuaHuang95/DEIM) | Object Detection | [demo](examples/deim) | +| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | +| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | +| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) | +| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) | +| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) | +| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) | +| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | +| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Open-Set Detection With Language | [demo](examples/yolo) | +| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | +| [CLIP](https://github.com/openai/CLIP) | Vision-Language Embedding | [demo](examples/clip) | +| [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](examples/clip) | +| [jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2) | Vision-Language Embedding | [demo](examples/clip) | +| [mobileclip](https://github.com/apple/ml-mobileclip) | Vision-Language Embedding | [demo](examples/clip) | +| [BLIP](https://github.com/salesforce/BLIP) | Image Captioning | [demo](examples/blip) | +| [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | +| [FAST](https://github.com/czczup/FAST) | Text Detection | [demo](examples/fast) | +| [LinkNet](https://arxiv.org/abs/1707.03718) | Text Detection | [demo](examples/linknet) | +| [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | +| [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) | Tabel Recognition | [demo](examples/slanet) | +| [TrOCR](https://huggingface.co/microsoft/trocr-base-printed) | Text Recognition | [demo](examples/trocr) | +| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | +| [DepthAnything v1
DepthAnything v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | +| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | +| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | +| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | +| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | +| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | +| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | +| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | +| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation
Background Removal | [demo](examples/rmbg) | +| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation
Background Removal | [demo](examples/rmbg) | +| [MediaPipe: Selfie-segmentation](https://ai.google.dev/edge/mediapipe/solutions/vision/image_segmenter) | Image Segmentation | [demo](examples/mediapipe-selfie-segmentation) | -## ✨ Example - -- Model Inference - ```shell - cargo run -r --example yolo # CPU - cargo run -r -F cuda --example yolo -- --device cuda:0 # GPU - ``` - -- Reading Images - ```rust - // Read a single image - let image = DataLoader::try_read_one("./assets/bus.jpg")?; - - // Read multiple images - let images = DataLoader::try_read_n(&["./assets/bus.jpg", "./assets/cat.png"])?; - - // Read all images in a folder - let images = DataLoader::try_read_folder("./assets")?; - - // Read images matching a pattern (glob) - let images = DataLoader::try_read_pattern("./assets/*.Jpg")?; - - // Load images and iterate - let dl = DataLoader::new("./assets")?.with_batch(2).build()?; - for images in dl.iter() { - // Code here - } - ``` - -- Reading Video - ```rust - let dl = DataLoader::new("http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4")? - .with_batch(1) - .with_nf_skip(2) - .with_progress_bar(true) - .build()?; - for images in dl.iter() { - // Code here - } - ``` - -- Annotate - ```rust - let annotator = Annotator::default(); - let image = DataLoader::try_read_one("./assets/bus.jpg")?; - // hbb - let hbb = Hbb::default() - .with_xyxy(669.5233, 395.4491, 809.0367, 878.81226) - .with_id(0) - .with_name("person") - .with_confidence(0.87094545); - let _ = annotator.annotate(&image, &hbb)?; - - // keypoints - let keypoints: Vec = vec![ - Keypoint::default() - .with_xy(139.35767, 443.43655) - .with_id(0) - .with_name("nose") - .with_confidence(0.9739332), - Keypoint::default() - .with_xy(147.38545, 434.34055) - .with_id(1) - .with_name("left_eye") - .with_confidence(0.9098319), - Keypoint::default() - .with_xy(128.5701, 434.07516) - .with_id(2) - .with_name("right_eye") - .with_confidence(0.9320564), - ]; - let _ = annotator.annotate(&image, &keypoints)?; - ``` +
-- Visualizing Inference Results and Exporting Video - ```rust - let dl = DataLoader::new(args.source.as_str())?.build()?; - let mut viewer = Viewer::default().with_window_scale(0.5); +## πŸ“¦ Cargo Features +- **`ort-download-binaries`** (**default**): Automatically downloads prebuilt ONNXRuntime binaries for supported platforms +- **`ort-load-dynamic`**: Dynamic linking to ONNXRuntime libraries ([Guide](https://ort.pyke.io/setup/linking#dynamic-linking)) +- **`video`**: Enable video stream reading and writing (via [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb)) +- **`cuda`**: NVIDIA CUDA GPU acceleration support +- **`tensorrt`**: NVIDIA TensorRT optimization for inference acceleration +- **`coreml`**: Apple CoreML acceleration for macOS/iOS devices +- **`openvino`**: Intel OpenVINO toolkit for CPU/GPU/VPU acceleration +- **`onednn`**: Intel oneDNN (formerly MKL-DNN) for CPU optimization +- **`directml`**: Microsoft DirectML for Windows GPU acceleration +- **`xnnpack`**: Google XNNPACK for mobile and edge device optimization +- **`rocm`**: AMD ROCm platform for GPU acceleration +- **`cann`**: Huawei CANN (Compute Architecture for Neural Networks) support +- **`rknpu`**: Rockchip NPU acceleration +- **`acl`**: Arm Compute Library for Arm processors +- **`nnapi`**: Android Neural Networks API support +- **`armnn`**: Arm NN inference engine +- **`tvm`**: Apache TVM tensor compiler stack +- **`qnn`**: Qualcomm Neural Network SDK +- **`migraphx`**: AMD MIGraphX for GPU acceleration +- **`vitis`**: Xilinx Vitis AI for FPGA acceleration +- **`azure`**: Azure Machine Learning integration - for images in &dl { - // Check if the window exists and is open - if viewer.is_window_exist() && !viewer.is_window_open() { - break; - } - - // Show image in window - viewer.imshow(&images[0])?; - - // Handle key events and delay - if let Some(key) = viewer.wait_key(1) { - if key == usls::Key::Escape { - break; - } - } - - // Your custom code here - - // Write video frame (requires video feature) - // if args.save_video { - // viewer.write_video_frame(&images[0])?; - // } - } - ``` - -**All examples are located in the [examples](./examples/) directory.** ## ❓ FAQ -See issues or open a new discussion. +See [issues](https://github.com/jamjamjon/usls/issues) or open a new discussion. ## 🀝 Contributing diff --git a/build.rs b/build.rs deleted file mode 100644 index 6558285..0000000 --- a/build.rs +++ /dev/null @@ -1,7 +0,0 @@ -use std::io::Result; - -fn main() -> Result<()> { - prost_build::compile_protos(&["src/utils/onnx.proto3"], &["src"])?; - - Ok(()) -} diff --git a/examples/mediapipe-selfie-segmentation/README.md b/examples/mediapipe-selfie-segmentation/README.md new file mode 100644 index 0000000..cc68120 --- /dev/null +++ b/examples/mediapipe-selfie-segmentation/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r --example mediapipe-selfie-segmentation -- --dtype f16 +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/mediapipe/demo-selfie-segmentaion.jpg) diff --git a/examples/mediapipe-selfie-segmentation/main.rs b/examples/mediapipe-selfie-segmentation/main.rs new file mode 100644 index 0000000..adeb090 --- /dev/null +++ b/examples/mediapipe-selfie-segmentation/main.rs @@ -0,0 +1,49 @@ +use usls::{models::MediaPipeSegmenter, Annotator, Config, DataLoader}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + let args: Args = argh::from_env(); + + // build model + let config = Config::mediapipe_selfie_segmentater() + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) + .commit()?; + let mut model = MediaPipeSegmenter::new(config)?; + + // load image + let xs = DataLoader::try_read_n(&["images/selfie-segmenter.png"])?; + + // run + let ys = model.forward(&xs)?; + + // annotate + let annotator = + Annotator::default().with_mask_style(usls::Style::mask().with_mask_cutout(true)); + for (x, y) in xs.iter().zip(ys.iter()) { + annotator.annotate(x, y)?.save(format!( + "{}.jpg", + usls::Dir::Current + .base_dir_with_subs(&["runs", model.spec()])? + .join(usls::timestamp(None)) + .display(), + ))?; + } + + Ok(()) +} diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs index b68aef1..69ce542 100644 --- a/examples/rtmo/main.rs +++ b/examples/rtmo/main.rs @@ -1,21 +1,48 @@ use anyhow::Result; use usls::{models::RTMO, Annotator, Config, DataLoader, Style, SKELETON_COCO_19}; +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"fp16\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale: s, m, l + #[argh(option, default = "String::from(\"t\")")] + scale: String, +} + fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); + let args: Args = argh::from_env(); // build model - let mut model = RTMO::new(Config::rtmo_s().commit()?)?; + let config = match args.scale.as_str() { + "t" => Config::rtmo_t(), + "s" => Config::rtmo_s(), + "m" => Config::rtmo_m(), + "l" => Config::rtmo_l(), + _ => unreachable!(), + } + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) + .commit()?; + let mut model = RTMO::new(config)?; // load image let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; // run let ys = model.forward(&xs)?; - println!("ys: {:?}", ys); + // println!("ys: {:?}", ys); // annotate let annotator = Annotator::default() @@ -37,5 +64,8 @@ fn main() -> Result<()> { ))?; } + // summary + model.summary(); + Ok(()) } diff --git a/src/utils/config.rs b/src/core/config.rs similarity index 99% rename from src/utils/config.rs rename to src/core/config.rs index b6799db..6ba7e76 100644 --- a/src/utils/config.rs +++ b/src/core/config.rs @@ -1,12 +1,11 @@ use aksr::Builder; use crate::{ - impl_ort_config_methods, impl_processor_config_methods, models::{SamKind, YOLOPredsFormat}, ORTConfig, ProcessorConfig, Scale, Task, Version, }; -/// Config for building models and inference +/// Configuration for model inference including engines, processors, and task settings. #[derive(Builder, Debug, Clone)] pub struct Config { // Basics diff --git a/src/io/dataloader.rs b/src/core/dataloader.rs similarity index 95% rename from src/io/dataloader.rs rename to src/core/dataloader.rs index 4292dcf..0e31e18 100644 --- a/src/io/dataloader.rs +++ b/src/core/dataloader.rs @@ -550,9 +550,19 @@ trait DataLoaderIterator { } } +/// An iterator implementation for `DataLoader` that enables batch processing of images. +/// +/// This struct is created by the `into_iter` method on `DataLoader`. +/// It provides functionality for: +/// - Receiving batches of images through a channel +/// - Tracking progress with an optional progress bar +/// - Processing images in configurable batch sizes pub struct DataLoaderIntoIterator { + /// Channel receiver for getting batches of images receiver: mpsc::Receiver>, + /// Optional progress bar for tracking iteration progress progress_bar: Option, + /// Number of images to process in each batch batch_size: u64, } @@ -593,6 +603,15 @@ impl IntoIterator for DataLoader { } } +/// A borrowing iterator for `DataLoader` that enables batch processing of images. +/// +/// This iterator is created by the `iter()` method on `DataLoader`, allowing iteration +/// over batches of images without taking ownership of the `DataLoader`. +/// +/// # Fields +/// - `receiver`: A reference to the channel receiver that provides batches of images +/// - `progress_bar`: An optional reference to a progress bar for tracking iteration progress +/// - `batch_size`: The number of images to process in each batch pub struct DataLoaderIter<'a> { receiver: &'a mpsc::Receiver>, progress_bar: Option<&'a ProgressBar>, diff --git a/src/utils/device.rs b/src/core/device.rs similarity index 99% rename from src/utils/device.rs rename to src/core/device.rs index 5743193..4f2accf 100644 --- a/src/utils/device.rs +++ b/src/core/device.rs @@ -1,4 +1,5 @@ #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +/// Device types for model execution. pub enum Device { Cpu(usize), Cuda(usize), diff --git a/src/io/dir.rs b/src/core/dir.rs similarity index 100% rename from src/io/dir.rs rename to src/core/dir.rs diff --git a/src/utils/dtype.rs b/src/core/dtype.rs similarity index 98% rename from src/utils/dtype.rs rename to src/core/dtype.rs index ef73223..e195806 100644 --- a/src/utils/dtype.rs +++ b/src/core/dtype.rs @@ -1,4 +1,5 @@ #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +/// Data type enumeration for tensor elements. pub enum DType { #[default] Auto, diff --git a/src/utils/dynconf.rs b/src/core/dynconf.rs similarity index 100% rename from src/utils/dynconf.rs rename to src/core/dynconf.rs diff --git a/src/inference/engine.rs b/src/core/engine.rs similarity index 98% rename from src/inference/engine.rs rename to src/core/engine.rs index 6206f52..f11409a 100644 --- a/src/inference/engine.rs +++ b/src/core/engine.rs @@ -46,25 +46,39 @@ impl From for DType { } /// A struct for tensor attrs composed of the names, the dtypes, and the dimensions. +/// ONNX Runtime tensor attributes containing names, data types, and dimensions. #[derive(Builder, Debug, Clone, Default)] +/// ONNX Runtime tensor attributes containing metadata. pub struct OrtTensorAttr { + /// Tensor names. pub names: Vec, + /// Tensor data types. pub dtypes: Vec, + /// Tensor dimensions for each tensor. pub dimss: Vec>, } +/// ONNX I/O structure containing input/output attributes and session. #[derive(Debug)] pub struct OnnxIo { + /// Input tensor attributes. pub inputs: OrtTensorAttr, + /// Output tensor attributes. pub outputs: OrtTensorAttr, + /// ONNX Runtime session. pub session: Session, + /// ONNX model protocol buffer. pub proto: onnx::ModelProto, } +/// ONNX Runtime inference engine with configuration and session management. #[derive(Debug, Builder)] pub struct Engine { + /// Model file path. pub file: String, + /// Model specification string. pub spec: String, + /// Execution device. pub device: Device, #[args(inc)] pub iiixs: Vec, @@ -72,9 +86,13 @@ pub struct Engine { pub params: Option, #[args(aka = "memory")] pub wbmems: Option, + /// Input min-opt-max configurations. pub inputs_minoptmax: Vec>, + /// ONNX I/O structure. pub onnx: Option, + /// Timing statistics. pub ts: Ts, + /// Number of dry runs for warmup. pub num_dry_run: usize, // global @@ -158,7 +176,7 @@ impl Default for Engine { // cann cann_graph_inference: true, cann_dump_graphs: false, - cann_dump_om_model: false, + cann_dump_om_model: true, // nnapi nnapi_cpu_only: false, nnapi_disable_cpu: false, diff --git a/src/io/hub.rs b/src/core/hub.rs similarity index 100% rename from src/io/hub.rs rename to src/core/hub.rs diff --git a/src/utils/iiix.rs b/src/core/iiix.rs similarity index 80% rename from src/utils/iiix.rs rename to src/core/iiix.rs index 6db1626..8077032 100644 --- a/src/utils/iiix.rs +++ b/src/core/iiix.rs @@ -3,8 +3,11 @@ use crate::MinOptMax; /// A struct for input composed of the i-th input, the ii-th dimension, and the value. #[derive(Clone, Debug, Default)] pub struct Iiix { + /// Input index. pub i: usize, + /// Dimension index. pub ii: usize, + /// Min-Opt-Max value specification. pub x: MinOptMax, } diff --git a/src/inference/image.rs b/src/core/image.rs similarity index 95% rename from src/inference/image.rs rename to src/core/image.rs index 23b113f..cb47135 100644 --- a/src/inference/image.rs +++ b/src/core/image.rs @@ -9,6 +9,7 @@ use std::path::{Path, PathBuf}; use crate::{build_resizer_filter, Hub, Location, MediaType, X}; +/// Information about image transformation including source and destination dimensions. #[derive(Builder, Debug, Clone, Default)] pub struct ImageTransformInfo { pub width_src: u32, @@ -19,6 +20,7 @@ pub struct ImageTransformInfo { pub width_scale: f32, } +/// Image resize modes for different scaling strategies. #[derive(Debug, Clone, Default)] pub enum ResizeMode { /// StretchToFit @@ -30,6 +32,7 @@ pub enum ResizeMode { Letterbox, } +/// Image wrapper with metadata and transformation capabilities. #[derive(Builder, Clone)] pub struct Image { image: RgbImage, @@ -376,8 +379,13 @@ impl Image { } } +/// Extension trait for converting between vectors of different image types. +/// Provides methods to convert between `Vec` and `Vec`. pub trait ImageVecExt { + /// Converts the vector into a vector of `DynamicImage`s. fn into_dyns(self) -> Vec; + + /// Converts the vector into a vector of `Image`s. fn into_images(self) -> Vec; } diff --git a/src/utils/logits_sampler.rs b/src/core/logits_sampler.rs similarity index 93% rename from src/utils/logits_sampler.rs rename to src/core/logits_sampler.rs index a95435d..947349f 100644 --- a/src/utils/logits_sampler.rs +++ b/src/core/logits_sampler.rs @@ -1,9 +1,12 @@ use anyhow::Result; use rand::distr::{weighted::WeightedIndex, Distribution}; +/// Logits sampler for text generation with temperature and nucleus sampling. #[derive(Debug, Clone)] pub struct LogitsSampler { + /// Temperature parameter for controlling randomness in sampling. temperature: f32, + /// Top-p parameter for nucleus sampling. p: f32, } diff --git a/src/io/media.rs b/src/core/media.rs similarity index 96% rename from src/io/media.rs rename to src/core/media.rs index cbbb6fe..7322551 100644 --- a/src/io/media.rs +++ b/src/core/media.rs @@ -11,6 +11,7 @@ pub(crate) const STREAM_PROTOCOLS: &[&str] = &[ "rtsp://", "rtsps://", "rtspu://", "rtmp://", "rtmps://", "hls://", ]; +/// Media location type indicating local or remote source. #[derive(Debug, Clone, Default, Copy)] pub enum Location { #[default] @@ -18,6 +19,7 @@ pub enum Location { Remote, } +/// Stream type for media content. #[derive(Debug, Clone, Copy, Default)] pub enum StreamType { #[default] @@ -25,6 +27,7 @@ pub enum StreamType { Live, } +/// Media type classification for different content formats. #[derive(Debug, Clone, Copy, Default)] pub enum MediaType { #[default] diff --git a/src/utils/min_opt_max.rs b/src/core/min_opt_max.rs similarity index 100% rename from src/utils/min_opt_max.rs rename to src/core/min_opt_max.rs diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..a4d739e --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,61 @@ +#[macro_use] +mod ort_config; +#[macro_use] +mod processor_config; +mod config; +mod dataloader; +mod device; +mod dir; +mod dtype; +mod dynconf; +#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] +mod engine; +mod hub; +mod iiix; +mod image; +mod logits_sampler; +mod media; +mod min_opt_max; +mod names; +#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] +#[allow(clippy::all)] +pub(crate) mod onnx; +mod ops; +mod processor; +mod retry; +mod scale; +mod task; +mod traits; +mod ts; +mod utils; +mod version; +mod x; +mod xs; + +pub use config::*; +pub use dataloader::*; +pub use device::Device; +pub use dir::*; +pub use dtype::DType; +pub use dynconf::DynConf; +#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] +pub use engine::*; +pub use hub::*; +pub(crate) use iiix::Iiix; +pub use image::*; +pub use logits_sampler::LogitsSampler; +pub use media::*; +pub use min_opt_max::MinOptMax; +pub use names::*; +pub use ops::*; +pub use ort_config::ORTConfig; +pub use processor::*; +pub use processor_config::ProcessorConfig; +pub use scale::Scale; +pub use task::Task; +pub use traits::*; +pub use ts::Ts; +pub use utils::*; +pub use version::Version; +pub use x::X; +pub use xs::Xs; diff --git a/src/utils/names.rs b/src/core/names.rs similarity index 96% rename from src/utils/names.rs rename to src/core/names.rs index e4affef..cf25bbc 100644 --- a/src/utils/names.rs +++ b/src/core/names.rs @@ -1,3 +1,5 @@ +/// A comprehensive list of 4585 object categories used in the YOLOE model. +/// A comprehensive list of 4585 object categories used in the YOLOE object detection model. pub static NAMES_YOLOE_4585: [&str; 4585] = [ "3D CG rendering", "3D glasses", @@ -4586,6 +4588,13 @@ pub static NAMES_YOLOE_4585: [&str; 4585] = [ "zoo", ]; +/// Labels for DOTA (Dataset for Object deTection in Aerial images) v1.5 with 16 categories. +/// +/// DOTA is a large-scale dataset for object detection in aerial images. This version includes: +/// - Transportation objects (planes, ships, vehicles) +/// - Sports facilities (courts, fields) +/// - Infrastructure (bridges, harbors, storage tanks) +/// - Other aerial view objects pub const NAMES_DOTA_V1_5_16: [&str; 16] = [ "plane", "ship", @@ -4604,6 +4613,13 @@ pub const NAMES_DOTA_V1_5_16: [&str; 16] = [ "swimming pool", "container crane", ]; +/// Labels for DOTA (Dataset for Object deTection in Aerial images) v1.0 with 15 categories. +/// +/// Similar to DOTA v1.5 but excludes the "container crane" category. Includes: +/// - Transportation objects +/// - Sports facilities +/// - Infrastructure +/// - Other aerial view objects pub const NAMES_DOTA_V1_15: [&str; 15] = [ "plane", "ship", @@ -4622,6 +4638,13 @@ pub const NAMES_DOTA_V1_15: [&str; 15] = [ "swimming pool", ]; +/// Labels for document layout analysis using YOLO with 10 categories. +/// +/// These labels are used to identify different components in document layout analysis: +/// - Text elements (title, plain text) +/// - Visual elements (figures and their captions) +/// - Tabular elements (tables, captions, footnotes) +/// - Mathematical elements (formulas and captions) pub const NAMES_YOLO_DOCLAYOUT_10: [&str; 10] = [ "title", "plain text", @@ -4634,6 +4657,14 @@ pub const NAMES_YOLO_DOCLAYOUT_10: [&str; 10] = [ "isolate_formula", "formula_caption", ]; +/// Labels for PicoDet document layout analysis with 17 categories. +/// +/// A comprehensive set of labels for detailed document structure analysis, including: +/// - Textual elements (titles, paragraphs, abstracts) +/// - Visual elements (images, figures) +/// - Structural elements (headers, footers) +/// - Reference elements (citations, footnotes) +/// - Special elements (algorithms, seals) pub const NAMES_PICODET_LAYOUT_17: [&str; 17] = [ "paragraph_title", "image", @@ -4653,8 +4684,26 @@ pub const NAMES_PICODET_LAYOUT_17: [&str; 17] = [ "footer", "seal", ]; +/// Simplified PicoDet document layout labels with 3 basic categories. +/// +/// A minimal set of labels for basic document layout analysis: +/// - Images +/// - Tables +/// - Seals (official stamps or marks) pub const NAMES_PICODET_LAYOUT_3: [&str; 3] = ["image", "table", "seal"]; +/// Core PicoDet document layout labels with 5 essential categories. +/// +/// A balanced set of labels for common document layout analysis tasks: +/// - Textual content (Text, Title, List) +/// - Visual content (Figure) +/// - Structured content (Table) pub const NAMES_PICODET_LAYOUT_5: [&str; 5] = ["Text", "Title", "List", "Table", "Figure"]; +/// COCO dataset keypoint labels for human pose estimation with 17 points. +/// +/// Standard keypoints for human body pose detection including: +/// - Facial features (eyes, nose, ears) +/// - Upper body joints (shoulders, elbows, wrists) +/// - Lower body joints (hips, knees, ankles) pub const NAMES_COCO_KEYPOINTS_17: [&str; 17] = [ "nose", "left_eye", @@ -4675,6 +4724,15 @@ pub const NAMES_COCO_KEYPOINTS_17: [&str; 17] = [ "right_ankle", ]; +/// COCO dataset object detection labels with 80 categories. +/// +/// A widely-used set of object categories for general object detection, including: +/// - Living beings (person, animals) +/// - Vehicles and transportation +/// - Common objects and furniture +/// - Food and kitchen items +/// - Sports and recreational equipment +/// - Electronics and appliances pub const NAMES_COCO_80: [&str; 80] = [ "person", "bicycle", @@ -4758,6 +4816,14 @@ pub const NAMES_COCO_80: [&str; 80] = [ "toothbrush", ]; +/// Extended COCO dataset labels with 91 categories including background and unused slots. +/// +/// An extended version of COCO labels that includes: +/// - Background class (index 0) +/// - All standard COCO categories +/// - Reserved "unused" slots for dataset compatibility +/// +/// Note: Indices are preserved for compatibility with pre-trained models, hence the "unused" entries. pub const NAMES_COCO_91: [&str; 91] = [ "background", // 0 "person", // 1 @@ -4852,6 +4918,13 @@ pub const NAMES_COCO_91: [&str; 91] = [ "toothbrush", // 90 ]; +/// Human body parts segmentation labels with 28 categories. +/// +/// Detailed categorization of human body parts for segmentation tasks: +/// - Basic body parts (face, neck, hair, torso) +/// - Left/Right symmetrical parts (hands, arms, legs, feet) +/// - Clothing regions (apparel, upper/lower clothing) +/// - Facial features (lips, teeth, tongue) pub const NAMES_BODY_PARTS_28: [&str; 28] = [ "Background", "Apparel", @@ -4883,6 +4956,16 @@ pub const NAMES_BODY_PARTS_28: [&str; 28] = [ "Tongue", ]; +/// ImageNet ILSVRC 1000-class classification labels. +/// +/// The standard ImageNet 1K classification dataset includes: +/// - Animals (mammals, birds, fish, insects) +/// - Objects (tools, vehicles, furniture) +/// - Foods and plants +/// - Natural scenes +/// - Man-made structures +/// +/// Note: Labels include both common names and scientific nomenclature where applicable. pub const NAMES_IMAGENET_1K: [&str; 1000] = [ "tench, Tinca tinca", "goldfish, Carassius auratus", diff --git a/src/core/onnx.rs b/src/core/onnx.rs new file mode 100644 index 0000000..1eb6e5e --- /dev/null +++ b/src/core/onnx.rs @@ -0,0 +1,1210 @@ +// This file is @generated by prost-build. +/// Attributes +/// +/// A named attribute containing either singular float, integer, string, graph, +/// and tensor values, or repeated float, integer, string, graph, and tensor values. +/// An AttributeProto MUST contain the name field, and *only one* of the +/// following content fields, effectively enforcing a C/C++ union equivalent. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AttributeProto { + /// The name field MUST be present for this version of the IR. + /// + /// namespace Attribute + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + /// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + /// In this case, this AttributeProto does not contain data, and it's a reference of attribute + /// in parent scope. + /// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + #[prost(string, tag = "21")] + pub ref_attr_name: ::prost::alloc::string::String, + /// A human-readable documentation for this attribute. Markdown is allowed. + #[prost(string, tag = "13")] + pub doc_string: ::prost::alloc::string::String, + /// The type field MUST be present for this version of the IR. + /// For 0.0.1 versions of the IR, this field was not defined, and + /// implementations needed to use has_field heuristics to determine + /// which value field was in use. For IR_VERSION 0.0.2 or later, this + /// field MUST be set and match the f|i|s|t|... field in use. This + /// change was made to accommodate proto3 implementations. + /// + /// discriminator that indicates which field below is in use + #[prost(enumeration = "attribute_proto::AttributeType", tag = "20")] + pub r#type: i32, + /// Exactly ONE of the following fields must be present for this version of the IR + /// + /// float + #[prost(float, tag = "2")] + pub f: f32, + /// int + #[prost(int64, tag = "3")] + pub i: i64, + /// UTF-8 string + #[prost(bytes = "vec", tag = "4")] + pub s: ::prost::alloc::vec::Vec, + /// tensor value + #[prost(message, optional, tag = "5")] + pub t: ::core::option::Option, + /// graph + #[prost(message, optional, tag = "6")] + pub g: ::core::option::Option, + /// sparse tensor value + #[prost(message, optional, tag = "22")] + pub sparse_tensor: ::core::option::Option, + /// Do not use field below, it's deprecated. + /// optional ValueProto v = 12; // value - subsumes everything but graph + /// + /// type proto + #[prost(message, optional, tag = "14")] + pub tp: ::core::option::Option, + /// list of floats + #[prost(float, repeated, tag = "7")] + pub floats: ::prost::alloc::vec::Vec, + /// list of ints + #[prost(int64, repeated, tag = "8")] + pub ints: ::prost::alloc::vec::Vec, + /// list of UTF-8 strings + #[prost(bytes = "vec", repeated, tag = "9")] + pub strings: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, + /// list of tensors + #[prost(message, repeated, tag = "10")] + pub tensors: ::prost::alloc::vec::Vec, + /// list of graph + #[prost(message, repeated, tag = "11")] + pub graphs: ::prost::alloc::vec::Vec, + /// list of sparse tensors + #[prost(message, repeated, tag = "23")] + pub sparse_tensors: ::prost::alloc::vec::Vec, + /// list of type protos + #[prost(message, repeated, tag = "15")] + pub type_protos: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `AttributeProto`. +pub mod attribute_proto { + /// Note: this enum is structurally identical to the OpSchema::AttrType + /// enum defined in schema.h. If you rev one, you likely need to rev the other. + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum AttributeType { + Undefined = 0, + Float = 1, + Int = 2, + String = 3, + Tensor = 4, + Graph = 5, + SparseTensor = 11, + TypeProto = 13, + Floats = 6, + Ints = 7, + Strings = 8, + Tensors = 9, + Graphs = 10, + SparseTensors = 12, + TypeProtos = 14, + } + impl AttributeType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Undefined => "UNDEFINED", + Self::Float => "FLOAT", + Self::Int => "INT", + Self::String => "STRING", + Self::Tensor => "TENSOR", + Self::Graph => "GRAPH", + Self::SparseTensor => "SPARSE_TENSOR", + Self::TypeProto => "TYPE_PROTO", + Self::Floats => "FLOATS", + Self::Ints => "INTS", + Self::Strings => "STRINGS", + Self::Tensors => "TENSORS", + Self::Graphs => "GRAPHS", + Self::SparseTensors => "SPARSE_TENSORS", + Self::TypeProtos => "TYPE_PROTOS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNDEFINED" => Some(Self::Undefined), + "FLOAT" => Some(Self::Float), + "INT" => Some(Self::Int), + "STRING" => Some(Self::String), + "TENSOR" => Some(Self::Tensor), + "GRAPH" => Some(Self::Graph), + "SPARSE_TENSOR" => Some(Self::SparseTensor), + "TYPE_PROTO" => Some(Self::TypeProto), + "FLOATS" => Some(Self::Floats), + "INTS" => Some(Self::Ints), + "STRINGS" => Some(Self::Strings), + "TENSORS" => Some(Self::Tensors), + "GRAPHS" => Some(Self::Graphs), + "SPARSE_TENSORS" => Some(Self::SparseTensors), + "TYPE_PROTOS" => Some(Self::TypeProtos), + _ => None, + } + } + } +} +/// Defines information on value, including the name, the type, and +/// the shape of the value. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ValueInfoProto { + /// This field MUST be present in this version of the IR. + /// + /// namespace Value + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + /// This field MUST be present in this version of the IR for + /// inputs and outputs of the top-level graph. + #[prost(message, optional, tag = "2")] + pub r#type: ::core::option::Option, + /// A human-readable documentation for this value. Markdown is allowed. + #[prost(string, tag = "3")] + pub doc_string: ::prost::alloc::string::String, + /// Named metadata values; keys should be distinct. + #[prost(message, repeated, tag = "4")] + pub metadata_props: ::prost::alloc::vec::Vec, +} +/// Nodes +/// +/// Computation graphs are made up of a DAG of nodes, which represent what is +/// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +/// +/// For example, it can be a node of type "Conv" that takes in an image, a filter +/// tensor and a bias tensor, and produces the convolved output. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NodeProto { + /// namespace Value + #[prost(string, repeated, tag = "1")] + pub input: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// namespace Value + #[prost(string, repeated, tag = "2")] + pub output: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// An optional identifier for this node in a graph. + /// This field MAY be absent in this version of the IR. + /// + /// namespace Node + #[prost(string, tag = "3")] + pub name: ::prost::alloc::string::String, + /// The symbolic identifier of the Operator to execute. + /// + /// namespace Operator + #[prost(string, tag = "4")] + pub op_type: ::prost::alloc::string::String, + /// The domain of the OperatorSet that specifies the operator named by op_type. + /// + /// namespace Domain + #[prost(string, tag = "7")] + pub domain: ::prost::alloc::string::String, + /// Overload identifier, used only to map this to a model-local function. + #[prost(string, tag = "8")] + pub overload: ::prost::alloc::string::String, + /// Additional named attributes. + #[prost(message, repeated, tag = "5")] + pub attribute: ::prost::alloc::vec::Vec, + /// A human-readable documentation for this node. Markdown is allowed. + #[prost(string, tag = "6")] + pub doc_string: ::prost::alloc::string::String, + /// Named metadata values; keys should be distinct. + #[prost(message, repeated, tag = "9")] + pub metadata_props: ::prost::alloc::vec::Vec, + /// Configuration of multi-device annotations. + #[prost(message, repeated, tag = "10")] + pub device_configurations: ::prost::alloc::vec::Vec, +} +/// IntIntListEntryProto follows the pattern for cross-proto-version maps. +/// See +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IntIntListEntryProto { + #[prost(int64, tag = "1")] + pub key: i64, + #[prost(int64, repeated, tag = "2")] + pub value: ::prost::alloc::vec::Vec, +} +/// Multi-device configuration proto for NodeProto. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NodeDeviceConfigurationProto { + /// This field MUST be present for this version of the IR. + /// ID of the configuration. MUST match the name of a DeviceConfigurationProto. + #[prost(string, tag = "1")] + pub configuration_id: ::prost::alloc::string::String, + /// Sharding spec for the node. + #[prost(message, repeated, tag = "2")] + pub sharding_spec: ::prost::alloc::vec::Vec, + /// Pipeline stage of this node. + #[prost(int32, tag = "3")] + pub pipeline_stage: i32, +} +/// ShardingSpecProto: This describes the sharding spec for a specific +/// input or output tensor of a node. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ShardingSpecProto { + /// This field MUST be present for this version of the IR. + /// Identifies the input or output of the node that is being sharded. + /// Required to match a name specified in the node's input or output list of ValueInfoProtos. + /// It is called `logical tensor` in subsequent descriptions. + #[prost(string, tag = "1")] + pub tensor_name: ::prost::alloc::string::String, + /// The following is the list of devices across which the logical + /// tensor is sharded or replicated. + #[prost(int64, repeated, tag = "2")] + pub device: ::prost::alloc::vec::Vec, + /// Each element v in above field devices may represent either a + /// device or a set of devices (when we want the same shard/tensor + /// to be replicated across a subset of devices), as indicated by + /// the following optional map. If the map contains an entry for v, + /// then v represents a device group, and the map indicates the set + /// of devices in that group. + #[prost(message, repeated, tag = "3")] + pub index_to_device_group_map: ::prost::alloc::vec::Vec, + /// The following is the sharded-shape of the tensor, consisting of + /// the sharding-spec for each axis of the tensor. + #[prost(message, repeated, tag = "4")] + pub sharded_dim: ::prost::alloc::vec::Vec, +} +/// ShardedDimProto: This describes the sharding spec for a single +/// axis of a sharded tensor. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ShardedDimProto { + /// This field MUST be present for this version of the IR. + /// The axis this sharding corresponds to. Must be in the range of + /// \[-r, r - 1\], where r is the rank of the tensor. Negative axis values means + /// counting from the back. + #[prost(int64, tag = "1")] + pub axis: i64, + /// Describes how the tensor on the provided axis is sharded. + /// The common-case is described by a single instance of SimpleShardedDimProto. + /// Multiple instances can be used to handle cases where a sharded + /// tensor is reshaped, fusing multiple axes into one. + #[prost(message, repeated, tag = "2")] + pub simple_sharding: ::prost::alloc::vec::Vec, +} +/// SimpleShardedDimProto: Indicates that N blocks are divided into M shards. +/// N is allowed to be symbolic where M is required to be a constant. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SimpleShardedDimProto { + /// This field MUST be present for this version of the IR. + /// Number of shards to split dim into. + #[prost(int64, tag = "3")] + pub num_shards: i64, + /// Dimension value to be sharded. + #[prost(oneof = "simple_sharded_dim_proto::Dim", tags = "1, 2")] + pub dim: ::core::option::Option, +} +/// Nested message and enum types in `SimpleShardedDimProto`. +pub mod simple_sharded_dim_proto { + /// Dimension value to be sharded. + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Dim { + #[prost(int64, tag = "1")] + DimValue(i64), + #[prost(string, tag = "2")] + DimParam(::prost::alloc::string::String), + } +} +/// Training information +/// TrainingInfoProto stores information for training a model. +/// In particular, this defines two functionalities: an initialization-step +/// and a training-algorithm-step. Initialization resets the model +/// back to its original state as if no training has been performed. +/// Training algorithm improves the model based on input data. +/// +/// The semantics of the initialization-step is that the initializers +/// in ModelProto.graph and in TrainingInfoProto.algorithm are first +/// initialized as specified by the initializers in the graph, and then +/// updated by the "initialization_binding" in every instance in +/// ModelProto.training_info. +/// +/// The field "algorithm" defines a computation graph which represents a +/// training algorithm's step. After the execution of a +/// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +/// may be immediately updated. If the targeted training algorithm contains +/// consecutive update steps (such as block coordinate descent methods), +/// the user needs to create a TrainingInfoProto for each step. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TrainingInfoProto { + /// This field describes a graph to compute the initial tensors + /// upon starting the training process. Initialization graph has no input + /// and can have multiple outputs. Usually, trainable tensors in neural + /// networks are randomly initialized. To achieve that, for each tensor, + /// the user can put a random number operator such as RandomNormal or + /// RandomUniform in TrainingInfoProto.initialization.node and assign its + /// random output to the specific tensor using "initialization_binding". + /// This graph can also set the initializers in "algorithm" in the same + /// TrainingInfoProto; a use case is resetting the number of training + /// iteration to zero. + /// + /// By default, this field is an empty graph and its evaluation does not + /// produce any output. Thus, no initializer would be changed by default. + #[prost(message, optional, tag = "1")] + pub initialization: ::core::option::Option, + /// This field represents a training algorithm step. Given required inputs, + /// it computes outputs to update initializers in its own or inference graph's + /// initializer lists. In general, this field contains loss node, gradient node, + /// optimizer node, increment of iteration count. + /// + /// An execution of the training algorithm step is performed by executing the + /// graph obtained by combining the inference graph (namely "ModelProto.graph") + /// and the "algorithm" graph. That is, the actual + /// input/initializer/output/node/value_info/sparse_initializer list of + /// the training graph is the concatenation of + /// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" + /// and "algorithm.input/initializer/output/node/value_info/sparse_initializer" + /// in that order. This combined graph must satisfy the normal ONNX conditions. + /// Now, let's provide a visualization of graph combination for clarity. + /// Let the inference graph (i.e., "ModelProto.graph") be + /// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d + /// and the "algorithm" graph be + /// tensor_d -> Add -> tensor_e + /// The combination process results + /// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e + /// + /// Notice that an input of a node in the "algorithm" graph may reference the + /// output of a node in the inference graph (but not the other way round). Also, inference + /// node cannot reference inputs of "algorithm". With these restrictions, inference graph + /// can always be run independently without training information. + /// + /// By default, this field is an empty graph and its evaluation does not + /// produce any output. Evaluating the default training step never + /// update any initializers. + #[prost(message, optional, tag = "2")] + pub algorithm: ::core::option::Option, + /// This field specifies the bindings from the outputs of "initialization" to + /// some initializers in "ModelProto.graph.initializer" and + /// the "algorithm.initializer" in the same TrainingInfoProto. + /// See "update_binding" below for details. + /// + /// By default, this field is empty and no initializer would be changed + /// by the execution of "initialization". + #[prost(message, repeated, tag = "3")] + pub initialization_binding: ::prost::alloc::vec::Vec, + /// Gradient-based training is usually an iterative procedure. In one gradient + /// descent iteration, we apply + /// + /// x = x - r * g + /// + /// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + /// gradient of "x" with respect to a chosen loss. To avoid adding assignments + /// into the training graph, we split the update equation into + /// + /// y = x - r * g + /// x = y + /// + /// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + /// tell that "y" should be assigned to "x", the field "update_binding" may + /// contain a key-value pair of strings, "x" (key of StringStringEntryProto) + /// and "y" (value of StringStringEntryProto). + /// For a neural network with multiple trainable (mutable) tensors, there can + /// be multiple key-value pairs in "update_binding". + /// + /// The initializers appears as keys in "update_binding" are considered + /// mutable variables. This implies some behaviors + /// as described below. + /// + /// 1. We have only unique keys in all "update_binding"s so that two + /// variables may not have the same name. This ensures that one + /// variable is assigned up to once. + /// 2. The keys must appear in names of "ModelProto.graph.initializer" or + /// "TrainingInfoProto.algorithm.initializer". + /// 3. The values must be output names of "algorithm" or "ModelProto.graph.output". + /// 4. Mutable variables are initialized to the value specified by the + /// corresponding initializer, and then potentially updated by + /// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + /// + /// This field usually contains names of trainable tensors + /// (in ModelProto.graph), optimizer states such as momentums in advanced + /// stochastic gradient methods (in TrainingInfoProto.graph), + /// and number of training iterations (in TrainingInfoProto.graph). + /// + /// By default, this field is empty and no initializer would be changed + /// by the execution of "algorithm". + #[prost(message, repeated, tag = "4")] + pub update_binding: ::prost::alloc::vec::Vec, +} +/// Models +/// +/// ModelProto is a top-level file/container format for bundling a ML model and +/// associating its computation graph with metadata. +/// +/// The semantics of the model are described by the associated GraphProto's. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ModelProto { + /// The version of the IR this model targets. See Version enum above. + /// This field MUST be present. + #[prost(int64, tag = "1")] + pub ir_version: i64, + /// The OperatorSets this model relies on. + /// All ModelProtos MUST have at least one entry that + /// specifies which version of the ONNX OperatorSet is + /// being imported. + /// + /// All nodes in the ModelProto's graph will bind against the operator + /// with the same-domain/same-op_type operator with the HIGHEST version + /// in the referenced operator sets. + #[prost(message, repeated, tag = "8")] + pub opset_import: ::prost::alloc::vec::Vec, + /// The name of the framework or tool used to generate this model. + /// This field SHOULD be present to indicate which implementation/tool/framework + /// emitted the model. + #[prost(string, tag = "2")] + pub producer_name: ::prost::alloc::string::String, + /// The version of the framework or tool used to generate this model. + /// This field SHOULD be present to indicate which implementation/tool/framework + /// emitted the model. + #[prost(string, tag = "3")] + pub producer_version: ::prost::alloc::string::String, + /// Domain name of the model. + /// We use reverse domain names as name space indicators. For example: + /// `com.facebook.fair` or `com.microsoft.cognitiveservices` + /// + /// Together with `model_version` and GraphProto.name, this forms the unique identity of + /// the graph. + #[prost(string, tag = "4")] + pub domain: ::prost::alloc::string::String, + /// The version of the graph encoded. See Version enum below. + #[prost(int64, tag = "5")] + pub model_version: i64, + /// A human-readable documentation for this model. Markdown is allowed. + #[prost(string, tag = "6")] + pub doc_string: ::prost::alloc::string::String, + /// The parameterized graph that is evaluated to execute the model. + #[prost(message, optional, tag = "7")] + pub graph: ::core::option::Option, + /// Named metadata values; keys should be distinct. + #[prost(message, repeated, tag = "14")] + pub metadata_props: ::prost::alloc::vec::Vec, + /// Training-specific information. Sequentially executing all stored + /// `TrainingInfoProto.algorithm`s and assigning their outputs following + /// the corresponding `TrainingInfoProto.update_binding`s is one training + /// iteration. Similarly, to initialize the model + /// (as if training hasn't happened), the user should sequentially execute + /// all stored `TrainingInfoProto.initialization`s and assigns their outputs + /// using `TrainingInfoProto.initialization_binding`s. + /// + /// If this field is empty, the training behavior of the model is undefined. + #[prost(message, repeated, tag = "20")] + pub training_info: ::prost::alloc::vec::Vec, + /// A list of function protos local to the model. + /// + /// The (domain, name, overload) tuple must be unique across the function protos in this list. + /// In case of any conflicts the behavior (whether the model local functions are given higher priority, + /// or standard operator sets are given higher priority or this is treated as error) is defined by + /// the runtimes. + /// + /// The operator sets imported by FunctionProto should be compatible with the ones + /// imported by ModelProto and other model local FunctionProtos. + /// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto + /// or by 2 FunctionProtos then versions for the operator set may be different but, + /// the operator schema returned for op_type, domain, version combination + /// for both the versions should be same for every node in the function body. + /// + /// One FunctionProto can reference other FunctionProto in the model, however, recursive reference + /// is not allowed. + #[prost(message, repeated, tag = "25")] + pub functions: ::prost::alloc::vec::Vec, + /// Describes different target configurations for a multi-device use case. + /// A model MAY describe multiple multi-device configurations for execution. + #[prost(message, repeated, tag = "26")] + pub configuration: ::prost::alloc::vec::Vec, +} +/// DeviceConfigurationProto describes a multi-device configuration for a model. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DeviceConfigurationProto { + /// This field MUST be present for this version of the IR. + /// Name of the configuration. + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + /// This field MUST be present for this version of the IR. + /// Number of devices inside this configuration. + #[prost(int32, tag = "2")] + pub num_devices: i32, + /// Optional names of the devices. MUST be length of num_devices if provided. + #[prost(string, repeated, tag = "3")] + pub device: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} +/// StringStringEntryProto follows the pattern for cross-proto-version maps. +/// See +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StringStringEntryProto { + #[prost(string, tag = "1")] + pub key: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub value: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TensorAnnotation { + #[prost(string, tag = "1")] + pub tensor_name: ::prost::alloc::string::String, + /// pairs to annotate tensor specified by above. + /// The keys used in the mapping below must be pre-defined in ONNX spec. + /// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + /// quantization parameter keys. + #[prost(message, repeated, tag = "2")] + pub quant_parameter_tensor_names: ::prost::alloc::vec::Vec, +} +/// Graphs +/// +/// A graph defines the computational logic of a model and is comprised of a parameterized +/// list of nodes that form a directed acyclic graph based on their inputs and outputs. +/// This is the equivalent of the "network" or "graph" in many deep learning +/// frameworks. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GraphProto { + /// The nodes in the graph, sorted topologically. + #[prost(message, repeated, tag = "1")] + pub node: ::prost::alloc::vec::Vec, + /// The name of the graph. + /// + /// namespace Graph + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + /// A list of named tensor values, used to specify constant inputs of the graph. + /// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. + /// The name MUST be unique across both initializer and sparse_initializer, + /// but the name MAY also appear in the input list. + #[prost(message, repeated, tag = "5")] + pub initializer: ::prost::alloc::vec::Vec, + /// Initializers (see above) stored in sparse format. + #[prost(message, repeated, tag = "15")] + pub sparse_initializer: ::prost::alloc::vec::Vec, + /// A human-readable documentation for this graph. Markdown is allowed. + #[prost(string, tag = "10")] + pub doc_string: ::prost::alloc::string::String, + /// The inputs and outputs of the graph. + #[prost(message, repeated, tag = "11")] + pub input: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "12")] + pub output: ::prost::alloc::vec::Vec, + /// Information for the values in the graph. The ValueInfoProto.name's + /// must be distinct. It is optional for a value to appear in value_info list. + #[prost(message, repeated, tag = "13")] + pub value_info: ::prost::alloc::vec::Vec, + /// This field carries information to indicate the mapping among a tensor and its + /// quantization parameter tensors. For example: + /// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + /// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + #[prost(message, repeated, tag = "14")] + pub quantization_annotation: ::prost::alloc::vec::Vec, + /// Named metadata values; keys should be distinct. + #[prost(message, repeated, tag = "16")] + pub metadata_props: ::prost::alloc::vec::Vec, +} +/// Tensors +/// +/// A serialized tensor value. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TensorProto { + /// The shape of the tensor. + #[prost(int64, repeated, tag = "1")] + pub dims: ::prost::alloc::vec::Vec, + /// The data type of the tensor. + /// This field MUST have a valid TensorProto.DataType value + #[prost(int32, tag = "2")] + pub data_type: i32, + #[prost(message, optional, tag = "3")] + pub segment: ::core::option::Option, + /// For float and complex64 values + /// Complex64 tensors are encoded as a single array of floats, + /// with the real components appearing in odd numbered positions, + /// and the corresponding imaginary component appearing in the + /// subsequent even numbered position. (e.g., \[1.0 + 2.0i, 3.0 + 4.0i\] + /// is encoded as \[1.0, 2.0 ,3.0 ,4.0\] + /// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + #[prost(float, repeated, tag = "4")] + pub float_data: ::prost::alloc::vec::Vec, + /// For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4: + /// - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer + /// representation before being written to the buffer. + /// - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte. + /// The first element is stored in the 4 least significant bits (LSB), + /// and the second element is stored in the 4 most significant bits (MSB). + /// + /// Consequently: + /// - For data types with a bit-width of 8 or greater, each `int32_data` stores one element. + /// - For 4-bit data types, each `int32_data` stores two elements. + /// + /// When this field is present, the data_type field MUST be + /// INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1 + #[prost(int32, repeated, tag = "5")] + pub int32_data: ::prost::alloc::vec::Vec, + /// For strings. + /// Each element of string_data is a UTF-8 encoded Unicode + /// string. No trailing null, no leading BOM. The protobuf "string" + /// scalar type is not used to match ML community conventions. + /// When this field is present, the data_type field MUST be STRING + #[prost(bytes = "vec", repeated, tag = "6")] + pub string_data: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, + /// For int64. + /// When this field is present, the data_type field MUST be INT64 + #[prost(int64, repeated, tag = "7")] + pub int64_data: ::prost::alloc::vec::Vec, + /// Optionally, a name for the tensor. + /// + /// namespace Value + #[prost(string, tag = "8")] + pub name: ::prost::alloc::string::String, + /// A human-readable documentation for this tensor. Markdown is allowed. + #[prost(string, tag = "12")] + pub doc_string: ::prost::alloc::string::String, + /// Serializations can either use one of the fields above, or use this + /// raw bytes field. The only exception is the string case, where one is + /// required to store the content in the repeated bytes string_data field. + /// + /// When this raw_data field is used to store tensor value, elements MUST + /// be stored in as fixed-width, little-endian order. + /// Floating-point data types MUST be stored in IEEE 754 format. + /// Complex64 elements must be written as two consecutive FLOAT values, real component first. + /// Complex128 elements must be written as two consecutive DOUBLE values, real component first. + /// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + /// uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB. + /// + /// Note: the advantage of specific field rather than the raw_data field is + /// that in some cases (e.g. int data), protobuf does a better packing via + /// variable length storage, and may lead to smaller binary footprint. + /// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + #[prost(bytes = "vec", tag = "9")] + pub raw_data: ::prost::alloc::vec::Vec, + /// Data can be stored inside the protobuf file using type-specific fields or raw_data. + /// Alternatively, raw bytes data can be stored in an external file, using the external_data field. + /// external_data stores key-value pairs describing data location. Recognized keys are: + /// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + /// protobuf model was stored + /// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + /// Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + /// - "length" (optional) - number of bytes containing data. Integer stored as string. + /// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + #[prost(message, repeated, tag = "13")] + pub external_data: ::prost::alloc::vec::Vec, + /// If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + #[prost(enumeration = "tensor_proto::DataLocation", tag = "14")] + pub data_location: i32, + /// For double + /// Complex128 tensors are encoded as a single array of doubles, + /// with the real components appearing in odd numbered positions, + /// and the corresponding imaginary component appearing in the + /// subsequent even numbered position. (e.g., \[1.0 + 2.0i, 3.0 + 4.0i\] + /// is encoded as \[1.0, 2.0 ,3.0 ,4.0\] + /// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + #[prost(double, repeated, tag = "10")] + pub double_data: ::prost::alloc::vec::Vec, + /// For uint64 and uint32 values + /// When this field is present, the data_type field MUST be + /// UINT32 or UINT64 + #[prost(uint64, repeated, tag = "11")] + pub uint64_data: ::prost::alloc::vec::Vec, + /// Named metadata values; keys should be distinct. + #[prost(message, repeated, tag = "16")] + pub metadata_props: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `TensorProto`. +pub mod tensor_proto { + /// For very large tensors, we may want to store them in chunks, in which + /// case the following fields will specify the segment that is stored in + /// the current TensorProto. + #[derive(Clone, Copy, PartialEq, ::prost::Message)] + pub struct Segment { + #[prost(int64, tag = "1")] + pub begin: i64, + #[prost(int64, tag = "2")] + pub end: i64, + } + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum DataType { + Undefined = 0, + /// Basic types. + /// + /// float + Float = 1, + /// uint8_t + Uint8 = 2, + /// int8_t + Int8 = 3, + /// uint16_t + Uint16 = 4, + /// int16_t + Int16 = 5, + /// int32_t + Int32 = 6, + /// int64_t + Int64 = 7, + /// string + String = 8, + /// bool + Bool = 9, + /// IEEE754 half-precision floating-point format (16 bits wide). + /// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + Float16 = 10, + Double = 11, + Uint32 = 12, + Uint64 = 13, + /// complex with float32 real and imaginary components + Complex64 = 14, + /// complex with float64 real and imaginary components + Complex128 = 15, + /// Non-IEEE floating-point format based on IEEE754 single-precision + /// floating-point number truncated to 16 bits. + /// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + Bfloat16 = 16, + /// Non-IEEE floating-point format based on papers + /// FP8 Formats for Deep Learning, + /// 8-bit Numerical Formats For Deep Neural Networks, + /// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + /// The computation usually happens inside a block quantize / dequantize + /// fused by the runtime. + /// + /// float 8, mostly used for coefficients, supports nan, not inf + Float8e4m3fn = 17, + /// float 8, mostly used for coefficients, supports nan, not inf, no negative zero + Float8e4m3fnuz = 18, + /// follows IEEE 754, supports nan, inf, mostly used for gradients + Float8e5m2 = 19, + /// follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero + Float8e5m2fnuz = 20, + /// 4-bit integer data types + /// + /// Unsigned integer in range \[0, 15\] + Uint4 = 21, + /// Signed integer in range \[-8, 7\], using two's-complement representation + Int4 = 22, + /// 4-bit floating point data types + Float4e2m1 = 23, + } + impl DataType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Undefined => "UNDEFINED", + Self::Float => "FLOAT", + Self::Uint8 => "UINT8", + Self::Int8 => "INT8", + Self::Uint16 => "UINT16", + Self::Int16 => "INT16", + Self::Int32 => "INT32", + Self::Int64 => "INT64", + Self::String => "STRING", + Self::Bool => "BOOL", + Self::Float16 => "FLOAT16", + Self::Double => "DOUBLE", + Self::Uint32 => "UINT32", + Self::Uint64 => "UINT64", + Self::Complex64 => "COMPLEX64", + Self::Complex128 => "COMPLEX128", + Self::Bfloat16 => "BFLOAT16", + Self::Float8e4m3fn => "FLOAT8E4M3FN", + Self::Float8e4m3fnuz => "FLOAT8E4M3FNUZ", + Self::Float8e5m2 => "FLOAT8E5M2", + Self::Float8e5m2fnuz => "FLOAT8E5M2FNUZ", + Self::Uint4 => "UINT4", + Self::Int4 => "INT4", + Self::Float4e2m1 => "FLOAT4E2M1", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNDEFINED" => Some(Self::Undefined), + "FLOAT" => Some(Self::Float), + "UINT8" => Some(Self::Uint8), + "INT8" => Some(Self::Int8), + "UINT16" => Some(Self::Uint16), + "INT16" => Some(Self::Int16), + "INT32" => Some(Self::Int32), + "INT64" => Some(Self::Int64), + "STRING" => Some(Self::String), + "BOOL" => Some(Self::Bool), + "FLOAT16" => Some(Self::Float16), + "DOUBLE" => Some(Self::Double), + "UINT32" => Some(Self::Uint32), + "UINT64" => Some(Self::Uint64), + "COMPLEX64" => Some(Self::Complex64), + "COMPLEX128" => Some(Self::Complex128), + "BFLOAT16" => Some(Self::Bfloat16), + "FLOAT8E4M3FN" => Some(Self::Float8e4m3fn), + "FLOAT8E4M3FNUZ" => Some(Self::Float8e4m3fnuz), + "FLOAT8E5M2" => Some(Self::Float8e5m2), + "FLOAT8E5M2FNUZ" => Some(Self::Float8e5m2fnuz), + "UINT4" => Some(Self::Uint4), + "INT4" => Some(Self::Int4), + "FLOAT4E2M1" => Some(Self::Float4e2m1), + _ => None, + } + } + } + /// Location of the data for this tensor. MUST be one of: + /// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + /// - EXTERNAL - data stored in an external location as described by external_data field. + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum DataLocation { + Default = 0, + External = 1, + } + impl DataLocation { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Default => "DEFAULT", + Self::External => "EXTERNAL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "DEFAULT" => Some(Self::Default), + "EXTERNAL" => Some(Self::External), + _ => None, + } + } + } +} +/// A serialized sparse-tensor value +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SparseTensorProto { + /// The sequence of non-default values are encoded as a tensor of shape \[NNZ\]. + /// The default-value is zero for numeric tensors, and empty-string for string tensors. + /// values must have a non-empty name present which serves as a name for SparseTensorProto + /// when used in sparse_initializer list. + #[prost(message, optional, tag = "1")] + pub values: ::core::option::Option, + /// The indices of the non-default values, which may be stored in one of two formats. + /// (a) Indices can be a tensor of shape \[NNZ, rank\] with the \[i,j\]-th value + /// corresponding to the j-th index of the i-th value (in the values tensor). + /// (b) Indices can be a tensor of shape \[NNZ\], in which case the i-th value + /// must be the linearized-index of the i-th value (in the values tensor). + /// The linearized-index can be converted into an index tuple (k_1,...,k_rank) + /// using the shape provided below. + /// The indices must appear in ascending order without duplication. + /// In the first format, the ordering is lexicographic-ordering: + /// e.g., index-value \[1,4\] must appear before \[2,1\] + #[prost(message, optional, tag = "2")] + pub indices: ::core::option::Option, + /// The shape of the underlying dense-tensor: \[dim_1, dim_2, ... dim_rank\] + #[prost(int64, repeated, tag = "3")] + pub dims: ::prost::alloc::vec::Vec, +} +/// Defines a tensor shape. A dimension can be either an integer value +/// or a symbolic variable. A symbolic variable represents an unknown +/// dimension. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TensorShapeProto { + #[prost(message, repeated, tag = "1")] + pub dim: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `TensorShapeProto`. +pub mod tensor_shape_proto { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Dimension { + /// Standard denotation can optionally be used to denote tensor + /// dimensions with standard semantic descriptions to ensure + /// that operations are applied to the correct axis of a tensor. + /// Refer to + /// for pre-defined dimension denotations. + #[prost(string, tag = "3")] + pub denotation: ::prost::alloc::string::String, + #[prost(oneof = "dimension::Value", tags = "1, 2")] + pub value: ::core::option::Option, + } + /// Nested message and enum types in `Dimension`. + pub mod dimension { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(int64, tag = "1")] + DimValue(i64), + /// namespace Shape + #[prost(string, tag = "2")] + DimParam(::prost::alloc::string::String), + } + } +} +/// Types +/// +/// The standard ONNX data types. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TypeProto { + /// An optional denotation can be used to denote the whole + /// type with a standard semantic description as to what is + /// stored inside. Refer to + /// for pre-defined type denotations. + #[prost(string, tag = "6")] + pub denotation: ::prost::alloc::string::String, + #[prost(oneof = "type_proto::Value", tags = "1, 4, 5, 9, 8")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `TypeProto`. +pub mod type_proto { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Tensor { + /// This field MUST NOT have the value of UNDEFINED + /// This field MUST have a valid TensorProto.DataType value + /// This field MUST be present for this version of the IR. + #[prost(int32, tag = "1")] + pub elem_type: i32, + #[prost(message, optional, tag = "2")] + pub shape: ::core::option::Option, + } + /// repeated T + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Sequence { + /// The type and optional shape of each element of the sequence. + /// This field MUST be present for this version of the IR. + #[prost(message, optional, boxed, tag = "1")] + pub elem_type: ::core::option::Option<::prost::alloc::boxed::Box>, + } + /// map + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Map { + /// This field MUST have a valid TensorProto.DataType value + /// This field MUST be present for this version of the IR. + /// This field MUST refer to an integral type (\[U\]INT{8|16|32|64}) or STRING + #[prost(int32, tag = "1")] + pub key_type: i32, + /// This field MUST be present for this version of the IR. + #[prost(message, optional, boxed, tag = "2")] + pub value_type: ::core::option::Option<::prost::alloc::boxed::Box>, + } + /// wrapper for Tensor, Sequence, or Map + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Optional { + /// The type and optional shape of the element wrapped. + /// This field MUST be present for this version of the IR. + /// Possible values correspond to OptionalProto.DataType enum + #[prost(message, optional, boxed, tag = "1")] + pub elem_type: ::core::option::Option<::prost::alloc::boxed::Box>, + } + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct SparseTensor { + /// This field MUST NOT have the value of UNDEFINED + /// This field MUST have a valid TensorProto.DataType value + /// This field MUST be present for this version of the IR. + #[prost(int32, tag = "1")] + pub elem_type: i32, + #[prost(message, optional, tag = "2")] + pub shape: ::core::option::Option, + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Value { + /// The type of a tensor. + #[prost(message, tag = "1")] + TensorType(Tensor), + /// The type of a sequence. + #[prost(message, tag = "4")] + SequenceType(::prost::alloc::boxed::Box), + /// The type of a map. + #[prost(message, tag = "5")] + MapType(::prost::alloc::boxed::Box), + /// The type of an optional. + #[prost(message, tag = "9")] + OptionalType(::prost::alloc::boxed::Box), + /// Type of the sparse tensor + #[prost(message, tag = "8")] + SparseTensorType(SparseTensor), + } +} +/// Operator Sets +/// +/// OperatorSets are uniquely identified by a (domain, opset_version) pair. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct OperatorSetIdProto { + /// The domain of the operator set being identified. + /// The empty string ("") or absence of this field implies the operator + /// set that is defined as part of the ONNX specification. + /// This field MUST be present in this version of the IR when referring to any other operator set. + #[prost(string, tag = "1")] + pub domain: ::prost::alloc::string::String, + /// The version of the operator set being identified. + /// This field MUST be present in this version of the IR. + #[prost(int64, tag = "2")] + pub version: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FunctionProto { + /// The name of the function, similar to op_type in NodeProto. + /// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + /// The inputs and outputs of the function. + #[prost(string, repeated, tag = "4")] + pub input: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(string, repeated, tag = "5")] + pub output: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// The attribute parameters of the function. + /// It is for function parameters without default values. + #[prost(string, repeated, tag = "6")] + pub attribute: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// The attribute protos of the function. + /// It is for function attributes with default values. + /// A function attribute shall be represented either as + /// a string attribute or an AttributeProto, not both. + #[prost(message, repeated, tag = "11")] + pub attribute_proto: ::prost::alloc::vec::Vec, + /// The nodes in the function. + #[prost(message, repeated, tag = "7")] + pub node: ::prost::alloc::vec::Vec, + /// A human-readable documentation for this function. Markdown is allowed. + #[prost(string, tag = "8")] + pub doc_string: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "9")] + pub opset_import: ::prost::alloc::vec::Vec, + /// The domain which this function belongs to. + /// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + #[prost(string, tag = "10")] + pub domain: ::prost::alloc::string::String, + /// The overload identifier of the function. + /// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + #[prost(string, tag = "13")] + pub overload: ::prost::alloc::string::String, + /// Information for the values in the function. The ValueInfoProto.name's + /// must be distinct and refer to names in the function (including inputs, + /// outputs, and intermediate values). It is optional for a value to appear + /// in value_info list. + #[prost(message, repeated, tag = "12")] + pub value_info: ::prost::alloc::vec::Vec, + /// Named metadata values; keys should be distinct. + #[prost(message, repeated, tag = "14")] + pub metadata_props: ::prost::alloc::vec::Vec, +} +/// Versioning +/// +/// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +/// +/// To be compatible with both proto2 and proto3, we will use a version number +/// that is not defined by the default value but an explicit enum number. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum Version { + /// proto3 requires the first enum value to be zero. + /// We add this just to appease the compiler. + StartVersion = 0, + /// The version field is always serialized and we will use it to store the + /// version that the graph is generated from. This helps us set up version + /// control. + /// For the IR, we are using simple numbers starting with 0x00000001, + /// which was the version we published on Oct 10, 2017. + IrVersion20171010 = 1, + /// IR_VERSION 2 published on Oct 30, 2017 + /// - Added type discriminator to AttributeProto to support proto3 users + IrVersion20171030 = 2, + /// IR VERSION 3 published on Nov 3, 2017 + /// - For operator versioning: + /// - Added new message OperatorSetIdProto + /// - Added opset_import in ModelProto + /// - For vendor extensions, added domain in NodeProto + IrVersion2017113 = 3, + /// IR VERSION 4 published on Jan 22, 2019 + /// - Relax constraint that initializers should be a subset of graph inputs + /// - Add type BFLOAT16 + IrVersion2019122 = 4, + /// IR VERSION 5 published on March 18, 2019 + /// - Add message TensorAnnotation. + /// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IrVersion2019318 = 5, + /// IR VERSION 6 published on Sep 19, 2019 + /// - Add support for sparse tensor constants stored in model. + /// - Add message SparseTensorProto + /// - Add sparse initializers + IrVersion2019919 = 6, + /// IR VERSION 7 published on May 8, 2020 + /// - Add support to allow function body graph to rely on multiple external operator sets. + /// - Add a list to promote inference graph's initializers to global and + /// mutable variables. Global variables are visible in all graphs of the + /// stored models. + /// - Add message TrainingInfoProto to store initialization + /// method and training algorithm. The execution of TrainingInfoProto + /// can modify the values of mutable variables. + /// - Implicitly add inference graph into each TrainingInfoProto's algorithm. + IrVersion202058 = 7, + /// IR VERSION 8 published on July 30, 2021 + /// Introduce TypeProto.SparseTensor + /// Introduce TypeProto.Optional + /// Added a list of FunctionProtos local to the model + /// Deprecated since_version and operator status from FunctionProto + IrVersion2021730 = 8, + /// IR VERSION 9 published on May 5, 2023 + /// Added AttributeProto to FunctionProto so that default attribute values can be set. + /// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. + IrVersion202355 = 9, + /// IR VERSION 10 published on March 25, 2024 + /// Added UINT4, INT4, overload field for functions and metadata_props on multiple proto definitions. + IrVersion2024325 = 10, + /// IR VERSION 11 published on May 12, 2025 + /// Added FLOAT4E2M1, multi-device protobuf classes. + IrVersion = 11, +} +impl Version { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::StartVersion => "_START_VERSION", + Self::IrVersion20171010 => "IR_VERSION_2017_10_10", + Self::IrVersion20171030 => "IR_VERSION_2017_10_30", + Self::IrVersion2017113 => "IR_VERSION_2017_11_3", + Self::IrVersion2019122 => "IR_VERSION_2019_1_22", + Self::IrVersion2019318 => "IR_VERSION_2019_3_18", + Self::IrVersion2019919 => "IR_VERSION_2019_9_19", + Self::IrVersion202058 => "IR_VERSION_2020_5_8", + Self::IrVersion2021730 => "IR_VERSION_2021_7_30", + Self::IrVersion202355 => "IR_VERSION_2023_5_5", + Self::IrVersion2024325 => "IR_VERSION_2024_3_25", + Self::IrVersion => "IR_VERSION", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "_START_VERSION" => Some(Self::StartVersion), + "IR_VERSION_2017_10_10" => Some(Self::IrVersion20171010), + "IR_VERSION_2017_10_30" => Some(Self::IrVersion20171030), + "IR_VERSION_2017_11_3" => Some(Self::IrVersion2017113), + "IR_VERSION_2019_1_22" => Some(Self::IrVersion2019122), + "IR_VERSION_2019_3_18" => Some(Self::IrVersion2019318), + "IR_VERSION_2019_9_19" => Some(Self::IrVersion2019919), + "IR_VERSION_2020_5_8" => Some(Self::IrVersion202058), + "IR_VERSION_2021_7_30" => Some(Self::IrVersion2021730), + "IR_VERSION_2023_5_5" => Some(Self::IrVersion202355), + "IR_VERSION_2024_3_25" => Some(Self::IrVersion2024325), + "IR_VERSION" => Some(Self::IrVersion), + _ => None, + } + } +} +/// Operator/function status. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum OperatorStatus { + Experimental = 0, + Stable = 1, +} +impl OperatorStatus { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Experimental => "EXPERIMENTAL", + Self::Stable => "STABLE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "EXPERIMENTAL" => Some(Self::Experimental), + "STABLE" => Some(Self::Stable), + _ => None, + } + } +} diff --git a/src/utils/ops.rs b/src/core/ops.rs similarity index 96% rename from src/utils/ops.rs rename to src/core/ops.rs index af0492c..7652c0d 100644 --- a/src/utils/ops.rs +++ b/src/core/ops.rs @@ -11,19 +11,33 @@ use ndarray::{concatenate, s, Array, Array3, ArrayView1, Axis, IntoDimension, Ix use rayon::prelude::*; +/// Image and tensor operations for preprocessing and postprocessing. pub enum Ops<'a> { + /// Resize images to exact dimensions. FitExact(&'a [DynamicImage], u32, u32, &'a str), + /// Apply letterbox padding to maintain aspect ratio. Letterbox(&'a [DynamicImage], u32, u32, &'a str, u8, &'a str, bool), + /// Normalize values to a specific range. Normalize(f32, f32), + /// Standardize using mean and standard deviation. Standardize(&'a [f32], &'a [f32], usize), + /// Permute tensor dimensions. Permute(&'a [usize]), + /// Insert a new axis at specified position. InsertAxis(usize), + /// Convert from NHWC to NCHW format. Nhwc2nchw, + /// Convert from NCHW to NHWC format. Nchw2nhwc, + /// Apply L2 normalization. Norm, + /// Apply sigmoid activation function. Sigmoid, + /// Broadcast tensor to larger dimensions. Broadcast, + /// Reshape tensor to specified shape. ToShape, + /// Repeat tensor elements. Repeat, } diff --git a/src/utils/ort_config.rs b/src/core/ort_config.rs similarity index 99% rename from src/utils/ort_config.rs rename to src/core/ort_config.rs index cc4a08b..95be40d 100644 --- a/src/utils/ort_config.rs +++ b/src/core/ort_config.rs @@ -3,6 +3,7 @@ use anyhow::Result; use crate::{try_fetch_file_stem, DType, Device, Hub, Iiix, MinOptMax}; +/// ONNX Runtime configuration with device and optimization settings. #[derive(Builder, Debug, Clone)] pub struct ORTConfig { pub file: String, @@ -71,7 +72,7 @@ impl Default for ORTConfig { tensorrt_timing_cache: false, cann_graph_inference: true, cann_dump_graphs: false, - cann_dump_om_model: false, + cann_dump_om_model: true, onednn_arena_allocator: true, nnapi_cpu_only: false, nnapi_disable_cpu: false, @@ -159,7 +160,6 @@ impl ORTConfig { } } -#[macro_export] macro_rules! impl_ort_config_methods { ($ty:ty, $field:ident) => { impl $ty { diff --git a/src/utils/processor.rs b/src/core/processor.rs similarity index 99% rename from src/utils/processor.rs rename to src/core/processor.rs index 54fe366..ad3a3d3 100644 --- a/src/utils/processor.rs +++ b/src/core/processor.rs @@ -7,6 +7,7 @@ use tokenizers::{Encoding, Tokenizer}; use crate::{Hub, Image, ImageTransformInfo, LogitsSampler, ProcessorConfig, ResizeMode, X}; +/// Image and text processing pipeline with tokenization and transformation capabilities. #[derive(Builder, Debug, Clone)] pub struct Processor { pub image_width: u32, diff --git a/src/utils/processor_config.rs b/src/core/processor_config.rs similarity index 90% rename from src/utils/processor_config.rs rename to src/core/processor_config.rs index 6a5898e..4d022b8 100644 --- a/src/utils/processor_config.rs +++ b/src/core/processor_config.rs @@ -4,30 +4,51 @@ use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; use crate::{Hub, ResizeMode}; +/// Configuration for image and text processing pipelines. #[derive(Builder, Debug, Clone)] pub struct ProcessorConfig { // Vision + /// Target image width for resizing. pub image_width: Option, + /// Target image height for resizing. pub image_height: Option, + /// Image resizing mode. pub resize_mode: ResizeMode, + /// Image resize filter algorithm. pub resize_filter: Option<&'static str>, + /// Padding value for image borders. pub padding_value: u8, + /// Whether to normalize image values. pub normalize: bool, + /// Standard deviation values for normalization. pub image_std: Vec, + /// Mean values for normalization. pub image_mean: Vec, + /// Whether to use NCHW format (channels first). pub nchw: bool, + /// Whether to use unsigned integer format. pub unsigned: bool, // Text + /// Maximum sequence length for tokenization. pub model_max_length: Option, + /// Path to tokenizer file. pub tokenizer_file: Option, + /// Path to model configuration file. pub config_file: Option, + /// Path to special tokens mapping file. pub special_tokens_map_file: Option, + /// Path to tokenizer configuration file. pub tokenizer_config_file: Option, + /// Path to generation configuration file. pub generation_config_file: Option, + /// Path to vocabulary file. pub vocab_file: Option, + /// Path to vocabulary text file. pub vocab_txt: Option, + /// Temperature parameter for text generation. pub temperature: f32, + /// Top-p parameter for nucleus sampling. pub topp: f32, } @@ -150,7 +171,6 @@ impl ProcessorConfig { } } -#[macro_export] macro_rules! impl_processor_config_methods { ($ty:ty, $field:ident) => { impl $ty { diff --git a/src/utils/retry.rs b/src/core/retry.rs similarity index 100% rename from src/utils/retry.rs rename to src/core/retry.rs diff --git a/src/utils/scale.rs b/src/core/scale.rs similarity index 98% rename from src/utils/scale.rs rename to src/core/scale.rs index 0949bbc..8d1f92a 100644 --- a/src/utils/scale.rs +++ b/src/core/scale.rs @@ -1,5 +1,6 @@ use std::str::FromStr; +/// Model scale variants for different model sizes. #[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum Scale { N, diff --git a/src/utils/task.rs b/src/core/task.rs similarity index 100% rename from src/utils/task.rs rename to src/core/task.rs diff --git a/src/utils/traits.rs b/src/core/traits.rs similarity index 70% rename from src/utils/traits.rs rename to src/core/traits.rs index c9f50ed..53505eb 100644 --- a/src/utils/traits.rs +++ b/src/core/traits.rs @@ -1,15 +1,22 @@ use crate::{Hbb, Obb}; +/// Trait for objects that have a confidence score. pub trait HasScore { + /// Returns the confidence score. fn score(&self) -> f32; } +/// Trait for objects that can calculate Intersection over Union (IoU). pub trait HasIoU { + /// Calculates IoU with another object. fn iou(&self, other: &Self) -> f32; } +/// Trait for Non-Maximum Suppression operations. pub trait NmsOps { + /// Applies NMS in-place with the given IoU threshold. fn apply_nms_inplace(&mut self, iou_threshold: f32); + /// Applies NMS and returns the filtered result. fn apply_nms(self, iou_threshold: f32) -> Self; } @@ -68,11 +75,17 @@ impl HasIoU for Obb { } } +/// Trait for geometric regions with area and intersection calculations. pub trait Region { + /// Calculates the area of the region. fn area(&self) -> f32; + /// Calculates the perimeter of the region. fn perimeter(&self) -> f32; + /// Calculates the intersection area with another region. fn intersect(&self, other: &Self) -> f32; + /// Calculates the union area with another region. fn union(&self, other: &Self) -> f32; + /// Calculates Intersection over Union (IoU) with another region. fn iou(&self, other: &Self) -> f32 { self.intersect(other) / self.union(other) } diff --git a/src/utils/ts.rs b/src/core/ts.rs similarity index 99% rename from src/utils/ts.rs rename to src/core/ts.rs index dbc55c3..d0beacf 100644 --- a/src/utils/ts.rs +++ b/src/core/ts.rs @@ -19,6 +19,7 @@ macro_rules! elapsed { }}; } +/// Time series collection for performance measurement and profiling. #[derive(aksr::Builder, Debug, Default, Clone, PartialEq)] pub struct Ts { // { k1: [d1,d1,d1,..], k2: [d2,d2,d2,..], k3: [d3,d3,d3,..], ..} diff --git a/src/utils/mod.rs b/src/core/utils.rs similarity index 55% rename from src/utils/mod.rs rename to src/core/utils.rs index 7c33ac0..29b3026 100644 --- a/src/utils/mod.rs +++ b/src/core/utils.rs @@ -1,52 +1,21 @@ -mod config; -mod device; -mod dtype; -mod dynconf; -mod iiix; -mod logits_sampler; -mod min_opt_max; -mod names; -mod ops; -mod ort_config; -mod processor; -mod processor_config; -mod retry; -mod scale; -mod task; -mod traits; -mod ts; -mod version; - -pub use config::*; -pub use device::Device; -pub use dtype::DType; -pub use dynconf::DynConf; -pub(crate) use iiix::Iiix; -pub use logits_sampler::LogitsSampler; -pub use min_opt_max::MinOptMax; -pub use names::*; -pub use ops::*; -pub use ort_config::ORTConfig; -pub use processor::*; -pub use processor_config::ProcessorConfig; -pub use scale::Scale; -pub use task::Task; -pub use traits::*; -pub use ts::Ts; -pub use version::Version; - +/// The name of the current crate. pub const CRATE_NAME: &str = env!("CARGO_PKG_NAME"); +/// Standard prefix length for progress bar formatting. pub const PREFIX_LENGTH: usize = 12; +/// Progress bar style for completion with iteration count. pub const PROGRESS_BAR_STYLE_FINISH: &str = "{prefix:>12.green.bold} {msg} for {human_len} iterations in {elapsed}"; +/// Progress bar style for completion with multiplier format. pub const PROGRESS_BAR_STYLE_FINISH_2: &str = "{prefix:>12.green.bold} {msg} x{human_len} in {elapsed}"; +/// Progress bar style for completion with byte size information. pub const PROGRESS_BAR_STYLE_FINISH_3: &str = "{prefix:>12.green.bold} {msg} ({binary_total_bytes}) in {elapsed}"; +/// Progress bar style for ongoing operations with position indicator. pub const PROGRESS_BAR_STYLE_CYAN_2: &str = "{prefix:>12.cyan.bold} {human_pos}/{human_len} |{bar}| {msg}"; -pub fn build_resizer_filter( +pub(crate) fn build_resizer_filter( ty: &str, ) -> anyhow::Result<(fast_image_resize::Resizer, fast_image_resize::ResizeOptions)> { use fast_image_resize::{FilterType, ResizeAlg, ResizeOptions, Resizer}; @@ -66,7 +35,7 @@ pub fn build_resizer_filter( )) } -pub fn try_fetch_file_stem>(p: P) -> anyhow::Result { +pub(crate) fn try_fetch_file_stem>(p: P) -> anyhow::Result { let p = p.as_ref(); let stem = p .file_stem() @@ -80,8 +49,7 @@ pub fn try_fetch_file_stem>(p: P) -> anyhow::Result, @@ -95,11 +63,41 @@ pub fn build_progress_bar( Ok(pb) } +/// Formats a byte size into a human-readable string using decimal (base-1000) units. +/// +/// # Arguments +/// * `size` - The size in bytes to format +/// * `decimal_places` - Number of decimal places to show in the formatted output +/// +/// # Returns +/// A string representing the size with appropriate decimal unit (B, KB, MB, etc.) +/// +/// # Example +/// ```ignore +/// let size = 1500000.0; +/// let formatted = human_bytes_decimal(size, 2); +/// assert_eq!(formatted, "1.50 MB"); +/// ``` pub fn human_bytes_decimal(size: f64, decimal_places: usize) -> String { const DECIMAL_UNITS: [&str; 7] = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; format_bytes_internal(size, 1000.0, &DECIMAL_UNITS, decimal_places) } +/// Formats a byte size into a human-readable string using binary (base-1024) units. +/// +/// # Arguments +/// * `size` - The size in bytes to format +/// * `decimal_places` - Number of decimal places to show in the formatted output +/// +/// # Returns +/// A string representing the size with appropriate binary unit (B, KiB, MiB, etc.) +/// +/// # Example +/// ```ignore +/// let size = 1024.0 * 1024.0; // 1 MiB in bytes +/// let formatted = human_bytes_binary(size, 2); +/// assert_eq!(formatted, "1.00 MiB"); +/// ``` pub fn human_bytes_binary(size: f64, decimal_places: usize) -> String { const BINARY_UNITS: [&str; 7] = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"]; format_bytes_internal(size, 1024.0, &BINARY_UNITS, decimal_places) @@ -125,6 +123,22 @@ fn format_bytes_internal( ) } +/// Generates a random alphanumeric string of the specified length. +/// +/// # Arguments +/// * `length` - The desired length of the random string +/// +/// # Returns +/// A string containing random alphanumeric characters (A-Z, a-z, 0-9). +/// Returns an empty string if length is 0. +/// +/// # Example +/// ```ignore +/// let random = generate_random_string(8); +/// assert_eq!(random.len(), 8); +/// // Each character in the string will be alphanumeric +/// assert!(random.chars().all(|c| c.is_ascii_alphanumeric())); +/// ``` pub fn generate_random_string(length: usize) -> String { use rand::{distr::Alphanumeric, rng, Rng}; if length == 0 { @@ -136,6 +150,28 @@ pub fn generate_random_string(length: usize) -> String { result } +/// Generates a timestamp string in the format "YYYYMMDDHHmmSSffffff" with an optional delimiter. +/// +/// # Arguments +/// * `delimiter` - Optional string to insert between each component of the timestamp. +/// If None or empty string, components will be joined without delimiter. +/// +/// # Returns +/// A string containing the current timestamp with the specified format. +/// +/// # Example +/// ```ignore +/// use chrono::TimeZone; +/// +/// // Without delimiter +/// let ts = timestamp(None); +/// assert_eq!(ts.len(), 20); // YYYYMMDDHHmmSSffffff +/// +/// // With delimiter +/// let ts = timestamp(Some("-")); +/// // Format: YYYY-MM-DD-HH-mm-SS-ffffff +/// assert_eq!(ts.split('-').count(), 7); +/// ``` pub fn timestamp(delimiter: Option<&str>) -> String { let delimiter = delimiter.unwrap_or(""); let format = format!("%Y{0}%m{0}%d{0}%H{0}%M{0}%S{0}%f", delimiter); diff --git a/src/utils/version.rs b/src/core/version.rs similarity index 95% rename from src/utils/version.rs rename to src/core/version.rs index 97e59ea..696ad82 100644 --- a/src/utils/version.rs +++ b/src/core/version.rs @@ -1,6 +1,7 @@ use aksr::Builder; use std::convert::TryFrom; +/// Version representation with major, minor, and optional patch numbers. #[derive(Debug, Builder, PartialEq, Eq, Copy, Clone, Hash, Default, PartialOrd, Ord)] pub struct Version(pub u8, pub u8, pub Option); diff --git a/src/inference/x.rs b/src/core/x.rs similarity index 100% rename from src/inference/x.rs rename to src/core/x.rs diff --git a/src/inference/xs.rs b/src/core/xs.rs similarity index 97% rename from src/inference/xs.rs rename to src/core/xs.rs index 54735c6..7b6334a 100644 --- a/src/inference/xs.rs +++ b/src/core/xs.rs @@ -6,6 +6,7 @@ use std::ops::{Deref, Index}; use crate::{generate_random_string, X}; +/// Collection of named tensors with associated images and texts. #[derive(Builder, Debug, Default, Clone)] pub struct Xs { map: HashMap, diff --git a/src/inference/mod.rs b/src/inference/mod.rs deleted file mode 100644 index 7cd2231..0000000 --- a/src/inference/mod.rs +++ /dev/null @@ -1,36 +0,0 @@ -#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] -mod engine; -mod hbb; -mod image; -mod instance_meta; -mod keypoint; -mod mask; -mod obb; -mod polygon; -mod prob; -mod skeleton; -mod text; -mod x; -mod xs; -mod y; -#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] -#[allow(clippy::all)] -pub(crate) mod onnx { - include!(concat!(env!("OUT_DIR"), "/onnx.rs")); -} - -#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] -pub use engine::*; -pub use hbb::*; -pub use image::*; -pub use instance_meta::*; -pub use keypoint::*; -pub use mask::*; -pub use obb::*; -pub use polygon::*; -pub use prob::*; -pub use skeleton::*; -pub use text::*; -pub use x::X; -pub use xs::Xs; -pub use y::*; diff --git a/src/io/mod.rs b/src/io/mod.rs deleted file mode 100644 index 183a926..0000000 --- a/src/io/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod dataloader; -mod dir; -mod hub; -mod media; - -pub use dataloader::*; -pub use dir::*; -pub use hub::*; -pub use media::*; diff --git a/src/lib.rs b/src/lib.rs index 5a3e869..2d138a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,64 @@ -mod inference; -mod io; +//! # usls +//! +//! `usls` is a cross-platform Rust library that provides efficient inference for SOTA vision and multi-modal models using ONNX Runtime (typically under 1B parameters). +//! +//! ## πŸ“š Documentation +//! - [API Documentation](https://docs.rs/usls/latest/usls/) +//! - [Examples](https://github.com/jamjamjon/usls/tree/main/examples) +//! ## πŸš€ Quick Start +//! +//! ```bash +//! # CPU +//! cargo run -r --example yolo # YOLOv8 detect by default +//! +//! # NVIDIA CUDA +//! cargo run -r -F cuda --example yolo -- --device cuda:0 +//! +//! # NVIDIA TensorRT +//! cargo run -r -F tensorrt --example yolo -- --device tensorrt:0 +//! +//! # Apple Silicon CoreML +//! cargo run -r -F coreml --example yolo -- --device coreml +//! +//! # Intel OpenVINO +//! cargo run -r -F openvino -F ort-load-dynamic --example yolo -- --device openvino:CPU +//! ``` +//! +//! ## ⚑ Cargo Features +//! - **`ort-download-binaries`** (**default**): Automatically downloads prebuilt ONNXRuntime binaries for supported platforms +//! - **`ort-load-dynamic`**: Dynamic linking to ONNXRuntime libraries ([Guide](https://ort.pyke.io/setup/linking#dynamic-linking)) +//! - **`video`**: Enable video stream reading and writing (via [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb)) +//! - **`cuda`**: NVIDIA CUDA GPU acceleration support +//! - **`tensorrt`**: NVIDIA TensorRT optimization for inference acceleration +//! - **`coreml`**: Apple CoreML acceleration for macOS/iOS devices +//! - **`openvino`**: Intel OpenVINO toolkit for CPU/GPU/VPU acceleration +//! - **`onednn`**: Intel oneDNN (formerly MKL-DNN) for CPU optimization +//! - **`directml`**: Microsoft DirectML for Windows GPU acceleration +//! - **`xnnpack`**: Google XNNPACK for mobile and edge device optimization +//! - **`rocm`**: AMD ROCm platform for GPU acceleration +//! - **`cann`**: Huawei CANN (Compute Architecture for Neural Networks) support +//! - **`rknpu`**: Rockchip NPU acceleration +//! - **`acl`**: Arm Compute Library for Arm processors +//! - **`nnapi`**: Android Neural Networks API support +//! - **`armnn`**: Arm NN inference engine +//! - **`tvm`**: Apache TVM tensor compiler stack +//! - **`qnn`**: Qualcomm Neural Network SDK +//! - **`migraphx`**: AMD MIGraphX for GPU acceleration +//! - **`vitis`**: Xilinx Vitis AI for FPGA acceleration +//! - **`azure`**: Azure Machine Learning integration +//! + +mod core; +/// Model Zoo #[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] pub mod models; -mod utils; +#[macro_use] +mod results; mod viz; -pub use inference::*; -pub use io::*; +pub use core::*; pub use minifb::Key; #[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] pub use models::*; -pub use utils::*; +pub use results::*; pub use viz::*; diff --git a/src/models/db/impl.rs b/src/models/db/impl.rs index 781a2aa..05596ea 100644 --- a/src/models/db/impl.rs +++ b/src/models/db/impl.rs @@ -7,6 +7,7 @@ use crate::{ elapsed, Config, DynConf, Engine, Hbb, Image, Mask, Obb, Ops, Polygon, Processor, Ts, Xs, Y, }; +/// DB (Differentiable Binarization) model for text detection. #[derive(Debug, Builder)] pub struct DB { engine: Engine, diff --git a/src/models/grounding_dino/impl.rs b/src/models/grounding_dino/impl.rs index efebd19..3df9a5f 100644 --- a/src/models/grounding_dino/impl.rs +++ b/src/models/grounding_dino/impl.rs @@ -7,6 +7,7 @@ use std::fmt::Write; use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y}; #[derive(Builder, Debug)] +/// Grounding DINO model for open-vocabulary object detection. pub struct GroundingDINO { pub engine: Engine, height: usize, diff --git a/src/models/mediapipe_segmenter/README.md b/src/models/mediapipe_segmenter/README.md new file mode 100644 index 0000000..83cf4a8 --- /dev/null +++ b/src/models/mediapipe_segmenter/README.md @@ -0,0 +1,9 @@ +# MediaPipe: Selfie segmentation model + +## Official Repository + +The official website can be found on: [GUIDE](https://ai.google.dev/edge/mediapipe/solutions/vision/image_segmenter) + +## Example + +Refer to the [example](../../../examples/mediapipe-selfie-segmentation) diff --git a/src/models/mediapipe_segmenter/config.rs b/src/models/mediapipe_segmenter/config.rs new file mode 100644 index 0000000..19a98c3 --- /dev/null +++ b/src/models/mediapipe_segmenter/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `MediaPipeSegmenter` +impl crate::Config { + pub fn mediapipe() -> Self { + Self::default() + .with_name("mediapipe") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 256.into()) + .with_model_ixx(0, 3, 256.into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_normalize(true) + } + + pub fn mediapipe_selfie_segmentater() -> Self { + Self::mediapipe().with_model_file("selfie-segmenter.onnx") + } + + pub fn mediapipe_selfie_segmentater_landscape() -> Self { + Self::mediapipe().with_model_file("selfie-segmenter-landscape.onnx") + } +} diff --git a/src/models/mediapipe_segmenter/impl.rs b/src/models/mediapipe_segmenter/impl.rs new file mode 100644 index 0000000..f492328 --- /dev/null +++ b/src/models/mediapipe_segmenter/impl.rs @@ -0,0 +1,91 @@ +use aksr::Builder; +use anyhow::Result; +use ndarray::Axis; + +use crate::{elapsed, Config, Engine, Image, Mask, Ops, Processor, Ts, Xs, Y}; + +#[derive(Builder, Debug)] +pub struct MediaPipeSegmenter { + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, +} + +impl MediaPipeSegmenter { + pub fn new(config: Config) -> Result { + let engine = Engine::try_from_config(&config.model)?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&256.into()).opt(), + engine.try_width().unwrap_or(&256.into()).opt(), + engine.ts().clone(), + ); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + ts, + spec, + processor, + }) + } + + fn preprocess(&mut self, xs: &[Image]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[Image]) -> Result> { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result> { + let mut ys: Vec = Vec::new(); + for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { + let (h1, w1) = ( + self.processor.images_transform_info[idx].height_src, + self.processor.images_transform_info[idx].width_src, + ); + + let luma = luma.mapv(|x| (x * 255.0) as u8); + let luma = Ops::resize_luma8_u8( + &luma.into_raw_vec_and_offset().0, + self.width as _, + self.height as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + + Ok(ys) + } +} diff --git a/src/models/mediapipe_segmenter/mod.rs b/src/models/mediapipe_segmenter/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/mediapipe_segmenter/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/mod.rs b/src/models/mod.rs index b0768cf..1391c99 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -15,6 +15,7 @@ mod fastvit; mod florence2; mod grounding_dino; mod linknet; +mod mediapipe_segmenter; mod mobileone; mod modnet; mod moondream2; @@ -43,6 +44,7 @@ pub use depth_pro::*; pub use dinov2::*; pub use florence2::*; pub use grounding_dino::*; +pub use mediapipe_segmenter::*; pub use modnet::*; pub use moondream2::*; pub use owl::*; diff --git a/src/models/owl/impl.rs b/src/models/owl/impl.rs index ad7a4e1..087b8a0 100644 --- a/src/models/owl/impl.rs +++ b/src/models/owl/impl.rs @@ -5,6 +5,7 @@ use rayon::prelude::*; use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y}; +/// OWL-ViT v2 model for open-vocabulary object detection. #[derive(Debug, Builder)] pub struct OWLv2 { engine: Engine, diff --git a/src/models/rtmo/config.rs b/src/models/rtmo/config.rs index 03f1d44..768d97e 100644 --- a/src/models/rtmo/config.rs +++ b/src/models/rtmo/config.rs @@ -15,6 +15,13 @@ impl crate::Config { .with_keypoint_confs(&[0.5]) } + pub fn rtmo_t() -> Self { + Self::rtmo() + .with_model_ixx(0, 2, 416.into()) + .with_model_ixx(0, 3, 416.into()) + .with_model_file("t.onnx") + } + pub fn rtmo_s() -> Self { Self::rtmo().with_model_file("s.onnx") } diff --git a/src/models/rtmo/impl.rs b/src/models/rtmo/impl.rs index 14cf442..5ec5c9d 100644 --- a/src/models/rtmo/impl.rs +++ b/src/models/rtmo/impl.rs @@ -1,6 +1,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Axis; +use rayon::prelude::*; use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Keypoint, Processor, Ts, Xs, Y}; @@ -68,69 +69,67 @@ impl RTMO { } fn postprocess(&mut self, xs: Xs) -> Result> { - let mut ys: Vec = Vec::new(); - // let (preds_bboxes, preds_kpts) = (&xs["dets"], &xs["keypoints"]); - let (preds_bboxes, preds_kpts) = (&xs[0], &xs[1]); - - for (idx, (batch_bboxes, batch_kpts)) in preds_bboxes + let ys: Vec = xs[0] .axis_iter(Axis(0)) - .zip(preds_kpts.axis_iter(Axis(0))) + .into_par_iter() + .zip(xs[1].axis_iter(Axis(0)).into_par_iter()) .enumerate() - { - let (height_original, width_original) = ( - self.processor.images_transform_info[idx].height_src, - self.processor.images_transform_info[idx].width_src, - ); - let ratio = self.processor.images_transform_info[idx].height_scale; - - let mut y_bboxes = Vec::new(); - let mut y_kpts: Vec> = Vec::new(); - for (xyxyc, kpts) in batch_bboxes - .axis_iter(Axis(0)) - .zip(batch_kpts.axis_iter(Axis(0))) - { - // bbox - let x1 = xyxyc[0] / ratio; - let y1 = xyxyc[1] / ratio; - let x2 = xyxyc[2] / ratio; - let y2 = xyxyc[3] / ratio; - let confidence = xyxyc[4]; - - if confidence < self.confs[0] { - continue; - } - y_bboxes.push( - Hbb::default() - .with_xyxy( - x1.max(0.0f32).min(width_original as _), - y1.max(0.0f32).min(height_original as _), - x2, - y2, - ) - .with_confidence(confidence) - .with_id(0) - .with_name("Person"), + .map(|(idx, (batch_bboxes, batch_kpts))| { + let (height_original, width_original) = ( + self.processor.images_transform_info[idx].height_src, + self.processor.images_transform_info[idx].width_src, ); + let ratio = self.processor.images_transform_info[idx].height_scale; - // keypoints - let mut kpts_ = Vec::new(); - for (i, kpt) in kpts.axis_iter(Axis(0)).enumerate() { - let x = kpt[0] / ratio; - let y = kpt[1] / ratio; - let c = kpt[2]; - if c < self.kconfs[i] { - kpts_.push(Keypoint::default()); - } else { - kpts_.push(Keypoint::default().with_id(i).with_confidence(c).with_xy( - x.max(0.0f32).min(width_original as _), - y.max(0.0f32).min(height_original as _), - )); + let mut y_bboxes = Vec::new(); + let mut y_kpts: Vec> = Vec::new(); + for (xyxyc, kpts) in batch_bboxes + .axis_iter(Axis(0)) + .zip(batch_kpts.axis_iter(Axis(0))) + { + // bbox + let x1 = xyxyc[0] / ratio; + let y1 = xyxyc[1] / ratio; + let x2 = xyxyc[2] / ratio; + let y2 = xyxyc[3] / ratio; + let confidence = xyxyc[4]; + + if confidence < self.confs[0] { + continue; } + y_bboxes.push( + Hbb::default() + .with_xyxy( + x1.max(0.0f32).min(width_original as _), + y1.max(0.0f32).min(height_original as _), + x2, + y2, + ) + .with_confidence(confidence) + .with_id(0) + .with_name("Person"), + ); + + // keypoints + let mut kpts_ = Vec::new(); + for (i, kpt) in kpts.axis_iter(Axis(0)).enumerate() { + let x = kpt[0] / ratio; + let y = kpt[1] / ratio; + let c = kpt[2]; + if c < self.kconfs[i] { + kpts_.push(Keypoint::default()); + } else { + kpts_.push(Keypoint::default().with_id(i).with_confidence(c).with_xy( + x.max(0.0f32).min(width_original as _), + y.max(0.0f32).min(height_original as _), + )); + } + } + y_kpts.push(kpts_); } - y_kpts.push(kpts_); - } - ys.push(Y::default().with_hbbs(&y_bboxes).with_keypointss(&y_kpts)); - } + Y::default().with_hbbs(&y_bboxes).with_keypointss(&y_kpts) + }) + .collect(); Ok(ys) } diff --git a/src/models/sam/config.rs b/src/models/sam/config.rs index bd03825..102e61f 100644 --- a/src/models/sam/config.rs +++ b/src/models/sam/config.rs @@ -2,6 +2,13 @@ use crate::{models::SamKind, Config}; /// Model configuration for `Segment Anything Model` impl Config { + /// Creates a base SAM configuration with common settings. + /// + /// Sets up default parameters for image preprocessing and model architecture: + /// - 1024x1024 input resolution + /// - Adaptive resize mode + /// - Image normalization parameters + /// - Contour finding enabled pub fn sam() -> Self { Self::default() .with_name("sam") @@ -19,6 +26,8 @@ impl Config { .with_find_contours(true) } + /// Creates a configuration for SAM v1 base model. + /// Uses the original ViT-B architecture. pub fn sam_v1_base() -> Self { Self::sam() .with_encoder_file("sam-vit-b-encoder.onnx") @@ -29,6 +38,8 @@ impl Config { // Self::sam().with_decoder_file("sam-vit-b-decoder-singlemask.onnx") // } + /// Creates a configuration for SAM 2.0 tiny model. + /// Uses a hierarchical architecture with tiny backbone. pub fn sam2_tiny() -> Self { Self::sam() .with_encoder_file("sam2-hiera-tiny-encoder.onnx") @@ -36,6 +47,8 @@ impl Config { .with_decoder_file("sam2-hiera-tiny-decoder.onnx") } + /// Creates a configuration for SAM 2.0 small model. + /// Uses a hierarchical architecture with small backbone. pub fn sam2_small() -> Self { Self::sam() .with_encoder_file("sam2-hiera-small-encoder.onnx") @@ -43,6 +56,8 @@ impl Config { .with_sam_kind(SamKind::Sam2) } + /// Creates a configuration for SAM 2.0 base-plus model. + /// Uses a hierarchical architecture with enhanced base backbone. pub fn sam2_base_plus() -> Self { Self::sam() .with_encoder_file("sam2-hiera-base-plus-encoder.onnx") @@ -50,6 +65,8 @@ impl Config { .with_sam_kind(SamKind::Sam2) } + /// Creates a configuration for MobileSAM tiny model. + /// Lightweight model optimized for mobile devices. pub fn mobile_sam_tiny() -> Self { Self::sam() .with_encoder_file("mobile-sam-vit-t-encoder.onnx") @@ -57,6 +74,8 @@ impl Config { .with_decoder_file("mobile-sam-vit-t-decoder.onnx") } + /// Creates a configuration for SAM-HQ tiny model. + /// High-quality variant focused on better mask quality. pub fn sam_hq_tiny() -> Self { Self::sam() .with_encoder_file("sam-hq-vit-t-encoder.onnx") @@ -64,6 +83,8 @@ impl Config { .with_decoder_file("sam-hq-vit-t-decoder.onnx") } + /// Creates a configuration for EdgeSAM 3x model. + /// Edge-based variant optimized for speed and efficiency. pub fn edge_sam_3x() -> Self { Self::sam() .with_encoder_file("edge-sam-3x-encoder.onnx") diff --git a/src/models/sam/impl.rs b/src/models/sam/impl.rs index 2abea4f..fe5a9ae 100644 --- a/src/models/sam/impl.rs +++ b/src/models/sam/impl.rs @@ -8,12 +8,18 @@ use crate::{ elapsed, Config, DynConf, Engine, Image, Mask, Ops, Polygon, Processor, SamPrompt, Ts, Xs, X, Y, }; +/// SAM model variants for different use cases. #[derive(Debug, Clone)] pub enum SamKind { + /// Original SAM model Sam, - Sam2, // 2.0 + /// SAM 2.0 with hierarchical architecture + Sam2, + /// Mobile optimized SAM MobileSam, + /// High quality SAM with better segmentation SamHq, + /// Efficient SAM with edge-based segmentation EdgeSam, } @@ -32,6 +38,10 @@ impl FromStr for SamKind { } } +/// Segment Anything Model (SAM) for image segmentation. +/// +/// A foundation model for generating high-quality object masks from input prompts such as points or boxes. +/// Supports multiple variants including the original SAM, SAM2, MobileSAM, SAM-HQ and EdgeSAM. #[derive(Builder, Debug)] pub struct SAM { encoder: Engine, @@ -49,6 +59,10 @@ pub struct SAM { } impl SAM { + /// Creates a new SAM model instance from the provided configuration. + /// + /// Initializes the model based on the specified SAM variant (original SAM, SAM2, MobileSAM etc.) + /// and configures its encoder-decoder architecture. pub fn new(config: Config) -> Result { let encoder = Engine::try_from_config(&config.encoder)?; let decoder = Engine::try_from_config(&config.decoder)?; @@ -94,6 +108,11 @@ impl SAM { }) } + /// Runs the complete segmentation pipeline on a batch of images. + /// + /// The pipeline consists of: + /// 1. Encoding the images into embeddings + /// 2. Decoding the embeddings with input prompts to generate segmentation masks pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result> { let ys = elapsed!("encode", self.ts, { self.encode(xs)? }); let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? }); @@ -101,11 +120,16 @@ impl SAM { Ok(ys) } + /// Encodes input images into image embeddings. pub fn encode(&mut self, xs: &[Image]) -> Result { let xs_ = self.processor.process_images(xs)?; self.encoder.run(Xs::from(xs_)) } + /// Generates segmentation masks from image embeddings and input prompts. + /// + /// Takes the image embeddings from the encoder and input prompts (points or boxes) + /// to generate binary segmentation masks for the prompted objects. pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result> { let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind { SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])), @@ -285,10 +309,12 @@ impl SAM { Ok(ys) } + /// Returns the width of the low-resolution feature maps. pub fn width_low_res(&self) -> usize { self.width / 4 } + /// Returns the height of the low-resolution feature maps. pub fn height_low_res(&self) -> usize { self.height / 4 } diff --git a/src/models/sam/mod.rs b/src/models/sam/mod.rs index bce941d..0e7849a 100644 --- a/src/models/sam/mod.rs +++ b/src/models/sam/mod.rs @@ -3,9 +3,12 @@ mod r#impl; pub use r#impl::*; +/// SAM prompt containing coordinates and labels for segmentation. #[derive(Debug, Default, Clone)] pub struct SamPrompt { + /// Point coordinates for prompting. pub coords: Vec>, + /// Labels corresponding to the coordinates. pub labels: Vec>, } diff --git a/src/models/sam2/config.rs b/src/models/sam2/config.rs index 1ca09c8..a0388f5 100644 --- a/src/models/sam2/config.rs +++ b/src/models/sam2/config.rs @@ -2,24 +2,36 @@ use crate::Config; /// Model configuration for `SAM2.1` impl Config { + /// Creates a configuration for SAM 2.1 tiny model. + /// + /// The smallest variant of the hierarchical architecture, optimized for speed. pub fn sam2_1_tiny() -> Self { Self::sam() .with_encoder_file("sam2.1-hiera-tiny-encoder.onnx") .with_decoder_file("sam2.1-hiera-tiny-decoder.onnx") } + /// Creates a configuration for SAM 2.1 small model. + /// + /// A balanced variant offering good performance and efficiency. pub fn sam2_1_small() -> Self { Self::sam() .with_encoder_file("sam2.1-hiera-small-encoder.onnx") .with_decoder_file("sam2.1-hiera-small-decoder.onnx") } + /// Creates a configuration for SAM 2.1 base-plus model. + /// + /// An enhanced base model with improved segmentation quality. pub fn sam2_1_base_plus() -> Self { Self::sam() .with_encoder_file("sam2.1-hiera-base-plus-encoder.onnx") .with_decoder_file("sam2.1-hiera-base-plus-decoder.onnx") } + /// Creates a configuration for SAM 2.1 large model. + /// + /// The most powerful variant with highest segmentation quality. pub fn sam2_1_large() -> Self { Self::sam() .with_encoder_file("sam2.1-hiera-large-encoder.onnx") diff --git a/src/models/sam2/impl.rs b/src/models/sam2/impl.rs index 822ad06..365d7ce 100644 --- a/src/models/sam2/impl.rs +++ b/src/models/sam2/impl.rs @@ -6,6 +6,10 @@ use crate::{ elapsed, Config, DynConf, Engine, Image, Mask, Ops, Processor, SamPrompt, Ts, Xs, X, Y, }; +/// SAM2 (Segment Anything Model 2.1) for advanced image segmentation. +/// +/// A hierarchical vision foundation model with improved efficiency and quality. +/// Features enhanced backbone architecture and optimized prompt handling. #[derive(Builder, Debug)] pub struct SAM2 { encoder: Engine, @@ -20,6 +24,12 @@ pub struct SAM2 { } impl SAM2 { + /// Creates a new SAM2 model instance from the provided configuration. + /// + /// Initializes the model with: + /// - Encoder-decoder architecture + /// - Image preprocessing settings + /// - Confidence thresholds pub fn new(config: Config) -> Result { let encoder = Engine::try_from_config(&config.encoder)?; let decoder = Engine::try_from_config(&config.decoder)?; @@ -48,6 +58,11 @@ impl SAM2 { }) } + /// Runs the complete segmentation pipeline on a batch of images. + /// + /// The pipeline consists of: + /// 1. Image encoding into hierarchical features + /// 2. Prompt-guided mask generation pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result> { let ys = elapsed!("encode", self.ts, { self.encode(xs)? }); let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? }); @@ -55,11 +70,16 @@ impl SAM2 { Ok(ys) } + /// Encodes input images into hierarchical feature representations. pub fn encode(&mut self, xs: &[Image]) -> Result { let xs_ = self.processor.process_images(xs)?; self.encoder.run(Xs::from(xs_)) } + /// Generates segmentation masks from encoded features and prompts. + /// + /// Takes hierarchical image features and user prompts (points/boxes) + /// to generate accurate object masks. pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result> { let (image_embeddings, high_res_features_0, high_res_features_1) = (&xs[0], &xs[1], &xs[2]); @@ -153,10 +173,12 @@ impl SAM2 { Ok(ys) } + /// Returns the width of the low-resolution feature maps. pub fn width_low_res(&self) -> usize { self.width / 4 } + /// Returns the height of the low-resolution feature maps. pub fn height_low_res(&self) -> usize { self.height / 4 } diff --git a/src/models/svtr/impl.rs b/src/models/svtr/impl.rs index 65707c7..b1cf82d 100644 --- a/src/models/svtr/impl.rs +++ b/src/models/svtr/impl.rs @@ -5,6 +5,7 @@ use rayon::prelude::*; use crate::{elapsed, Config, DynConf, Engine, Image, Processor, Text, Ts, Xs, Y}; +/// SVTR (Scene Text Recognition) model for text recognition. #[derive(Builder, Debug)] pub struct SVTR { engine: Engine, diff --git a/src/models/trocr/config.rs b/src/models/trocr/config.rs index 3b6f6c9..4a31405 100644 --- a/src/models/trocr/config.rs +++ b/src/models/trocr/config.rs @@ -1,7 +1,15 @@ use crate::Scale; -/// Model configuration for `TrOCR` +/// Model configuration for `TrOCR`. impl crate::Config { + /// Creates a base configuration for TrOCR models with default settings. + /// + /// This includes: + /// - Batch size of 1 + /// - Image input dimensions of 384x384 with 3 channels + /// - Image normalization with mean and std of [0.5, 0.5, 0.5] + /// - Lanczos3 resize filter + /// - Default tokenizer and model configuration files pub fn trocr() -> Self { Self::default() .with_name("trocr") @@ -18,6 +26,9 @@ impl crate::Config { .with_tokenizer_config_file("trocr/tokenizer_config.json") } + /// Creates a configuration for the small TrOCR model variant optimized for printed text. + /// + /// Uses the small scale model files and tokenizer configuration. pub fn trocr_small_printed() -> Self { Self::trocr() .with_scale(Scale::S) @@ -27,6 +38,9 @@ impl crate::Config { .with_tokenizer_file("trocr/tokenizer-small.json") } + /// Creates a configuration for the base TrOCR model variant optimized for handwritten text. + /// + /// Uses the base scale model files and tokenizer configuration. pub fn trocr_base_handwritten() -> Self { Self::trocr() .with_scale(Scale::B) @@ -36,6 +50,9 @@ impl crate::Config { .with_tokenizer_file("trocr/tokenizer-base.json") } + /// Creates a configuration for the small TrOCR model variant optimized for handwritten text. + /// + /// Modifies the small printed configuration to use handwritten-specific model files. pub fn trocr_small_handwritten() -> Self { Self::trocr_small_printed() .with_visual_file("s-encoder-handwritten.onnx") @@ -43,6 +60,9 @@ impl crate::Config { .with_textual_decoder_merged_file("s-decoder-merged-handwritten.onnx") } + /// Creates a configuration for the base TrOCR model variant optimized for printed text. + /// + /// Modifies the base handwritten configuration to use printed-specific model files. pub fn trocr_base_printed() -> Self { Self::trocr_base_handwritten() .with_visual_file("b-encoder-printed.onnx") diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs index f1a3679..c5b2787 100644 --- a/src/models/trocr/impl.rs +++ b/src/models/trocr/impl.rs @@ -6,9 +6,12 @@ use std::str::FromStr; use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y}; +/// TrOCR model variants for different text types. #[derive(Debug, Copy, Clone)] pub enum TrOCRKind { + /// Model variant optimized for machine-printed text recognition Printed, + /// Model variant optimized for handwritten text recognition HandWritten, } @@ -24,23 +27,59 @@ impl FromStr for TrOCRKind { } } +/// TrOCR model for optical character recognition. +/// +/// TrOCR is a transformer-based OCR model that combines an image encoder with +/// a text decoder for end-to-end text recognition. It supports both printed and +/// handwritten text recognition through different model variants. +/// +/// The model consists of: +/// - An encoder that processes the input image +/// - A decoder that generates text tokens +/// - A merged decoder variant for optimized inference +/// - A processor for image preprocessing and text postprocessing #[derive(Debug, Builder)] pub struct TrOCR { + /// Image encoder engine encoder: Engine, + /// Text decoder engine for token generation decoder: Engine, + /// Optimized merged decoder engine decoder_merged: Engine, + /// Maximum length of generated text sequence max_length: u32, + /// Token ID representing end of sequence eos_token_id: u32, + /// Token ID used to start text generation decoder_start_token_id: u32, + /// Timestamp tracking for performance analysis ts: Ts, + /// Number of key-value pairs in decoder attention n_kvs: usize, + /// Image and text processor for pre/post processing processor: Processor, + /// Batch size for inference batch: usize, + /// Input image height height: usize, + /// Input image width width: usize, } impl TrOCR { + /// Creates a new TrOCR model instance from the given configuration. + /// + /// # Arguments + /// * `config` - The model configuration containing paths to model files and parameters + /// + /// # Returns + /// * `Result` - A new TrOCR instance if initialization succeeds + /// + /// # Errors + /// Returns an error if: + /// - Required model files cannot be loaded + /// - Model configuration is invalid + /// - Tokenizer initialization fails pub fn new(config: Config) -> Result { let encoder = Engine::try_from_config(&config.visual)?; let decoder = Engine::try_from_config(&config.textual_decoder)?; @@ -86,6 +125,16 @@ impl TrOCR { }) } + /// Encodes the given images into feature vectors using the TrOCR encoder. + /// + /// This method processes the images through the image processor and then + /// encodes them using the encoder engine. + /// + /// # Arguments + /// * `xs` - A slice of `Image` instances to be encoded. + /// + /// # Errors + /// Returns an error if image processing or encoding fails. pub fn encode(&mut self, xs: &[Image]) -> Result { let ys = self.processor.process_images(xs)?; self.batch = xs.len(); // update @@ -93,6 +142,16 @@ impl TrOCR { Ok(ys[0].to_owned()) } + /// Performs the forward pass of the TrOCR model, from encoding images to decoding text. + /// + /// This method encodes the input images, generates token IDs using the decoder, + /// and finally decodes the token IDs into text. + /// + /// # Arguments + /// * `xs` - A slice of `Image` instances to be processed. + /// + /// # Errors + /// Returns an error if any step in the forward pass fails. pub fn forward(&mut self, xs: &[Image]) -> Result> { let encoder_hidden_states = elapsed!("encode", self.ts, { self.encode(xs)? }); let generated = elapsed!("generate", self.ts, { @@ -182,6 +241,13 @@ impl TrOCR { Ok(token_ids) } + /// Decodes the given token IDs into text using the TrOCR processor. + /// + /// # Arguments + /// * `token_ids` - A vector of vector of token IDs to be decoded. + /// + /// # Errors + /// Returns an error if decoding fails. pub fn decode(&self, token_ids: Vec>) -> Result> { // decode let texts = self.processor.decode_tokens_batch(&token_ids, false)?; @@ -195,6 +261,7 @@ impl TrOCR { Ok(texts) } + /// Displays a summary of the TrOCR model's configuration and state. pub fn summary(&self) { self.ts.summary(); } diff --git a/src/models/yolo/config.rs b/src/models/yolo/config.rs index 1f240e9..956ccdd 100644 --- a/src/models/yolo/config.rs +++ b/src/models/yolo/config.rs @@ -5,6 +5,9 @@ use crate::{ }; impl Config { + /// Creates a base YOLO configuration with common settings. + /// + /// Sets up default input dimensions (640x640) and image processing parameters. pub fn yolo() -> Self { Self::default() .with_name("yolo") @@ -16,6 +19,12 @@ impl Config { .with_resize_filter("CatmullRom") } + /// Creates a configuration for YOLO image classification. + /// + /// Configures the model for ImageNet classification with: + /// - 224x224 input size + /// - Exact resize mode with bilinear interpolation + /// - ImageNet 1000 class names pub fn yolo_classify() -> Self { Self::yolo() .with_task(Task::ImageClassification) @@ -26,24 +35,38 @@ impl Config { .with_class_names(&NAMES_IMAGENET_1K) } + /// Creates a configuration for YOLO object detection. + /// + /// Configures the model for COCO dataset object detection with 80 classes. pub fn yolo_detect() -> Self { Self::yolo() .with_task(Task::ObjectDetection) .with_class_names(&NAMES_COCO_80) } + /// Creates a configuration for YOLO pose estimation. + /// + /// Configures the model for human keypoint detection with 17 COCO keypoints. pub fn yolo_pose() -> Self { Self::yolo() .with_task(Task::KeypointsDetection) .with_keypoint_names(&NAMES_COCO_KEYPOINTS_17) } + /// Creates a configuration for YOLO instance segmentation. + /// + /// Configures the model for COCO dataset instance segmentation with 80 classes. pub fn yolo_segment() -> Self { Self::yolo() .with_task(Task::InstanceSegmentation) .with_class_names(&NAMES_COCO_80) } + /// Creates a configuration for YOLO oriented object detection. + /// + /// Configures the model for detecting rotated objects with: + /// - 1024x1024 input size + /// - DOTA v1 dataset classes pub fn yolo_obb() -> Self { Self::yolo() .with_model_ixx(0, 2, 1024.into()) @@ -52,6 +75,11 @@ impl Config { .with_class_names(&NAMES_DOTA_V1_15) } + /// Creates a configuration for document layout analysis using YOLOv10. + /// + /// Configures the model for detecting document structure elements with: + /// - Variable input size up to 1024x1024 + /// - 10 document layout classes pub fn doclayout_yolo_docstructbench() -> Self { Self::yolo_detect() .with_version(10.into()) @@ -62,12 +90,16 @@ impl Config { .with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1 } - // YOLOE models + /// Creates a base YOLOE configuration with 4585 classes. + /// + /// Configures the model for instance segmentation with a large class vocabulary. pub fn yoloe() -> Self { Self::yolo() .with_task(Task::InstanceSegmentation) .with_class_names(&NAMES_YOLOE_4585) } + /// Creates a configuration for YOLOE-v8s segmentation model. + /// Uses the small variant of YOLOv8 architecture. pub fn yoloe_v8s_seg_pf() -> Self { Self::yoloe() .with_version(8.into()) @@ -75,6 +107,8 @@ impl Config { .with_model_file("yoloe-v8s-seg-pf.onnx") } + /// Creates a configuration for YOLOE-v8m segmentation model. + /// Uses the medium variant of YOLOv8 architecture. pub fn yoloe_v8m_seg_pf() -> Self { Self::yoloe() .with_version(8.into()) @@ -82,6 +116,8 @@ impl Config { .with_model_file("yoloe-v8m-seg-pf.onnx") } + /// Creates a configuration for YOLOE-v8l segmentation model. + /// Uses the large variant of YOLOv8 architecture. pub fn yoloe_v8l_seg_pf() -> Self { Self::yoloe() .with_version(8.into()) @@ -89,6 +125,8 @@ impl Config { .with_model_file("yoloe-v8l-seg-pf.onnx") } + /// Creates a configuration for YOLOE-11s segmentation model. + /// Uses the small variant of YOLOv11 architecture. pub fn yoloe_11s_seg_pf() -> Self { Self::yoloe() .with_version(11.into()) diff --git a/src/models/yolo/impl.rs b/src/models/yolo/impl.rs index 5f56d11..5440d8b 100644 --- a/src/models/yolo/impl.rs +++ b/src/models/yolo/impl.rs @@ -12,6 +12,14 @@ use crate::{ Ts, Version, Xs, Y, }; +/// YOLO (You Only Look Once) object detection model. +/// +/// A versatile deep learning model that can perform multiple computer vision tasks including: +/// - Object Detection +/// - Instance Segmentation +/// - Keypoint Detection +/// - Image Classification +/// - Oriented Object Detection #[derive(Debug, Builder)] pub struct YOLO { engine: Engine, @@ -45,6 +53,12 @@ impl TryFrom for YOLO { } impl YOLO { + /// Creates a new YOLO model instance from the provided configuration. + /// + /// This function initializes the model by: + /// - Loading the ONNX model file + /// - Setting up input dimensions and processing parameters + /// - Configuring task-specific settings and output formats pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts, spec) = ( @@ -284,16 +298,24 @@ impl YOLO { }) } + /// Pre-processes input images before model inference. fn preprocess(&mut self, xs: &[Image]) -> Result { let x = self.processor.process_images(xs)?; Ok(x.into()) } + /// Performs model inference on the pre-processed input. fn inference(&mut self, xs: Xs) -> Result { self.engine.run(xs) } + /// Performs the complete inference pipeline on a batch of images. + /// + /// The pipeline consists of: + /// 1. Pre-processing the input images + /// 2. Running model inference + /// 3. Post-processing the outputs to generate final predictions pub fn forward(&mut self, xs: &[Image]) -> Result> { let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); @@ -302,6 +324,7 @@ impl YOLO { Ok(ys) } + /// Post-processes model outputs to generate final predictions. fn postprocess(&self, xs: Xs) -> Result> { // let protos = if xs.len() == 2 { Some(&xs[1]) } else { None }; let ys: Vec = xs[0] @@ -605,6 +628,7 @@ impl YOLO { // Ok(ys.into()) } + /// Extracts class names from the ONNX model metadata if available. fn fetch_names_from_onnx(engine: &Engine) -> Option> { // fetch class names from onnx metadata // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}` @@ -616,6 +640,7 @@ impl YOLO { .into() } + /// Extracts the number of keypoints from the ONNX model metadata if available. fn fetch_nk_from_onnx(engine: &Engine) -> Option { Regex::new(r"(\d+), \d+") .ok()? @@ -624,6 +649,7 @@ impl YOLO { .and_then(|m| m.as_str().parse::().ok()) } + /// Prints a summary of the model configuration and parameters. pub fn summary(&mut self) { self.ts.summary(); } diff --git a/src/models/yolo/preds.rs b/src/models/yolo/preds.rs index 12029bd..d9ba538 100644 --- a/src/models/yolo/preds.rs +++ b/src/models/yolo/preds.rs @@ -2,6 +2,7 @@ use ndarray::{ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr}; use crate::Task; +/// Bounding box coordinate format types. #[derive(Debug, Clone, PartialEq)] pub enum BoxType { Cxcywh, @@ -11,6 +12,7 @@ pub enum BoxType { XyCxcy, } +/// Classification output format types. #[derive(Debug, Clone, PartialEq)] pub enum ClssType { Clss, @@ -20,18 +22,21 @@ pub enum ClssType { ClssConf, } +/// Keypoint output format types. #[derive(Debug, Clone, PartialEq)] pub enum KptsType { Xys, Xycs, } +/// Anchor position in the prediction pipeline. #[derive(Debug, Clone, PartialEq)] pub enum AnchorsPosition { Before, After, } +/// YOLO prediction format configuration. #[derive(Debug, Clone, PartialEq)] pub struct YOLOPredsFormat { pub clss: ClssType, diff --git a/src/inference/hbb.rs b/src/results/hbb.rs similarity index 98% rename from src/inference/hbb.rs rename to src/results/hbb.rs index 04e5115..ca4753a 100644 --- a/src/inference/hbb.rs +++ b/src/results/hbb.rs @@ -1,7 +1,8 @@ use aksr::Builder; -use crate::{impl_meta_methods, InstanceMeta, Keypoint, Style}; +use crate::{InstanceMeta, Keypoint, Style}; +/// Horizontal bounding box with position, size, and metadata. #[derive(Builder, Clone, Default)] pub struct Hbb { x: f32, diff --git a/src/inference/instance_meta.rs b/src/results/instance_meta.rs similarity index 97% rename from src/inference/instance_meta.rs rename to src/results/instance_meta.rs index ef149eb..7fa3ce4 100644 --- a/src/inference/instance_meta.rs +++ b/src/results/instance_meta.rs @@ -1,4 +1,5 @@ #[derive(aksr::Builder, Clone, PartialEq)] +/// Metadata for detection instances including ID, confidence, and name. pub struct InstanceMeta { uid: usize, id: Option, @@ -78,7 +79,6 @@ impl InstanceMeta { } } -#[macro_export] macro_rules! impl_meta_methods { () => { pub fn with_uid(mut self, uid: usize) -> Self { diff --git a/src/inference/keypoint.rs b/src/results/keypoint.rs similarity index 99% rename from src/inference/keypoint.rs rename to src/results/keypoint.rs index 8ba23c3..b486535 100644 --- a/src/inference/keypoint.rs +++ b/src/results/keypoint.rs @@ -1,7 +1,7 @@ use aksr::Builder; use std::ops::{Add, Div, Mul, Sub}; -use crate::{impl_meta_methods, InstanceMeta, Style}; +use crate::{InstanceMeta, Style}; /// Represents a keypoint in a 2D space with optional metadata. #[derive(Builder, Default, Clone)] diff --git a/src/inference/mask.rs b/src/results/mask.rs similarity index 93% rename from src/inference/mask.rs rename to src/results/mask.rs index 47bb1c0..f1b9380 100644 --- a/src/inference/mask.rs +++ b/src/results/mask.rs @@ -3,13 +3,16 @@ use anyhow::Result; use image::GrayImage; use rayon::prelude::*; -use crate::{impl_meta_methods, InstanceMeta, Polygon, Style}; +use crate::{InstanceMeta, Polygon, Style}; /// Mask: Gray Image. #[derive(Builder, Default, Clone)] pub struct Mask { + /// The grayscale image representing the mask. mask: GrayImage, + /// Metadata associated with the mask instance. meta: InstanceMeta, + /// Optional styling information for visualization. style: Option