diff --git a/Cargo.toml b/Cargo.toml
index 8dc658b..dc42d58 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,7 +1,7 @@
[package]
name = "usls"
edition = "2021"
-version = "0.1.0-beta.4"
+version = "0.1.0-rc.1"
rust-version = "1.82"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
@@ -10,6 +10,7 @@ license = "MIT"
readme = "README.md"
exclude = ["assets/*", "examples/*", "runs/*", "benches/*", "tests/*"]
+
[dependencies]
anyhow = { version = "1" }
aksr = { version = "0.0.3" }
@@ -47,8 +48,6 @@ tokenizers = { version = "0.21.1" }
paste = "1.0.15"
base64ct = "=1.7.3"
-[build-dependencies]
-prost-build = "0.13.5"
[dev-dependencies]
argh = "0.1.13"
diff --git a/README.md b/README.md
index 90813d7..5888792 100644
--- a/README.md
+++ b/README.md
@@ -1,291 +1,146 @@
usls
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
-
-
-
+
+
+
+
+
-
- βοΈ Star if helpful! βοΈ
-
-**usls** is an evolving Rust library focused on inference for advanced **vision** and **vision-language** models, along with practical vision utilities.
+**usls** is a cross-platform Rust library powered by ONNX Runtime for efficient inference of SOTA vision and multi-modal models(typically under 1B parameters).
-- **SOTA Model Inference:** Supports a wide range of state-of-the-art vision and multi-modal models (typically with fewer than 1B parameters).
-- **Multi-backend Acceleration:** Supports CPU, CUDA, TensorRT, and CoreML.
-- **Easy Data Handling:** Easily read images, video streams, and folders with iterator support.
-- **Rich Result Types:** Built-in containers for common vision outputs like bounding boxes (Hbb, Obb), polygons, masks, etc.
-- **Annotation & Visualization:** Draw and display inference results directly, similar to OpenCV's `imshow()`.
+## π Documentation
+- [API Documentation](https://docs.rs/usls/latest/usls/)
+- [Examples](./examples)
-## π§© Supported Models
+## π Quick Start
+```bash
+# CPU
+cargo run -r --example yolo # YOLOv8-n detect by default
-- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLO11](https://github.com/ultralytics/ultralytics), [YOLOv12](https://github.com/sunsmarterjie/yolov12)
-- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
-- **Vision Models**: [RT-DETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro), [FastViT](https://github.com/apple/ml-fastvit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MobileOne](https://github.com/apple/ml-mobileone)
-- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [jina-clip-v1-v2](https://huggingface.co/jinaai/jina-clip-v1), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242), [Moondream2](https://github.com/vikhyat/moondream/tree/main)
-- **OCR-Related Models**: [FAST](https://github.com/czczup/FAST), [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947), [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159), [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html), [TrOCR](https://huggingface.co/microsoft/trocr-base-printed), [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
+# NVIDIA CUDA
+cargo run -r -F cuda --example yolo -- --device cuda:0
-
-Full list of supported models (click to expand)
+# NVIDIA TensorRT
+cargo run -r -F tensorrt --example yolo -- --device tensorrt:0
-| Model | Task / Description | Example | CoreML | CUDA
FP32 | CUDA
FP16 | TensorRT
FP32 | TensorRT
FP16 |
-| -------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------- | ------ | -------------- | -------------- | ------------------ | ------------------ |
-| [BEiT](https://github.com/microsoft/unilm/tree/master/beit) | Image Classification | [demo](examples/beit) | β
| β
| β
| | |
-| [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) | Image Classification | [demo](examples/convnext) | β
| β
| β
| | |
-| [FastViT](https://github.com/apple/ml-fastvit) | Image Classification | [demo](examples/fastvit) | β
| β
| β
| | |
-| [MobileOne](https://github.com/apple/ml-mobileone) | Image Classification | [demo](examples/mobileone) | β
| β
| β
| | |
-| [DeiT](https://github.com/facebookresearch/deit) | Image Classification | [demo](examples/deit) | β
| β
| β
| | |
-| [DINOv2](https://github.com/facebookresearch/dinov2) | VisionΒ Embedding | [demo](examples/dinov2) | β
| β
| β
| β
| β
|
-| [YOLOv5](https://github.com/ultralytics/yolov5) | Image Classification
Object Detection
Instance Segmentation | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [YOLOv8
YOLO11](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Image Classification
Oriented Object Detection
Keypoint Detection | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [YOLOv12](https://github.com/sunsmarterjie/yolov12) | Object Detection | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [RT-DETR](https://github.com/lyuwenyu/RT-DETR) | Object Detection | [demo](examples/rtdetr) | β
| β
| β
| | |
-| [RF-DETR](https://github.com/roboflow/rf-detr) | Object Detection | [demo](examples/rfdetr) | β
| β
| β
| | |
-| [PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) | Object Detection | [demo](examples/picodet-layout) | β
| β
| β
| | |
-| [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) | Object Detection | [demo](examples/picodet-layout) | β
| β
| β
| | |
-| [D-FINE](https://github.com/manhbd-22022602/D-FINE) | Object Detection | [demo](examples/d-fine) | β
| β
| β
| | |
-| [DEIM](https://github.com/ShihuaHuang95/DEIM) | Object Detection | [demo](examples/deim) | β
| β
| β
| | |
-| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | β
| β
| β
| β | β |
-| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | β
| β
| β
| | |
-| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) | β
| β
| β
| | |
-| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) | β
| β
| β
| | |
-| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) | β
| β
| β
| | |
-| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) | β
| β
| β
| | |
-| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Open-Set Detection With Language | [demo](examples/yolo) | β
| β
| β
| β
| β
|
-| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | β
| β
| β
| | |
-| [CLIP](https://github.com/openai/CLIP) | Vision-Language Embedding | [demo](examples/clip) | β
| β
| β
| β | β |
-| [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](examples/clip) | β
| β
| β
| β | β |
-| [jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2) | Vision-Language Embedding | [demo](examples/clip) | β
| β
| β
| β | β |
-| [mobileclip](https://github.com/apple/ml-mobileclip) | Vision-Language Embedding | [demo](examples/clip) | β
| β
| β
| β | β |
-| [BLIP](https://github.com/salesforce/BLIP) | Image Captioning | [demo](examples/blip) | β
| β
| β
| β | β |
-| [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) | β
| β
| β
| β
| β
|
-| [FAST](https://github.com/czczup/FAST) | Text Detection | [demo](examples/fast) | β
| β
| β
| β
| β
|
-| [LinkNet](https://arxiv.org/abs/1707.03718) | Text Detection | [demo](examples/linknet) | β
| β
| β
| β
| β
|
-| [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | β
| β
| β
| β
| β
|
-| [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) | Tabel Recognition | [demo](examples/slanet) | β
| β
| β
| | |
-| [TrOCR](https://huggingface.co/microsoft/trocr-base-printed) | Text Recognition | [demo](examples/trocr) | β
| β
| β
| | |
-| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | β
| β
| β
| β
| β
|
-| [DepthAnything v1
DepthAnything v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | β
| β
| β
| β | β |
-| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | β
| β
| β
| | |
-| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | β
| β
| β
| β
| β
|
-| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | β
| β
| β
| | |
-| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | β
| β
| β
| | |
-| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
ImageΒ Caption
Visual Question Answering | [demo](examples/moondream2) | β
| β
| β
| | |
-| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | β
| β
| β
| | |
-| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | β
| β
| β
| | |
-| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation
Background Removal | [demo](examples/rmbg) | β
| β
| β
| | |
-| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation
Background Removal | [demo](examples/rmbg) | β
| β
| β
| | |
+# Apple Silicon CoreML
+cargo run -r -F coreml --example yolo -- --device coreml
-
+# Intel OpenVINO
+cargo run -r -F openvino -F ort-load-dynamic --example yolo -- --device openvino:CPU
-
-## π οΈ Installation
-
-To get started, you'll need:
-
-### 1. Protocol Buffers Compiler (`protoc`)
-Required for building the project. [Official installation guide](https://protobuf.dev/installation/)
-```shell
-# Linux (apt)
-sudo apt install -y protobuf-compiler
-
-# macOS (Homebrew)
-brew install protobuf
-
-# Windows (Winget)
-winget install protobuf
-
-# Verify installation
-protoc --version # Should be 3.x or higher
+# And other EPs...
```
-### 2. Rust Toolchain
-```shell
-# Install Rust and Cargo
-curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
-```
-### 3. Add usls to Your Project
+## βοΈ Installation
Add the following to your `Cargo.toml`:
+
```toml
[dependencies]
# Recommended: Use GitHub version
-usls = { git = "https://github.com/jamjamjon/usls" }
+usls = { git = "https://github.com/jamjamjon/usls", features = [ "cuda" ] }
# Alternative: Use crates.io version
usls = "latest-version"
```
-> **Note:** **The GitHub version is recommended as it contains the latest updates.**
-## β‘ Cargo Features
-- **ONNXRuntime-related features (enabled by default)**, provide model inference and model zoo support:
- - **`ort-download-binaries`** (**default**): Automatically downloads prebuilt `ONNXRuntime` binaries for supported platforms. Provides core model loading and inference capabilities using the `CPU` execution provider.
- - **`ort-load-dynamic `** Dynamic linking. You'll need to compile `ONNXRuntime` from [source](https://github.com/microsoft/onnxruntime) or download a [precompiled package](https://github.com/microsoft/onnxruntime/releases), and then link it manually. [See the guide here](https://ort.pyke.io/setup/linking#dynamic-linking).
-
- - **`cuda`**: Enables the NVIDIA `CUDA` provider. Requires `CUDA` toolkit and `cuDNN` installed.
- - **`trt`**: Enables the NVIDIA `TensorRT` provider. Requires `TensorRT` libraries installed.
- - **`mps`**: Enables the Apple `CoreML` provider for macOS.
+## β‘ Supported Models
+
+Click to expand
-- **If you only need basic features** (such as image/video reading, result visualization, etc.), you can disable the default features to minimize dependencies:
- ```shell
- usls = { git = "https://github.com/jamjamjon/usls", default-features = false }
- ```
- - **`video`** : Enable video stream reading, and video writing.(Note: Powered by [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb). Check their repositories for potential issues.)
+| Model | Task / Description | Example |
+| ----- | ----------------- | ------- |
+| [BEiT](https://github.com/microsoft/unilm/tree/master/beit) | Image Classification | [demo](examples/beit) |
+| [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) | Image Classification | [demo](examples/convnext) |
+| [FastViT](https://github.com/apple/ml-fastvit) | Image Classification | [demo](examples/fastvit) |
+| [MobileOne](https://github.com/apple/ml-mobileone) | Image Classification | [demo](examples/mobileone) |
+| [DeiT](https://github.com/facebookresearch/deit) | Image Classification | [demo](examples/deit) |
+| [DINOv2](https://github.com/facebookresearch/dinov2) | Vision Embedding | [demo](examples/dinov2) |
+| [YOLOv5](https://github.com/ultralytics/yolov5) | Image Classification
Object Detection
Instance Segmentation | [demo](examples/yolo) |
+| [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | [demo](examples/yolo) |
+| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) |
+| [YOLOv8
YOLO11](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Image Classification
Oriented Object Detection
Keypoint Detection | [demo](examples/yolo) |
+| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) |
+| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) |
+| [YOLOv12](https://github.com/sunsmarterjie/yolov12) | Object Detection | [demo](examples/yolo) |
+| [RT-DETR](https://github.com/lyuwenyu/RT-DETR) | Object Detection | [demo](examples/rtdetr) |
+| [RF-DETR](https://github.com/roboflow/rf-detr) | Object Detection | [demo](examples/rfdetr) |
+| [PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.8/configs/picodet) | Object Detection | [demo](examples/picodet-layout) |
+| [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) | Object Detection | [demo](examples/picodet-layout) |
+| [D-FINE](https://github.com/manhbd-22022602/D-FINE) | Object Detection | [demo](examples/d-fine) |
+| [DEIM](https://github.com/ShihuaHuang95/DEIM) | Object Detection | [demo](examples/deim) |
+| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) |
+| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) |
+| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | [demo](examples/sam) |
+| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | [demo](examples/sam) |
+| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | [demo](examples/sam) |
+| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | [demo](examples/sam) |
+| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) |
+| [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Open-Set Detection With Language | [demo](examples/yolo) |
+| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) |
+| [CLIP](https://github.com/openai/CLIP) | Vision-Language Embedding | [demo](examples/clip) |
+| [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](examples/clip) |
+| [jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2) | Vision-Language Embedding | [demo](examples/clip) |
+| [mobileclip](https://github.com/apple/ml-mobileclip) | Vision-Language Embedding | [demo](examples/clip) |
+| [BLIP](https://github.com/salesforce/BLIP) | Image Captioning | [demo](examples/blip) |
+| [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947) | Text Detection | [demo](examples/db) |
+| [FAST](https://github.com/czczup/FAST) | Text Detection | [demo](examples/fast) |
+| [LinkNet](https://arxiv.org/abs/1707.03718) | Text Detection | [demo](examples/linknet) |
+| [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) |
+| [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html) | Tabel Recognition | [demo](examples/slanet) |
+| [TrOCR](https://huggingface.co/microsoft/trocr-base-printed) | Text Recognition | [demo](examples/trocr) |
+| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) |
+| [DepthAnything v1
DepthAnything v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) |
+| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) |
+| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) |
+| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) |
+| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) |
+| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) |
+| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) |
+| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) |
+| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation
Background Removal | [demo](examples/rmbg) |
+| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation
Background Removal | [demo](examples/rmbg) |
+| [MediaPipe: Selfie-segmentation](https://ai.google.dev/edge/mediapipe/solutions/vision/image_segmenter) | Image Segmentation | [demo](examples/mediapipe-selfie-segmentation) |
-## β¨ Example
-
-- Model Inference
- ```shell
- cargo run -r --example yolo # CPU
- cargo run -r -F cuda --example yolo -- --device cuda:0 # GPU
- ```
-
-- Reading Images
- ```rust
- // Read a single image
- let image = DataLoader::try_read_one("./assets/bus.jpg")?;
-
- // Read multiple images
- let images = DataLoader::try_read_n(&["./assets/bus.jpg", "./assets/cat.png"])?;
-
- // Read all images in a folder
- let images = DataLoader::try_read_folder("./assets")?;
-
- // Read images matching a pattern (glob)
- let images = DataLoader::try_read_pattern("./assets/*.Jpg")?;
-
- // Load images and iterate
- let dl = DataLoader::new("./assets")?.with_batch(2).build()?;
- for images in dl.iter() {
- // Code here
- }
- ```
-
-- Reading Video
- ```rust
- let dl = DataLoader::new("http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4")?
- .with_batch(1)
- .with_nf_skip(2)
- .with_progress_bar(true)
- .build()?;
- for images in dl.iter() {
- // Code here
- }
- ```
-
-- Annotate
- ```rust
- let annotator = Annotator::default();
- let image = DataLoader::try_read_one("./assets/bus.jpg")?;
- // hbb
- let hbb = Hbb::default()
- .with_xyxy(669.5233, 395.4491, 809.0367, 878.81226)
- .with_id(0)
- .with_name("person")
- .with_confidence(0.87094545);
- let _ = annotator.annotate(&image, &hbb)?;
-
- // keypoints
- let keypoints: Vec = vec![
- Keypoint::default()
- .with_xy(139.35767, 443.43655)
- .with_id(0)
- .with_name("nose")
- .with_confidence(0.9739332),
- Keypoint::default()
- .with_xy(147.38545, 434.34055)
- .with_id(1)
- .with_name("left_eye")
- .with_confidence(0.9098319),
- Keypoint::default()
- .with_xy(128.5701, 434.07516)
- .with_id(2)
- .with_name("right_eye")
- .with_confidence(0.9320564),
- ];
- let _ = annotator.annotate(&image, &keypoints)?;
- ```
+
-- Visualizing Inference Results and Exporting Video
- ```rust
- let dl = DataLoader::new(args.source.as_str())?.build()?;
- let mut viewer = Viewer::default().with_window_scale(0.5);
+## π¦ Cargo Features
+- **`ort-download-binaries`** (**default**): Automatically downloads prebuilt ONNXRuntime binaries for supported platforms
+- **`ort-load-dynamic`**: Dynamic linking to ONNXRuntime libraries ([Guide](https://ort.pyke.io/setup/linking#dynamic-linking))
+- **`video`**: Enable video stream reading and writing (via [video-rs](https://github.com/oddity-ai/video-rs) and [minifb](https://github.com/emoon/rust_minifb))
+- **`cuda`**: NVIDIA CUDA GPU acceleration support
+- **`tensorrt`**: NVIDIA TensorRT optimization for inference acceleration
+- **`coreml`**: Apple CoreML acceleration for macOS/iOS devices
+- **`openvino`**: Intel OpenVINO toolkit for CPU/GPU/VPU acceleration
+- **`onednn`**: Intel oneDNN (formerly MKL-DNN) for CPU optimization
+- **`directml`**: Microsoft DirectML for Windows GPU acceleration
+- **`xnnpack`**: Google XNNPACK for mobile and edge device optimization
+- **`rocm`**: AMD ROCm platform for GPU acceleration
+- **`cann`**: Huawei CANN (Compute Architecture for Neural Networks) support
+- **`rknpu`**: Rockchip NPU acceleration
+- **`acl`**: Arm Compute Library for Arm processors
+- **`nnapi`**: Android Neural Networks API support
+- **`armnn`**: Arm NN inference engine
+- **`tvm`**: Apache TVM tensor compiler stack
+- **`qnn`**: Qualcomm Neural Network SDK
+- **`migraphx`**: AMD MIGraphX for GPU acceleration
+- **`vitis`**: Xilinx Vitis AI for FPGA acceleration
+- **`azure`**: Azure Machine Learning integration
- for images in &dl {
- // Check if the window exists and is open
- if viewer.is_window_exist() && !viewer.is_window_open() {
- break;
- }
-
- // Show image in window
- viewer.imshow(&images[0])?;
-
- // Handle key events and delay
- if let Some(key) = viewer.wait_key(1) {
- if key == usls::Key::Escape {
- break;
- }
- }
-
- // Your custom code here
-
- // Write video frame (requires video feature)
- // if args.save_video {
- // viewer.write_video_frame(&images[0])?;
- // }
- }
- ```
-
-**All examples are located in the [examples](./examples/) directory.**
## β FAQ
-See issues or open a new discussion.
+See [issues](https://github.com/jamjamjon/usls/issues) or open a new discussion.
## π€ Contributing
diff --git a/build.rs b/build.rs
deleted file mode 100644
index 6558285..0000000
--- a/build.rs
+++ /dev/null
@@ -1,7 +0,0 @@
-use std::io::Result;
-
-fn main() -> Result<()> {
- prost_build::compile_protos(&["src/utils/onnx.proto3"], &["src"])?;
-
- Ok(())
-}
diff --git a/examples/mediapipe-selfie-segmentation/README.md b/examples/mediapipe-selfie-segmentation/README.md
new file mode 100644
index 0000000..cc68120
--- /dev/null
+++ b/examples/mediapipe-selfie-segmentation/README.md
@@ -0,0 +1,10 @@
+## Quick Start
+
+```shell
+cargo run -r --example mediapipe-selfie-segmentation -- --dtype f16
+```
+
+
+## Results
+
+
diff --git a/examples/mediapipe-selfie-segmentation/main.rs b/examples/mediapipe-selfie-segmentation/main.rs
new file mode 100644
index 0000000..adeb090
--- /dev/null
+++ b/examples/mediapipe-selfie-segmentation/main.rs
@@ -0,0 +1,49 @@
+use usls::{models::MediaPipeSegmenter, Annotator, Config, DataLoader};
+
+#[derive(argh::FromArgs)]
+/// Example
+struct Args {
+ /// dtype
+ #[argh(option, default = "String::from(\"auto\")")]
+ dtype: String,
+
+ /// device
+ #[argh(option, default = "String::from(\"cpu:0\")")]
+ device: String,
+}
+
+fn main() -> anyhow::Result<()> {
+ tracing_subscriber::fmt()
+ .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
+ .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
+ .init();
+ let args: Args = argh::from_env();
+
+ // build model
+ let config = Config::mediapipe_selfie_segmentater()
+ .with_model_dtype(args.dtype.parse()?)
+ .with_model_device(args.device.parse()?)
+ .commit()?;
+ let mut model = MediaPipeSegmenter::new(config)?;
+
+ // load image
+ let xs = DataLoader::try_read_n(&["images/selfie-segmenter.png"])?;
+
+ // run
+ let ys = model.forward(&xs)?;
+
+ // annotate
+ let annotator =
+ Annotator::default().with_mask_style(usls::Style::mask().with_mask_cutout(true));
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ annotator.annotate(x, y)?.save(format!(
+ "{}.jpg",
+ usls::Dir::Current
+ .base_dir_with_subs(&["runs", model.spec()])?
+ .join(usls::timestamp(None))
+ .display(),
+ ))?;
+ }
+
+ Ok(())
+}
diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs
index b68aef1..69ce542 100644
--- a/examples/rtmo/main.rs
+++ b/examples/rtmo/main.rs
@@ -1,21 +1,48 @@
use anyhow::Result;
use usls::{models::RTMO, Annotator, Config, DataLoader, Style, SKELETON_COCO_19};
+#[derive(argh::FromArgs)]
+/// Example
+struct Args {
+ /// dtype
+ #[argh(option, default = "String::from(\"fp16\")")]
+ dtype: String,
+
+ /// device
+ #[argh(option, default = "String::from(\"cpu:0\")")]
+ device: String,
+
+ /// scale: s, m, l
+ #[argh(option, default = "String::from(\"t\")")]
+ scale: String,
+}
+
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
+ let args: Args = argh::from_env();
// build model
- let mut model = RTMO::new(Config::rtmo_s().commit()?)?;
+ let config = match args.scale.as_str() {
+ "t" => Config::rtmo_t(),
+ "s" => Config::rtmo_s(),
+ "m" => Config::rtmo_m(),
+ "l" => Config::rtmo_l(),
+ _ => unreachable!(),
+ }
+ .with_model_dtype(args.dtype.parse()?)
+ .with_model_device(args.device.parse()?)
+ .commit()?;
+ let mut model = RTMO::new(config)?;
// load image
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
// run
let ys = model.forward(&xs)?;
- println!("ys: {:?}", ys);
+ // println!("ys: {:?}", ys);
// annotate
let annotator = Annotator::default()
@@ -37,5 +64,8 @@ fn main() -> Result<()> {
))?;
}
+ // summary
+ model.summary();
+
Ok(())
}
diff --git a/src/utils/config.rs b/src/core/config.rs
similarity index 99%
rename from src/utils/config.rs
rename to src/core/config.rs
index b6799db..6ba7e76 100644
--- a/src/utils/config.rs
+++ b/src/core/config.rs
@@ -1,12 +1,11 @@
use aksr::Builder;
use crate::{
- impl_ort_config_methods, impl_processor_config_methods,
models::{SamKind, YOLOPredsFormat},
ORTConfig, ProcessorConfig, Scale, Task, Version,
};
-/// Config for building models and inference
+/// Configuration for model inference including engines, processors, and task settings.
#[derive(Builder, Debug, Clone)]
pub struct Config {
// Basics
diff --git a/src/io/dataloader.rs b/src/core/dataloader.rs
similarity index 95%
rename from src/io/dataloader.rs
rename to src/core/dataloader.rs
index 4292dcf..0e31e18 100644
--- a/src/io/dataloader.rs
+++ b/src/core/dataloader.rs
@@ -550,9 +550,19 @@ trait DataLoaderIterator {
}
}
+/// An iterator implementation for `DataLoader` that enables batch processing of images.
+///
+/// This struct is created by the `into_iter` method on `DataLoader`.
+/// It provides functionality for:
+/// - Receiving batches of images through a channel
+/// - Tracking progress with an optional progress bar
+/// - Processing images in configurable batch sizes
pub struct DataLoaderIntoIterator {
+ /// Channel receiver for getting batches of images
receiver: mpsc::Receiver>,
+ /// Optional progress bar for tracking iteration progress
progress_bar: Option,
+ /// Number of images to process in each batch
batch_size: u64,
}
@@ -593,6 +603,15 @@ impl IntoIterator for DataLoader {
}
}
+/// A borrowing iterator for `DataLoader` that enables batch processing of images.
+///
+/// This iterator is created by the `iter()` method on `DataLoader`, allowing iteration
+/// over batches of images without taking ownership of the `DataLoader`.
+///
+/// # Fields
+/// - `receiver`: A reference to the channel receiver that provides batches of images
+/// - `progress_bar`: An optional reference to a progress bar for tracking iteration progress
+/// - `batch_size`: The number of images to process in each batch
pub struct DataLoaderIter<'a> {
receiver: &'a mpsc::Receiver>,
progress_bar: Option<&'a ProgressBar>,
diff --git a/src/utils/device.rs b/src/core/device.rs
similarity index 99%
rename from src/utils/device.rs
rename to src/core/device.rs
index 5743193..4f2accf 100644
--- a/src/utils/device.rs
+++ b/src/core/device.rs
@@ -1,4 +1,5 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
+/// Device types for model execution.
pub enum Device {
Cpu(usize),
Cuda(usize),
diff --git a/src/io/dir.rs b/src/core/dir.rs
similarity index 100%
rename from src/io/dir.rs
rename to src/core/dir.rs
diff --git a/src/utils/dtype.rs b/src/core/dtype.rs
similarity index 98%
rename from src/utils/dtype.rs
rename to src/core/dtype.rs
index ef73223..e195806 100644
--- a/src/utils/dtype.rs
+++ b/src/core/dtype.rs
@@ -1,4 +1,5 @@
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
+/// Data type enumeration for tensor elements.
pub enum DType {
#[default]
Auto,
diff --git a/src/utils/dynconf.rs b/src/core/dynconf.rs
similarity index 100%
rename from src/utils/dynconf.rs
rename to src/core/dynconf.rs
diff --git a/src/inference/engine.rs b/src/core/engine.rs
similarity index 98%
rename from src/inference/engine.rs
rename to src/core/engine.rs
index 6206f52..f11409a 100644
--- a/src/inference/engine.rs
+++ b/src/core/engine.rs
@@ -46,25 +46,39 @@ impl From for DType {
}
/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions.
+/// ONNX Runtime tensor attributes containing names, data types, and dimensions.
#[derive(Builder, Debug, Clone, Default)]
+/// ONNX Runtime tensor attributes containing metadata.
pub struct OrtTensorAttr {
+ /// Tensor names.
pub names: Vec,
+ /// Tensor data types.
pub dtypes: Vec,
+ /// Tensor dimensions for each tensor.
pub dimss: Vec>,
}
+/// ONNX I/O structure containing input/output attributes and session.
#[derive(Debug)]
pub struct OnnxIo {
+ /// Input tensor attributes.
pub inputs: OrtTensorAttr,
+ /// Output tensor attributes.
pub outputs: OrtTensorAttr,
+ /// ONNX Runtime session.
pub session: Session,
+ /// ONNX model protocol buffer.
pub proto: onnx::ModelProto,
}
+/// ONNX Runtime inference engine with configuration and session management.
#[derive(Debug, Builder)]
pub struct Engine {
+ /// Model file path.
pub file: String,
+ /// Model specification string.
pub spec: String,
+ /// Execution device.
pub device: Device,
#[args(inc)]
pub iiixs: Vec,
@@ -72,9 +86,13 @@ pub struct Engine {
pub params: Option,
#[args(aka = "memory")]
pub wbmems: Option,
+ /// Input min-opt-max configurations.
pub inputs_minoptmax: Vec>,
+ /// ONNX I/O structure.
pub onnx: Option,
+ /// Timing statistics.
pub ts: Ts,
+ /// Number of dry runs for warmup.
pub num_dry_run: usize,
// global
@@ -158,7 +176,7 @@ impl Default for Engine {
// cann
cann_graph_inference: true,
cann_dump_graphs: false,
- cann_dump_om_model: false,
+ cann_dump_om_model: true,
// nnapi
nnapi_cpu_only: false,
nnapi_disable_cpu: false,
diff --git a/src/io/hub.rs b/src/core/hub.rs
similarity index 100%
rename from src/io/hub.rs
rename to src/core/hub.rs
diff --git a/src/utils/iiix.rs b/src/core/iiix.rs
similarity index 80%
rename from src/utils/iiix.rs
rename to src/core/iiix.rs
index 6db1626..8077032 100644
--- a/src/utils/iiix.rs
+++ b/src/core/iiix.rs
@@ -3,8 +3,11 @@ use crate::MinOptMax;
/// A struct for input composed of the i-th input, the ii-th dimension, and the value.
#[derive(Clone, Debug, Default)]
pub struct Iiix {
+ /// Input index.
pub i: usize,
+ /// Dimension index.
pub ii: usize,
+ /// Min-Opt-Max value specification.
pub x: MinOptMax,
}
diff --git a/src/inference/image.rs b/src/core/image.rs
similarity index 95%
rename from src/inference/image.rs
rename to src/core/image.rs
index 23b113f..cb47135 100644
--- a/src/inference/image.rs
+++ b/src/core/image.rs
@@ -9,6 +9,7 @@ use std::path::{Path, PathBuf};
use crate::{build_resizer_filter, Hub, Location, MediaType, X};
+/// Information about image transformation including source and destination dimensions.
#[derive(Builder, Debug, Clone, Default)]
pub struct ImageTransformInfo {
pub width_src: u32,
@@ -19,6 +20,7 @@ pub struct ImageTransformInfo {
pub width_scale: f32,
}
+/// Image resize modes for different scaling strategies.
#[derive(Debug, Clone, Default)]
pub enum ResizeMode {
/// StretchToFit
@@ -30,6 +32,7 @@ pub enum ResizeMode {
Letterbox,
}
+/// Image wrapper with metadata and transformation capabilities.
#[derive(Builder, Clone)]
pub struct Image {
image: RgbImage,
@@ -376,8 +379,13 @@ impl Image {
}
}
+/// Extension trait for converting between vectors of different image types.
+/// Provides methods to convert between `Vec` and `Vec`.
pub trait ImageVecExt {
+ /// Converts the vector into a vector of `DynamicImage`s.
fn into_dyns(self) -> Vec;
+
+ /// Converts the vector into a vector of `Image`s.
fn into_images(self) -> Vec;
}
diff --git a/src/utils/logits_sampler.rs b/src/core/logits_sampler.rs
similarity index 93%
rename from src/utils/logits_sampler.rs
rename to src/core/logits_sampler.rs
index a95435d..947349f 100644
--- a/src/utils/logits_sampler.rs
+++ b/src/core/logits_sampler.rs
@@ -1,9 +1,12 @@
use anyhow::Result;
use rand::distr::{weighted::WeightedIndex, Distribution};
+/// Logits sampler for text generation with temperature and nucleus sampling.
#[derive(Debug, Clone)]
pub struct LogitsSampler {
+ /// Temperature parameter for controlling randomness in sampling.
temperature: f32,
+ /// Top-p parameter for nucleus sampling.
p: f32,
}
diff --git a/src/io/media.rs b/src/core/media.rs
similarity index 96%
rename from src/io/media.rs
rename to src/core/media.rs
index cbbb6fe..7322551 100644
--- a/src/io/media.rs
+++ b/src/core/media.rs
@@ -11,6 +11,7 @@ pub(crate) const STREAM_PROTOCOLS: &[&str] = &[
"rtsp://", "rtsps://", "rtspu://", "rtmp://", "rtmps://", "hls://",
];
+/// Media location type indicating local or remote source.
#[derive(Debug, Clone, Default, Copy)]
pub enum Location {
#[default]
@@ -18,6 +19,7 @@ pub enum Location {
Remote,
}
+/// Stream type for media content.
#[derive(Debug, Clone, Copy, Default)]
pub enum StreamType {
#[default]
@@ -25,6 +27,7 @@ pub enum StreamType {
Live,
}
+/// Media type classification for different content formats.
#[derive(Debug, Clone, Copy, Default)]
pub enum MediaType {
#[default]
diff --git a/src/utils/min_opt_max.rs b/src/core/min_opt_max.rs
similarity index 100%
rename from src/utils/min_opt_max.rs
rename to src/core/min_opt_max.rs
diff --git a/src/core/mod.rs b/src/core/mod.rs
new file mode 100644
index 0000000..a4d739e
--- /dev/null
+++ b/src/core/mod.rs
@@ -0,0 +1,61 @@
+#[macro_use]
+mod ort_config;
+#[macro_use]
+mod processor_config;
+mod config;
+mod dataloader;
+mod device;
+mod dir;
+mod dtype;
+mod dynconf;
+#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
+mod engine;
+mod hub;
+mod iiix;
+mod image;
+mod logits_sampler;
+mod media;
+mod min_opt_max;
+mod names;
+#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
+#[allow(clippy::all)]
+pub(crate) mod onnx;
+mod ops;
+mod processor;
+mod retry;
+mod scale;
+mod task;
+mod traits;
+mod ts;
+mod utils;
+mod version;
+mod x;
+mod xs;
+
+pub use config::*;
+pub use dataloader::*;
+pub use device::Device;
+pub use dir::*;
+pub use dtype::DType;
+pub use dynconf::DynConf;
+#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
+pub use engine::*;
+pub use hub::*;
+pub(crate) use iiix::Iiix;
+pub use image::*;
+pub use logits_sampler::LogitsSampler;
+pub use media::*;
+pub use min_opt_max::MinOptMax;
+pub use names::*;
+pub use ops::*;
+pub use ort_config::ORTConfig;
+pub use processor::*;
+pub use processor_config::ProcessorConfig;
+pub use scale::Scale;
+pub use task::Task;
+pub use traits::*;
+pub use ts::Ts;
+pub use utils::*;
+pub use version::Version;
+pub use x::X;
+pub use xs::Xs;
diff --git a/src/utils/names.rs b/src/core/names.rs
similarity index 96%
rename from src/utils/names.rs
rename to src/core/names.rs
index e4affef..cf25bbc 100644
--- a/src/utils/names.rs
+++ b/src/core/names.rs
@@ -1,3 +1,5 @@
+/// A comprehensive list of 4585 object categories used in the YOLOE model.
+/// A comprehensive list of 4585 object categories used in the YOLOE object detection model.
pub static NAMES_YOLOE_4585: [&str; 4585] = [
"3D CG rendering",
"3D glasses",
@@ -4586,6 +4588,13 @@ pub static NAMES_YOLOE_4585: [&str; 4585] = [
"zoo",
];
+/// Labels for DOTA (Dataset for Object deTection in Aerial images) v1.5 with 16 categories.
+///
+/// DOTA is a large-scale dataset for object detection in aerial images. This version includes:
+/// - Transportation objects (planes, ships, vehicles)
+/// - Sports facilities (courts, fields)
+/// - Infrastructure (bridges, harbors, storage tanks)
+/// - Other aerial view objects
pub const NAMES_DOTA_V1_5_16: [&str; 16] = [
"plane",
"ship",
@@ -4604,6 +4613,13 @@ pub const NAMES_DOTA_V1_5_16: [&str; 16] = [
"swimming pool",
"container crane",
];
+/// Labels for DOTA (Dataset for Object deTection in Aerial images) v1.0 with 15 categories.
+///
+/// Similar to DOTA v1.5 but excludes the "container crane" category. Includes:
+/// - Transportation objects
+/// - Sports facilities
+/// - Infrastructure
+/// - Other aerial view objects
pub const NAMES_DOTA_V1_15: [&str; 15] = [
"plane",
"ship",
@@ -4622,6 +4638,13 @@ pub const NAMES_DOTA_V1_15: [&str; 15] = [
"swimming pool",
];
+/// Labels for document layout analysis using YOLO with 10 categories.
+///
+/// These labels are used to identify different components in document layout analysis:
+/// - Text elements (title, plain text)
+/// - Visual elements (figures and their captions)
+/// - Tabular elements (tables, captions, footnotes)
+/// - Mathematical elements (formulas and captions)
pub const NAMES_YOLO_DOCLAYOUT_10: [&str; 10] = [
"title",
"plain text",
@@ -4634,6 +4657,14 @@ pub const NAMES_YOLO_DOCLAYOUT_10: [&str; 10] = [
"isolate_formula",
"formula_caption",
];
+/// Labels for PicoDet document layout analysis with 17 categories.
+///
+/// A comprehensive set of labels for detailed document structure analysis, including:
+/// - Textual elements (titles, paragraphs, abstracts)
+/// - Visual elements (images, figures)
+/// - Structural elements (headers, footers)
+/// - Reference elements (citations, footnotes)
+/// - Special elements (algorithms, seals)
pub const NAMES_PICODET_LAYOUT_17: [&str; 17] = [
"paragraph_title",
"image",
@@ -4653,8 +4684,26 @@ pub const NAMES_PICODET_LAYOUT_17: [&str; 17] = [
"footer",
"seal",
];
+/// Simplified PicoDet document layout labels with 3 basic categories.
+///
+/// A minimal set of labels for basic document layout analysis:
+/// - Images
+/// - Tables
+/// - Seals (official stamps or marks)
pub const NAMES_PICODET_LAYOUT_3: [&str; 3] = ["image", "table", "seal"];
+/// Core PicoDet document layout labels with 5 essential categories.
+///
+/// A balanced set of labels for common document layout analysis tasks:
+/// - Textual content (Text, Title, List)
+/// - Visual content (Figure)
+/// - Structured content (Table)
pub const NAMES_PICODET_LAYOUT_5: [&str; 5] = ["Text", "Title", "List", "Table", "Figure"];
+/// COCO dataset keypoint labels for human pose estimation with 17 points.
+///
+/// Standard keypoints for human body pose detection including:
+/// - Facial features (eyes, nose, ears)
+/// - Upper body joints (shoulders, elbows, wrists)
+/// - Lower body joints (hips, knees, ankles)
pub const NAMES_COCO_KEYPOINTS_17: [&str; 17] = [
"nose",
"left_eye",
@@ -4675,6 +4724,15 @@ pub const NAMES_COCO_KEYPOINTS_17: [&str; 17] = [
"right_ankle",
];
+/// COCO dataset object detection labels with 80 categories.
+///
+/// A widely-used set of object categories for general object detection, including:
+/// - Living beings (person, animals)
+/// - Vehicles and transportation
+/// - Common objects and furniture
+/// - Food and kitchen items
+/// - Sports and recreational equipment
+/// - Electronics and appliances
pub const NAMES_COCO_80: [&str; 80] = [
"person",
"bicycle",
@@ -4758,6 +4816,14 @@ pub const NAMES_COCO_80: [&str; 80] = [
"toothbrush",
];
+/// Extended COCO dataset labels with 91 categories including background and unused slots.
+///
+/// An extended version of COCO labels that includes:
+/// - Background class (index 0)
+/// - All standard COCO categories
+/// - Reserved "unused" slots for dataset compatibility
+///
+/// Note: Indices are preserved for compatibility with pre-trained models, hence the "unused" entries.
pub const NAMES_COCO_91: [&str; 91] = [
"background", // 0
"person", // 1
@@ -4852,6 +4918,13 @@ pub const NAMES_COCO_91: [&str; 91] = [
"toothbrush", // 90
];
+/// Human body parts segmentation labels with 28 categories.
+///
+/// Detailed categorization of human body parts for segmentation tasks:
+/// - Basic body parts (face, neck, hair, torso)
+/// - Left/Right symmetrical parts (hands, arms, legs, feet)
+/// - Clothing regions (apparel, upper/lower clothing)
+/// - Facial features (lips, teeth, tongue)
pub const NAMES_BODY_PARTS_28: [&str; 28] = [
"Background",
"Apparel",
@@ -4883,6 +4956,16 @@ pub const NAMES_BODY_PARTS_28: [&str; 28] = [
"Tongue",
];
+/// ImageNet ILSVRC 1000-class classification labels.
+///
+/// The standard ImageNet 1K classification dataset includes:
+/// - Animals (mammals, birds, fish, insects)
+/// - Objects (tools, vehicles, furniture)
+/// - Foods and plants
+/// - Natural scenes
+/// - Man-made structures
+///
+/// Note: Labels include both common names and scientific nomenclature where applicable.
pub const NAMES_IMAGENET_1K: [&str; 1000] = [
"tench, Tinca tinca",
"goldfish, Carassius auratus",
diff --git a/src/core/onnx.rs b/src/core/onnx.rs
new file mode 100644
index 0000000..1eb6e5e
--- /dev/null
+++ b/src/core/onnx.rs
@@ -0,0 +1,1210 @@
+// This file is @generated by prost-build.
+/// Attributes
+///
+/// A named attribute containing either singular float, integer, string, graph,
+/// and tensor values, or repeated float, integer, string, graph, and tensor values.
+/// An AttributeProto MUST contain the name field, and *only one* of the
+/// following content fields, effectively enforcing a C/C++ union equivalent.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct AttributeProto {
+ /// The name field MUST be present for this version of the IR.
+ ///
+ /// namespace Attribute
+ #[prost(string, tag = "1")]
+ pub name: ::prost::alloc::string::String,
+ /// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
+ /// In this case, this AttributeProto does not contain data, and it's a reference of attribute
+ /// in parent scope.
+ /// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
+ #[prost(string, tag = "21")]
+ pub ref_attr_name: ::prost::alloc::string::String,
+ /// A human-readable documentation for this attribute. Markdown is allowed.
+ #[prost(string, tag = "13")]
+ pub doc_string: ::prost::alloc::string::String,
+ /// The type field MUST be present for this version of the IR.
+ /// For 0.0.1 versions of the IR, this field was not defined, and
+ /// implementations needed to use has_field heuristics to determine
+ /// which value field was in use. For IR_VERSION 0.0.2 or later, this
+ /// field MUST be set and match the f|i|s|t|... field in use. This
+ /// change was made to accommodate proto3 implementations.
+ ///
+ /// discriminator that indicates which field below is in use
+ #[prost(enumeration = "attribute_proto::AttributeType", tag = "20")]
+ pub r#type: i32,
+ /// Exactly ONE of the following fields must be present for this version of the IR
+ ///
+ /// float
+ #[prost(float, tag = "2")]
+ pub f: f32,
+ /// int
+ #[prost(int64, tag = "3")]
+ pub i: i64,
+ /// UTF-8 string
+ #[prost(bytes = "vec", tag = "4")]
+ pub s: ::prost::alloc::vec::Vec,
+ /// tensor value
+ #[prost(message, optional, tag = "5")]
+ pub t: ::core::option::Option,
+ /// graph
+ #[prost(message, optional, tag = "6")]
+ pub g: ::core::option::Option,
+ /// sparse tensor value
+ #[prost(message, optional, tag = "22")]
+ pub sparse_tensor: ::core::option::Option,
+ /// Do not use field below, it's deprecated.
+ /// optional ValueProto v = 12; // value - subsumes everything but graph
+ ///
+ /// type proto
+ #[prost(message, optional, tag = "14")]
+ pub tp: ::core::option::Option,
+ /// list of floats
+ #[prost(float, repeated, tag = "7")]
+ pub floats: ::prost::alloc::vec::Vec,
+ /// list of ints
+ #[prost(int64, repeated, tag = "8")]
+ pub ints: ::prost::alloc::vec::Vec,
+ /// list of UTF-8 strings
+ #[prost(bytes = "vec", repeated, tag = "9")]
+ pub strings: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>,
+ /// list of tensors
+ #[prost(message, repeated, tag = "10")]
+ pub tensors: ::prost::alloc::vec::Vec,
+ /// list of graph
+ #[prost(message, repeated, tag = "11")]
+ pub graphs: ::prost::alloc::vec::Vec,
+ /// list of sparse tensors
+ #[prost(message, repeated, tag = "23")]
+ pub sparse_tensors: ::prost::alloc::vec::Vec,
+ /// list of type protos
+ #[prost(message, repeated, tag = "15")]
+ pub type_protos: ::prost::alloc::vec::Vec,
+}
+/// Nested message and enum types in `AttributeProto`.
+pub mod attribute_proto {
+ /// Note: this enum is structurally identical to the OpSchema::AttrType
+ /// enum defined in schema.h. If you rev one, you likely need to rev the other.
+ #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
+ #[repr(i32)]
+ pub enum AttributeType {
+ Undefined = 0,
+ Float = 1,
+ Int = 2,
+ String = 3,
+ Tensor = 4,
+ Graph = 5,
+ SparseTensor = 11,
+ TypeProto = 13,
+ Floats = 6,
+ Ints = 7,
+ Strings = 8,
+ Tensors = 9,
+ Graphs = 10,
+ SparseTensors = 12,
+ TypeProtos = 14,
+ }
+ impl AttributeType {
+ /// String value of the enum field names used in the ProtoBuf definition.
+ ///
+ /// The values are not transformed in any way and thus are considered stable
+ /// (if the ProtoBuf definition does not change) and safe for programmatic use.
+ pub fn as_str_name(&self) -> &'static str {
+ match self {
+ Self::Undefined => "UNDEFINED",
+ Self::Float => "FLOAT",
+ Self::Int => "INT",
+ Self::String => "STRING",
+ Self::Tensor => "TENSOR",
+ Self::Graph => "GRAPH",
+ Self::SparseTensor => "SPARSE_TENSOR",
+ Self::TypeProto => "TYPE_PROTO",
+ Self::Floats => "FLOATS",
+ Self::Ints => "INTS",
+ Self::Strings => "STRINGS",
+ Self::Tensors => "TENSORS",
+ Self::Graphs => "GRAPHS",
+ Self::SparseTensors => "SPARSE_TENSORS",
+ Self::TypeProtos => "TYPE_PROTOS",
+ }
+ }
+ /// Creates an enum from field names used in the ProtoBuf definition.
+ pub fn from_str_name(value: &str) -> ::core::option::Option {
+ match value {
+ "UNDEFINED" => Some(Self::Undefined),
+ "FLOAT" => Some(Self::Float),
+ "INT" => Some(Self::Int),
+ "STRING" => Some(Self::String),
+ "TENSOR" => Some(Self::Tensor),
+ "GRAPH" => Some(Self::Graph),
+ "SPARSE_TENSOR" => Some(Self::SparseTensor),
+ "TYPE_PROTO" => Some(Self::TypeProto),
+ "FLOATS" => Some(Self::Floats),
+ "INTS" => Some(Self::Ints),
+ "STRINGS" => Some(Self::Strings),
+ "TENSORS" => Some(Self::Tensors),
+ "GRAPHS" => Some(Self::Graphs),
+ "SPARSE_TENSORS" => Some(Self::SparseTensors),
+ "TYPE_PROTOS" => Some(Self::TypeProtos),
+ _ => None,
+ }
+ }
+ }
+}
+/// Defines information on value, including the name, the type, and
+/// the shape of the value.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ValueInfoProto {
+ /// This field MUST be present in this version of the IR.
+ ///
+ /// namespace Value
+ #[prost(string, tag = "1")]
+ pub name: ::prost::alloc::string::String,
+ /// This field MUST be present in this version of the IR for
+ /// inputs and outputs of the top-level graph.
+ #[prost(message, optional, tag = "2")]
+ pub r#type: ::core::option::Option,
+ /// A human-readable documentation for this value. Markdown is allowed.
+ #[prost(string, tag = "3")]
+ pub doc_string: ::prost::alloc::string::String,
+ /// Named metadata values; keys should be distinct.
+ #[prost(message, repeated, tag = "4")]
+ pub metadata_props: ::prost::alloc::vec::Vec,
+}
+/// Nodes
+///
+/// Computation graphs are made up of a DAG of nodes, which represent what is
+/// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
+///
+/// For example, it can be a node of type "Conv" that takes in an image, a filter
+/// tensor and a bias tensor, and produces the convolved output.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct NodeProto {
+ /// namespace Value
+ #[prost(string, repeated, tag = "1")]
+ pub input: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
+ /// namespace Value
+ #[prost(string, repeated, tag = "2")]
+ pub output: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
+ /// An optional identifier for this node in a graph.
+ /// This field MAY be absent in this version of the IR.
+ ///
+ /// namespace Node
+ #[prost(string, tag = "3")]
+ pub name: ::prost::alloc::string::String,
+ /// The symbolic identifier of the Operator to execute.
+ ///
+ /// namespace Operator
+ #[prost(string, tag = "4")]
+ pub op_type: ::prost::alloc::string::String,
+ /// The domain of the OperatorSet that specifies the operator named by op_type.
+ ///
+ /// namespace Domain
+ #[prost(string, tag = "7")]
+ pub domain: ::prost::alloc::string::String,
+ /// Overload identifier, used only to map this to a model-local function.
+ #[prost(string, tag = "8")]
+ pub overload: ::prost::alloc::string::String,
+ /// Additional named attributes.
+ #[prost(message, repeated, tag = "5")]
+ pub attribute: ::prost::alloc::vec::Vec,
+ /// A human-readable documentation for this node. Markdown is allowed.
+ #[prost(string, tag = "6")]
+ pub doc_string: ::prost::alloc::string::String,
+ /// Named metadata values; keys should be distinct.
+ #[prost(message, repeated, tag = "9")]
+ pub metadata_props: ::prost::alloc::vec::Vec,
+ /// Configuration of multi-device annotations.
+ #[prost(message, repeated, tag = "10")]
+ pub device_configurations: ::prost::alloc::vec::Vec,
+}
+/// IntIntListEntryProto follows the pattern for cross-proto-version maps.
+/// See
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct IntIntListEntryProto {
+ #[prost(int64, tag = "1")]
+ pub key: i64,
+ #[prost(int64, repeated, tag = "2")]
+ pub value: ::prost::alloc::vec::Vec,
+}
+/// Multi-device configuration proto for NodeProto.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct NodeDeviceConfigurationProto {
+ /// This field MUST be present for this version of the IR.
+ /// ID of the configuration. MUST match the name of a DeviceConfigurationProto.
+ #[prost(string, tag = "1")]
+ pub configuration_id: ::prost::alloc::string::String,
+ /// Sharding spec for the node.
+ #[prost(message, repeated, tag = "2")]
+ pub sharding_spec: ::prost::alloc::vec::Vec,
+ /// Pipeline stage of this node.
+ #[prost(int32, tag = "3")]
+ pub pipeline_stage: i32,
+}
+/// ShardingSpecProto: This describes the sharding spec for a specific
+/// input or output tensor of a node.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ShardingSpecProto {
+ /// This field MUST be present for this version of the IR.
+ /// Identifies the input or output of the node that is being sharded.
+ /// Required to match a name specified in the node's input or output list of ValueInfoProtos.
+ /// It is called `logical tensor` in subsequent descriptions.
+ #[prost(string, tag = "1")]
+ pub tensor_name: ::prost::alloc::string::String,
+ /// The following is the list of devices across which the logical
+ /// tensor is sharded or replicated.
+ #[prost(int64, repeated, tag = "2")]
+ pub device: ::prost::alloc::vec::Vec,
+ /// Each element v in above field devices may represent either a
+ /// device or a set of devices (when we want the same shard/tensor
+ /// to be replicated across a subset of devices), as indicated by
+ /// the following optional map. If the map contains an entry for v,
+ /// then v represents a device group, and the map indicates the set
+ /// of devices in that group.
+ #[prost(message, repeated, tag = "3")]
+ pub index_to_device_group_map: ::prost::alloc::vec::Vec,
+ /// The following is the sharded-shape of the tensor, consisting of
+ /// the sharding-spec for each axis of the tensor.
+ #[prost(message, repeated, tag = "4")]
+ pub sharded_dim: ::prost::alloc::vec::Vec,
+}
+/// ShardedDimProto: This describes the sharding spec for a single
+/// axis of a sharded tensor.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ShardedDimProto {
+ /// This field MUST be present for this version of the IR.
+ /// The axis this sharding corresponds to. Must be in the range of
+ /// \[-r, r - 1\], where r is the rank of the tensor. Negative axis values means
+ /// counting from the back.
+ #[prost(int64, tag = "1")]
+ pub axis: i64,
+ /// Describes how the tensor on the provided axis is sharded.
+ /// The common-case is described by a single instance of SimpleShardedDimProto.
+ /// Multiple instances can be used to handle cases where a sharded
+ /// tensor is reshaped, fusing multiple axes into one.
+ #[prost(message, repeated, tag = "2")]
+ pub simple_sharding: ::prost::alloc::vec::Vec,
+}
+/// SimpleShardedDimProto: Indicates that N blocks are divided into M shards.
+/// N is allowed to be symbolic where M is required to be a constant.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct SimpleShardedDimProto {
+ /// This field MUST be present for this version of the IR.
+ /// Number of shards to split dim into.
+ #[prost(int64, tag = "3")]
+ pub num_shards: i64,
+ /// Dimension value to be sharded.
+ #[prost(oneof = "simple_sharded_dim_proto::Dim", tags = "1, 2")]
+ pub dim: ::core::option::Option,
+}
+/// Nested message and enum types in `SimpleShardedDimProto`.
+pub mod simple_sharded_dim_proto {
+ /// Dimension value to be sharded.
+ #[derive(Clone, PartialEq, ::prost::Oneof)]
+ pub enum Dim {
+ #[prost(int64, tag = "1")]
+ DimValue(i64),
+ #[prost(string, tag = "2")]
+ DimParam(::prost::alloc::string::String),
+ }
+}
+/// Training information
+/// TrainingInfoProto stores information for training a model.
+/// In particular, this defines two functionalities: an initialization-step
+/// and a training-algorithm-step. Initialization resets the model
+/// back to its original state as if no training has been performed.
+/// Training algorithm improves the model based on input data.
+///
+/// The semantics of the initialization-step is that the initializers
+/// in ModelProto.graph and in TrainingInfoProto.algorithm are first
+/// initialized as specified by the initializers in the graph, and then
+/// updated by the "initialization_binding" in every instance in
+/// ModelProto.training_info.
+///
+/// The field "algorithm" defines a computation graph which represents a
+/// training algorithm's step. After the execution of a
+/// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
+/// may be immediately updated. If the targeted training algorithm contains
+/// consecutive update steps (such as block coordinate descent methods),
+/// the user needs to create a TrainingInfoProto for each step.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct TrainingInfoProto {
+ /// This field describes a graph to compute the initial tensors
+ /// upon starting the training process. Initialization graph has no input
+ /// and can have multiple outputs. Usually, trainable tensors in neural
+ /// networks are randomly initialized. To achieve that, for each tensor,
+ /// the user can put a random number operator such as RandomNormal or
+ /// RandomUniform in TrainingInfoProto.initialization.node and assign its
+ /// random output to the specific tensor using "initialization_binding".
+ /// This graph can also set the initializers in "algorithm" in the same
+ /// TrainingInfoProto; a use case is resetting the number of training
+ /// iteration to zero.
+ ///
+ /// By default, this field is an empty graph and its evaluation does not
+ /// produce any output. Thus, no initializer would be changed by default.
+ #[prost(message, optional, tag = "1")]
+ pub initialization: ::core::option::Option,
+ /// This field represents a training algorithm step. Given required inputs,
+ /// it computes outputs to update initializers in its own or inference graph's
+ /// initializer lists. In general, this field contains loss node, gradient node,
+ /// optimizer node, increment of iteration count.
+ ///
+ /// An execution of the training algorithm step is performed by executing the
+ /// graph obtained by combining the inference graph (namely "ModelProto.graph")
+ /// and the "algorithm" graph. That is, the actual
+ /// input/initializer/output/node/value_info/sparse_initializer list of
+ /// the training graph is the concatenation of
+ /// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
+ /// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
+ /// in that order. This combined graph must satisfy the normal ONNX conditions.
+ /// Now, let's provide a visualization of graph combination for clarity.
+ /// Let the inference graph (i.e., "ModelProto.graph") be
+ /// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
+ /// and the "algorithm" graph be
+ /// tensor_d -> Add -> tensor_e
+ /// The combination process results
+ /// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
+ ///
+ /// Notice that an input of a node in the "algorithm" graph may reference the
+ /// output of a node in the inference graph (but not the other way round). Also, inference
+ /// node cannot reference inputs of "algorithm". With these restrictions, inference graph
+ /// can always be run independently without training information.
+ ///
+ /// By default, this field is an empty graph and its evaluation does not
+ /// produce any output. Evaluating the default training step never
+ /// update any initializers.
+ #[prost(message, optional, tag = "2")]
+ pub algorithm: ::core::option::Option,
+ /// This field specifies the bindings from the outputs of "initialization" to
+ /// some initializers in "ModelProto.graph.initializer" and
+ /// the "algorithm.initializer" in the same TrainingInfoProto.
+ /// See "update_binding" below for details.
+ ///
+ /// By default, this field is empty and no initializer would be changed
+ /// by the execution of "initialization".
+ #[prost(message, repeated, tag = "3")]
+ pub initialization_binding: ::prost::alloc::vec::Vec,
+ /// Gradient-based training is usually an iterative procedure. In one gradient
+ /// descent iteration, we apply
+ ///
+ /// x = x - r * g
+ ///
+ /// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
+ /// gradient of "x" with respect to a chosen loss. To avoid adding assignments
+ /// into the training graph, we split the update equation into
+ ///
+ /// y = x - r * g
+ /// x = y
+ ///
+ /// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
+ /// tell that "y" should be assigned to "x", the field "update_binding" may
+ /// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
+ /// and "y" (value of StringStringEntryProto).
+ /// For a neural network with multiple trainable (mutable) tensors, there can
+ /// be multiple key-value pairs in "update_binding".
+ ///
+ /// The initializers appears as keys in "update_binding" are considered
+ /// mutable variables. This implies some behaviors
+ /// as described below.
+ ///
+ /// 1. We have only unique keys in all "update_binding"s so that two
+ /// variables may not have the same name. This ensures that one
+ /// variable is assigned up to once.
+ /// 2. The keys must appear in names of "ModelProto.graph.initializer" or
+ /// "TrainingInfoProto.algorithm.initializer".
+ /// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
+ /// 4. Mutable variables are initialized to the value specified by the
+ /// corresponding initializer, and then potentially updated by
+ /// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
+ ///
+ /// This field usually contains names of trainable tensors
+ /// (in ModelProto.graph), optimizer states such as momentums in advanced
+ /// stochastic gradient methods (in TrainingInfoProto.graph),
+ /// and number of training iterations (in TrainingInfoProto.graph).
+ ///
+ /// By default, this field is empty and no initializer would be changed
+ /// by the execution of "algorithm".
+ #[prost(message, repeated, tag = "4")]
+ pub update_binding: ::prost::alloc::vec::Vec,
+}
+/// Models
+///
+/// ModelProto is a top-level file/container format for bundling a ML model and
+/// associating its computation graph with metadata.
+///
+/// The semantics of the model are described by the associated GraphProto's.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ModelProto {
+ /// The version of the IR this model targets. See Version enum above.
+ /// This field MUST be present.
+ #[prost(int64, tag = "1")]
+ pub ir_version: i64,
+ /// The OperatorSets this model relies on.
+ /// All ModelProtos MUST have at least one entry that
+ /// specifies which version of the ONNX OperatorSet is
+ /// being imported.
+ ///
+ /// All nodes in the ModelProto's graph will bind against the operator
+ /// with the same-domain/same-op_type operator with the HIGHEST version
+ /// in the referenced operator sets.
+ #[prost(message, repeated, tag = "8")]
+ pub opset_import: ::prost::alloc::vec::Vec,
+ /// The name of the framework or tool used to generate this model.
+ /// This field SHOULD be present to indicate which implementation/tool/framework
+ /// emitted the model.
+ #[prost(string, tag = "2")]
+ pub producer_name: ::prost::alloc::string::String,
+ /// The version of the framework or tool used to generate this model.
+ /// This field SHOULD be present to indicate which implementation/tool/framework
+ /// emitted the model.
+ #[prost(string, tag = "3")]
+ pub producer_version: ::prost::alloc::string::String,
+ /// Domain name of the model.
+ /// We use reverse domain names as name space indicators. For example:
+ /// `com.facebook.fair` or `com.microsoft.cognitiveservices`
+ ///
+ /// Together with `model_version` and GraphProto.name, this forms the unique identity of
+ /// the graph.
+ #[prost(string, tag = "4")]
+ pub domain: ::prost::alloc::string::String,
+ /// The version of the graph encoded. See Version enum below.
+ #[prost(int64, tag = "5")]
+ pub model_version: i64,
+ /// A human-readable documentation for this model. Markdown is allowed.
+ #[prost(string, tag = "6")]
+ pub doc_string: ::prost::alloc::string::String,
+ /// The parameterized graph that is evaluated to execute the model.
+ #[prost(message, optional, tag = "7")]
+ pub graph: ::core::option::Option,
+ /// Named metadata values; keys should be distinct.
+ #[prost(message, repeated, tag = "14")]
+ pub metadata_props: ::prost::alloc::vec::Vec,
+ /// Training-specific information. Sequentially executing all stored
+ /// `TrainingInfoProto.algorithm`s and assigning their outputs following
+ /// the corresponding `TrainingInfoProto.update_binding`s is one training
+ /// iteration. Similarly, to initialize the model
+ /// (as if training hasn't happened), the user should sequentially execute
+ /// all stored `TrainingInfoProto.initialization`s and assigns their outputs
+ /// using `TrainingInfoProto.initialization_binding`s.
+ ///
+ /// If this field is empty, the training behavior of the model is undefined.
+ #[prost(message, repeated, tag = "20")]
+ pub training_info: ::prost::alloc::vec::Vec,
+ /// A list of function protos local to the model.
+ ///
+ /// The (domain, name, overload) tuple must be unique across the function protos in this list.
+ /// In case of any conflicts the behavior (whether the model local functions are given higher priority,
+ /// or standard operator sets are given higher priority or this is treated as error) is defined by
+ /// the runtimes.
+ ///
+ /// The operator sets imported by FunctionProto should be compatible with the ones
+ /// imported by ModelProto and other model local FunctionProtos.
+ /// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
+ /// or by 2 FunctionProtos then versions for the operator set may be different but,
+ /// the operator schema returned for op_type, domain, version combination
+ /// for both the versions should be same for every node in the function body.
+ ///
+ /// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
+ /// is not allowed.
+ #[prost(message, repeated, tag = "25")]
+ pub functions: ::prost::alloc::vec::Vec,
+ /// Describes different target configurations for a multi-device use case.
+ /// A model MAY describe multiple multi-device configurations for execution.
+ #[prost(message, repeated, tag = "26")]
+ pub configuration: ::prost::alloc::vec::Vec,
+}
+/// DeviceConfigurationProto describes a multi-device configuration for a model.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct DeviceConfigurationProto {
+ /// This field MUST be present for this version of the IR.
+ /// Name of the configuration.
+ #[prost(string, tag = "1")]
+ pub name: ::prost::alloc::string::String,
+ /// This field MUST be present for this version of the IR.
+ /// Number of devices inside this configuration.
+ #[prost(int32, tag = "2")]
+ pub num_devices: i32,
+ /// Optional names of the devices. MUST be length of num_devices if provided.
+ #[prost(string, repeated, tag = "3")]
+ pub device: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
+}
+/// StringStringEntryProto follows the pattern for cross-proto-version maps.
+/// See
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct StringStringEntryProto {
+ #[prost(string, tag = "1")]
+ pub key: ::prost::alloc::string::String,
+ #[prost(string, tag = "2")]
+ pub value: ::prost::alloc::string::String,
+}
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct TensorAnnotation {
+ #[prost(string, tag = "1")]
+ pub tensor_name: ::prost::alloc::string::String,
+ /// pairs to annotate tensor specified by above.
+ /// The keys used in the mapping below must be pre-defined in ONNX spec.
+ /// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
+ /// quantization parameter keys.
+ #[prost(message, repeated, tag = "2")]
+ pub quant_parameter_tensor_names: ::prost::alloc::vec::Vec,
+}
+/// Graphs
+///
+/// A graph defines the computational logic of a model and is comprised of a parameterized
+/// list of nodes that form a directed acyclic graph based on their inputs and outputs.
+/// This is the equivalent of the "network" or "graph" in many deep learning
+/// frameworks.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct GraphProto {
+ /// The nodes in the graph, sorted topologically.
+ #[prost(message, repeated, tag = "1")]
+ pub node: ::prost::alloc::vec::Vec,
+ /// The name of the graph.
+ ///
+ /// namespace Graph
+ #[prost(string, tag = "2")]
+ pub name: ::prost::alloc::string::String,
+ /// A list of named tensor values, used to specify constant inputs of the graph.
+ /// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
+ /// The name MUST be unique across both initializer and sparse_initializer,
+ /// but the name MAY also appear in the input list.
+ #[prost(message, repeated, tag = "5")]
+ pub initializer: ::prost::alloc::vec::Vec,
+ /// Initializers (see above) stored in sparse format.
+ #[prost(message, repeated, tag = "15")]
+ pub sparse_initializer: ::prost::alloc::vec::Vec,
+ /// A human-readable documentation for this graph. Markdown is allowed.
+ #[prost(string, tag = "10")]
+ pub doc_string: ::prost::alloc::string::String,
+ /// The inputs and outputs of the graph.
+ #[prost(message, repeated, tag = "11")]
+ pub input: ::prost::alloc::vec::Vec,
+ #[prost(message, repeated, tag = "12")]
+ pub output: ::prost::alloc::vec::Vec,
+ /// Information for the values in the graph. The ValueInfoProto.name's
+ /// must be distinct. It is optional for a value to appear in value_info list.
+ #[prost(message, repeated, tag = "13")]
+ pub value_info: ::prost::alloc::vec::Vec,
+ /// This field carries information to indicate the mapping among a tensor and its
+ /// quantization parameter tensors. For example:
+ /// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
+ /// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
+ #[prost(message, repeated, tag = "14")]
+ pub quantization_annotation: ::prost::alloc::vec::Vec,
+ /// Named metadata values; keys should be distinct.
+ #[prost(message, repeated, tag = "16")]
+ pub metadata_props: ::prost::alloc::vec::Vec,
+}
+/// Tensors
+///
+/// A serialized tensor value.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct TensorProto {
+ /// The shape of the tensor.
+ #[prost(int64, repeated, tag = "1")]
+ pub dims: ::prost::alloc::vec::Vec,
+ /// The data type of the tensor.
+ /// This field MUST have a valid TensorProto.DataType value
+ #[prost(int32, tag = "2")]
+ pub data_type: i32,
+ #[prost(message, optional, tag = "3")]
+ pub segment: ::core::option::Option,
+ /// For float and complex64 values
+ /// Complex64 tensors are encoded as a single array of floats,
+ /// with the real components appearing in odd numbered positions,
+ /// and the corresponding imaginary component appearing in the
+ /// subsequent even numbered position. (e.g., \[1.0 + 2.0i, 3.0 + 4.0i\]
+ /// is encoded as \[1.0, 2.0 ,3.0 ,4.0\]
+ /// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
+ #[prost(float, repeated, tag = "4")]
+ pub float_data: ::prost::alloc::vec::Vec,
+ /// For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4:
+ /// - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer
+ /// representation before being written to the buffer.
+ /// - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte.
+ /// The first element is stored in the 4 least significant bits (LSB),
+ /// and the second element is stored in the 4 most significant bits (MSB).
+ ///
+ /// Consequently:
+ /// - For data types with a bit-width of 8 or greater, each `int32_data` stores one element.
+ /// - For 4-bit data types, each `int32_data` stores two elements.
+ ///
+ /// When this field is present, the data_type field MUST be
+ /// INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1
+ #[prost(int32, repeated, tag = "5")]
+ pub int32_data: ::prost::alloc::vec::Vec,
+ /// For strings.
+ /// Each element of string_data is a UTF-8 encoded Unicode
+ /// string. No trailing null, no leading BOM. The protobuf "string"
+ /// scalar type is not used to match ML community conventions.
+ /// When this field is present, the data_type field MUST be STRING
+ #[prost(bytes = "vec", repeated, tag = "6")]
+ pub string_data: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>,
+ /// For int64.
+ /// When this field is present, the data_type field MUST be INT64
+ #[prost(int64, repeated, tag = "7")]
+ pub int64_data: ::prost::alloc::vec::Vec,
+ /// Optionally, a name for the tensor.
+ ///
+ /// namespace Value
+ #[prost(string, tag = "8")]
+ pub name: ::prost::alloc::string::String,
+ /// A human-readable documentation for this tensor. Markdown is allowed.
+ #[prost(string, tag = "12")]
+ pub doc_string: ::prost::alloc::string::String,
+ /// Serializations can either use one of the fields above, or use this
+ /// raw bytes field. The only exception is the string case, where one is
+ /// required to store the content in the repeated bytes string_data field.
+ ///
+ /// When this raw_data field is used to store tensor value, elements MUST
+ /// be stored in as fixed-width, little-endian order.
+ /// Floating-point data types MUST be stored in IEEE 754 format.
+ /// Complex64 elements must be written as two consecutive FLOAT values, real component first.
+ /// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
+ /// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
+ /// uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB.
+ ///
+ /// Note: the advantage of specific field rather than the raw_data field is
+ /// that in some cases (e.g. int data), protobuf does a better packing via
+ /// variable length storage, and may lead to smaller binary footprint.
+ /// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
+ #[prost(bytes = "vec", tag = "9")]
+ pub raw_data: ::prost::alloc::vec::Vec,
+ /// Data can be stored inside the protobuf file using type-specific fields or raw_data.
+ /// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
+ /// external_data stores key-value pairs describing data location. Recognized keys are:
+ /// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
+ /// protobuf model was stored
+ /// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
+ /// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
+ /// - "length" (optional) - number of bytes containing data. Integer stored as string.
+ /// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
+ #[prost(message, repeated, tag = "13")]
+ pub external_data: ::prost::alloc::vec::Vec,
+ /// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
+ #[prost(enumeration = "tensor_proto::DataLocation", tag = "14")]
+ pub data_location: i32,
+ /// For double
+ /// Complex128 tensors are encoded as a single array of doubles,
+ /// with the real components appearing in odd numbered positions,
+ /// and the corresponding imaginary component appearing in the
+ /// subsequent even numbered position. (e.g., \[1.0 + 2.0i, 3.0 + 4.0i\]
+ /// is encoded as \[1.0, 2.0 ,3.0 ,4.0\]
+ /// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
+ #[prost(double, repeated, tag = "10")]
+ pub double_data: ::prost::alloc::vec::Vec,
+ /// For uint64 and uint32 values
+ /// When this field is present, the data_type field MUST be
+ /// UINT32 or UINT64
+ #[prost(uint64, repeated, tag = "11")]
+ pub uint64_data: ::prost::alloc::vec::Vec,
+ /// Named metadata values; keys should be distinct.
+ #[prost(message, repeated, tag = "16")]
+ pub metadata_props: ::prost::alloc::vec::Vec,
+}
+/// Nested message and enum types in `TensorProto`.
+pub mod tensor_proto {
+ /// For very large tensors, we may want to store them in chunks, in which
+ /// case the following fields will specify the segment that is stored in
+ /// the current TensorProto.
+ #[derive(Clone, Copy, PartialEq, ::prost::Message)]
+ pub struct Segment {
+ #[prost(int64, tag = "1")]
+ pub begin: i64,
+ #[prost(int64, tag = "2")]
+ pub end: i64,
+ }
+ #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
+ #[repr(i32)]
+ pub enum DataType {
+ Undefined = 0,
+ /// Basic types.
+ ///
+ /// float
+ Float = 1,
+ /// uint8_t
+ Uint8 = 2,
+ /// int8_t
+ Int8 = 3,
+ /// uint16_t
+ Uint16 = 4,
+ /// int16_t
+ Int16 = 5,
+ /// int32_t
+ Int32 = 6,
+ /// int64_t
+ Int64 = 7,
+ /// string
+ String = 8,
+ /// bool
+ Bool = 9,
+ /// IEEE754 half-precision floating-point format (16 bits wide).
+ /// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
+ Float16 = 10,
+ Double = 11,
+ Uint32 = 12,
+ Uint64 = 13,
+ /// complex with float32 real and imaginary components
+ Complex64 = 14,
+ /// complex with float64 real and imaginary components
+ Complex128 = 15,
+ /// Non-IEEE floating-point format based on IEEE754 single-precision
+ /// floating-point number truncated to 16 bits.
+ /// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
+ Bfloat16 = 16,
+ /// Non-IEEE floating-point format based on papers
+ /// FP8 Formats for Deep Learning,
+ /// 8-bit Numerical Formats For Deep Neural Networks,
+ /// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
+ /// The computation usually happens inside a block quantize / dequantize
+ /// fused by the runtime.
+ ///
+ /// float 8, mostly used for coefficients, supports nan, not inf
+ Float8e4m3fn = 17,
+ /// float 8, mostly used for coefficients, supports nan, not inf, no negative zero
+ Float8e4m3fnuz = 18,
+ /// follows IEEE 754, supports nan, inf, mostly used for gradients
+ Float8e5m2 = 19,
+ /// follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero
+ Float8e5m2fnuz = 20,
+ /// 4-bit integer data types
+ ///
+ /// Unsigned integer in range \[0, 15\]
+ Uint4 = 21,
+ /// Signed integer in range \[-8, 7\], using two's-complement representation
+ Int4 = 22,
+ /// 4-bit floating point data types
+ Float4e2m1 = 23,
+ }
+ impl DataType {
+ /// String value of the enum field names used in the ProtoBuf definition.
+ ///
+ /// The values are not transformed in any way and thus are considered stable
+ /// (if the ProtoBuf definition does not change) and safe for programmatic use.
+ pub fn as_str_name(&self) -> &'static str {
+ match self {
+ Self::Undefined => "UNDEFINED",
+ Self::Float => "FLOAT",
+ Self::Uint8 => "UINT8",
+ Self::Int8 => "INT8",
+ Self::Uint16 => "UINT16",
+ Self::Int16 => "INT16",
+ Self::Int32 => "INT32",
+ Self::Int64 => "INT64",
+ Self::String => "STRING",
+ Self::Bool => "BOOL",
+ Self::Float16 => "FLOAT16",
+ Self::Double => "DOUBLE",
+ Self::Uint32 => "UINT32",
+ Self::Uint64 => "UINT64",
+ Self::Complex64 => "COMPLEX64",
+ Self::Complex128 => "COMPLEX128",
+ Self::Bfloat16 => "BFLOAT16",
+ Self::Float8e4m3fn => "FLOAT8E4M3FN",
+ Self::Float8e4m3fnuz => "FLOAT8E4M3FNUZ",
+ Self::Float8e5m2 => "FLOAT8E5M2",
+ Self::Float8e5m2fnuz => "FLOAT8E5M2FNUZ",
+ Self::Uint4 => "UINT4",
+ Self::Int4 => "INT4",
+ Self::Float4e2m1 => "FLOAT4E2M1",
+ }
+ }
+ /// Creates an enum from field names used in the ProtoBuf definition.
+ pub fn from_str_name(value: &str) -> ::core::option::Option {
+ match value {
+ "UNDEFINED" => Some(Self::Undefined),
+ "FLOAT" => Some(Self::Float),
+ "UINT8" => Some(Self::Uint8),
+ "INT8" => Some(Self::Int8),
+ "UINT16" => Some(Self::Uint16),
+ "INT16" => Some(Self::Int16),
+ "INT32" => Some(Self::Int32),
+ "INT64" => Some(Self::Int64),
+ "STRING" => Some(Self::String),
+ "BOOL" => Some(Self::Bool),
+ "FLOAT16" => Some(Self::Float16),
+ "DOUBLE" => Some(Self::Double),
+ "UINT32" => Some(Self::Uint32),
+ "UINT64" => Some(Self::Uint64),
+ "COMPLEX64" => Some(Self::Complex64),
+ "COMPLEX128" => Some(Self::Complex128),
+ "BFLOAT16" => Some(Self::Bfloat16),
+ "FLOAT8E4M3FN" => Some(Self::Float8e4m3fn),
+ "FLOAT8E4M3FNUZ" => Some(Self::Float8e4m3fnuz),
+ "FLOAT8E5M2" => Some(Self::Float8e5m2),
+ "FLOAT8E5M2FNUZ" => Some(Self::Float8e5m2fnuz),
+ "UINT4" => Some(Self::Uint4),
+ "INT4" => Some(Self::Int4),
+ "FLOAT4E2M1" => Some(Self::Float4e2m1),
+ _ => None,
+ }
+ }
+ }
+ /// Location of the data for this tensor. MUST be one of:
+ /// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
+ /// - EXTERNAL - data stored in an external location as described by external_data field.
+ #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
+ #[repr(i32)]
+ pub enum DataLocation {
+ Default = 0,
+ External = 1,
+ }
+ impl DataLocation {
+ /// String value of the enum field names used in the ProtoBuf definition.
+ ///
+ /// The values are not transformed in any way and thus are considered stable
+ /// (if the ProtoBuf definition does not change) and safe for programmatic use.
+ pub fn as_str_name(&self) -> &'static str {
+ match self {
+ Self::Default => "DEFAULT",
+ Self::External => "EXTERNAL",
+ }
+ }
+ /// Creates an enum from field names used in the ProtoBuf definition.
+ pub fn from_str_name(value: &str) -> ::core::option::Option {
+ match value {
+ "DEFAULT" => Some(Self::Default),
+ "EXTERNAL" => Some(Self::External),
+ _ => None,
+ }
+ }
+ }
+}
+/// A serialized sparse-tensor value
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct SparseTensorProto {
+ /// The sequence of non-default values are encoded as a tensor of shape \[NNZ\].
+ /// The default-value is zero for numeric tensors, and empty-string for string tensors.
+ /// values must have a non-empty name present which serves as a name for SparseTensorProto
+ /// when used in sparse_initializer list.
+ #[prost(message, optional, tag = "1")]
+ pub values: ::core::option::Option,
+ /// The indices of the non-default values, which may be stored in one of two formats.
+ /// (a) Indices can be a tensor of shape \[NNZ, rank\] with the \[i,j\]-th value
+ /// corresponding to the j-th index of the i-th value (in the values tensor).
+ /// (b) Indices can be a tensor of shape \[NNZ\], in which case the i-th value
+ /// must be the linearized-index of the i-th value (in the values tensor).
+ /// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
+ /// using the shape provided below.
+ /// The indices must appear in ascending order without duplication.
+ /// In the first format, the ordering is lexicographic-ordering:
+ /// e.g., index-value \[1,4\] must appear before \[2,1\]
+ #[prost(message, optional, tag = "2")]
+ pub indices: ::core::option::Option,
+ /// The shape of the underlying dense-tensor: \[dim_1, dim_2, ... dim_rank\]
+ #[prost(int64, repeated, tag = "3")]
+ pub dims: ::prost::alloc::vec::Vec,
+}
+/// Defines a tensor shape. A dimension can be either an integer value
+/// or a symbolic variable. A symbolic variable represents an unknown
+/// dimension.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct TensorShapeProto {
+ #[prost(message, repeated, tag = "1")]
+ pub dim: ::prost::alloc::vec::Vec,
+}
+/// Nested message and enum types in `TensorShapeProto`.
+pub mod tensor_shape_proto {
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct Dimension {
+ /// Standard denotation can optionally be used to denote tensor
+ /// dimensions with standard semantic descriptions to ensure
+ /// that operations are applied to the correct axis of a tensor.
+ /// Refer to
+ /// for pre-defined dimension denotations.
+ #[prost(string, tag = "3")]
+ pub denotation: ::prost::alloc::string::String,
+ #[prost(oneof = "dimension::Value", tags = "1, 2")]
+ pub value: ::core::option::Option,
+ }
+ /// Nested message and enum types in `Dimension`.
+ pub mod dimension {
+ #[derive(Clone, PartialEq, ::prost::Oneof)]
+ pub enum Value {
+ #[prost(int64, tag = "1")]
+ DimValue(i64),
+ /// namespace Shape
+ #[prost(string, tag = "2")]
+ DimParam(::prost::alloc::string::String),
+ }
+ }
+}
+/// Types
+///
+/// The standard ONNX data types.
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct TypeProto {
+ /// An optional denotation can be used to denote the whole
+ /// type with a standard semantic description as to what is
+ /// stored inside. Refer to
+ /// for pre-defined type denotations.
+ #[prost(string, tag = "6")]
+ pub denotation: ::prost::alloc::string::String,
+ #[prost(oneof = "type_proto::Value", tags = "1, 4, 5, 9, 8")]
+ pub value: ::core::option::Option,
+}
+/// Nested message and enum types in `TypeProto`.
+pub mod type_proto {
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct Tensor {
+ /// This field MUST NOT have the value of UNDEFINED
+ /// This field MUST have a valid TensorProto.DataType value
+ /// This field MUST be present for this version of the IR.
+ #[prost(int32, tag = "1")]
+ pub elem_type: i32,
+ #[prost(message, optional, tag = "2")]
+ pub shape: ::core::option::Option,
+ }
+ /// repeated T
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct Sequence {
+ /// The type and optional shape of each element of the sequence.
+ /// This field MUST be present for this version of the IR.
+ #[prost(message, optional, boxed, tag = "1")]
+ pub elem_type: ::core::option::Option<::prost::alloc::boxed::Box>,
+ }
+ /// map
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct Map {
+ /// This field MUST have a valid TensorProto.DataType value
+ /// This field MUST be present for this version of the IR.
+ /// This field MUST refer to an integral type (\[U\]INT{8|16|32|64}) or STRING
+ #[prost(int32, tag = "1")]
+ pub key_type: i32,
+ /// This field MUST be present for this version of the IR.
+ #[prost(message, optional, boxed, tag = "2")]
+ pub value_type: ::core::option::Option<::prost::alloc::boxed::Box>,
+ }
+ /// wrapper for Tensor, Sequence, or Map
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct Optional {
+ /// The type and optional shape of the element wrapped.
+ /// This field MUST be present for this version of the IR.
+ /// Possible values correspond to OptionalProto.DataType enum
+ #[prost(message, optional, boxed, tag = "1")]
+ pub elem_type: ::core::option::Option<::prost::alloc::boxed::Box>,
+ }
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct SparseTensor {
+ /// This field MUST NOT have the value of UNDEFINED
+ /// This field MUST have a valid TensorProto.DataType value
+ /// This field MUST be present for this version of the IR.
+ #[prost(int32, tag = "1")]
+ pub elem_type: i32,
+ #[prost(message, optional, tag = "2")]
+ pub shape: ::core::option::Option,
+ }
+ #[derive(Clone, PartialEq, ::prost::Oneof)]
+ pub enum Value {
+ /// The type of a tensor.
+ #[prost(message, tag = "1")]
+ TensorType(Tensor),
+ /// The type of a sequence.
+ #[prost(message, tag = "4")]
+ SequenceType(::prost::alloc::boxed::Box),
+ /// The type of a map.
+ #[prost(message, tag = "5")]
+ MapType(::prost::alloc::boxed::Box