From 28f3d18ac3883970e365039f5ca8d11415aff37b Mon Sep 17 00:00:00 2001 From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Tue, 3 Jun 2025 20:53:48 +0800 Subject: [PATCH] Bump `ort` from 2.0.0-rc.9 to 2.0.0-rc.10 (#107) --- .github/workflows/rust-ci.yml | 8 +- Cargo.toml | 18 ++-- build.rs | 3 - src/inference/engine.rs | 163 +++++++++++++------------------- src/inference/x.rs | 1 - src/io/hub.rs | 45 +++++---- src/utils/device.rs | 22 ++--- src/utils/dtype.rs | 28 +++++- src/utils/onnx.proto3 | 172 +++++++++++++++++++++++++++++++--- 9 files changed, 291 insertions(+), 169 deletions(-) diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 55a0847..fd49a56 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -22,7 +22,7 @@ jobs: - name: Install dependencies run: | DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing - DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler + DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler - name: Setup Rust uses: dtolnay/rust-toolchain@stable @@ -47,7 +47,7 @@ jobs: - name: Install dependencies run: | DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing - DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler + DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler - name: Setup Rust uses: dtolnay/rust-toolchain@stable @@ -67,7 +67,7 @@ jobs: - name: Install dependencies run: | DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing - DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler + DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler - name: Setup Rust uses: dtolnay/rust-toolchain@nightly @@ -93,7 +93,7 @@ jobs: - name: Install dependencies run: | DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing - DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler + DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler - name: Setup Rust uses: dtolnay/rust-toolchain@stable diff --git a/Cargo.toml b/Cargo.toml index 321ae5e..a4ef394 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,10 +38,10 @@ fast_image_resize = { version = "5.1.2", features = ["image"] } ndarray-npy = "0.9.1" half = { version = "2.3.1" } prost = "0.13.5" -ort = { version = "2.0.0-rc.9", default-features = false, optional = true, features = [ - "ndarray", +ort = { version = "=2.0.0-rc.10", default-features = false, optional = true, features = [ "copy-dylibs", "half", + "std", ] } tokenizers = { version = "0.21.1" } paste = "1.0.15" @@ -54,10 +54,10 @@ argh = "0.1.13" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] } [features] -default = ["ort-download-binaries"] -ort-download-binaries = ["ort", "ort/download-binaries"] -ort-load-dynamic = ["ort", "ort/load-dynamic"] -cuda = ["ort/cuda"] -trt = ["ort/tensorrt"] -mps = ["ort/coreml"] -video = ["dep:video-rs"] +default = [ "ort-download-binaries" ] +video = [ "dep:video-rs" ] +ort-download-binaries = [ "ort", "ort/download-binaries" ] +ort-load-dynamic = [ "ort", "ort/load-dynamic" ] +cuda = [ "ort/cuda" ] +trt = [ "ort/tensorrt" ] +coreml = [ "ort/coreml" ] diff --git a/build.rs b/build.rs index 4231433..6558285 100644 --- a/build.rs +++ b/build.rs @@ -3,8 +3,5 @@ use std::io::Result; fn main() -> Result<()> { prost_build::compile_protos(&["src/utils/onnx.proto3"], &["src"])?; - #[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] - println!("cargo:rustc-link-arg=-fapple-link-rtlib"); - Ok(()) } diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 2884a7c..2186660 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -20,10 +20,12 @@ use crate::{ impl From for DType { fn from(dtype: TensorElementType) -> Self { match dtype { + TensorElementType::Int4 => Self::Int4, TensorElementType::Int8 => Self::Int8, TensorElementType::Int16 => Self::Int16, TensorElementType::Int32 => Self::Int32, TensorElementType::Int64 => Self::Int64, + TensorElementType::Uint4 => Self::Uint4, TensorElementType::Uint8 => Self::Uint8, TensorElementType::Uint16 => Self::Uint16, TensorElementType::Uint32 => Self::Uint32, @@ -32,14 +34,19 @@ impl From for DType { TensorElementType::Float32 => Self::Fp32, TensorElementType::Float64 => Self::Fp64, TensorElementType::Bfloat16 => Self::Bf16, - TensorElementType::String => Self::String, - TensorElementType::Bool => Self::Bool, + TensorElementType::Float8E4M3FN => Self::Fp8e4m3fn, + TensorElementType::Float8E4M3FNUZ => Self::Fp8e4m3fnuz, + TensorElementType::Float8E5M2 => Self::Fp8e5m2, + TensorElementType::Float8E5M2FNUZ => Self::Fp8e5m2fnuz, + TensorElementType::Complex64 => Self::Complex64, + TensorElementType::Complex128 => Self::Complex128, + _ => todo!(), } } } /// A struct for tensor attrs composed of the names, the dtypes, and the dimensions. -#[derive(Builder, Debug, Clone)] +#[derive(Builder, Debug, Clone, Default)] pub struct OrtTensorAttr { pub names: Vec, pub dtypes: Vec, @@ -133,7 +140,9 @@ impl Engine { let param = tensor_proto.dims.iter().product::() as usize; params += param; let param = Ops::make_divisible(param, byte_alignment); - let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize); + let n = Self::get_ort_dtype_from_proto_dtype_id(tensor_proto.data_type) + .map(|x| x.byte_size(1)) + .unwrap_or_default(); let wbmem = param * n; wbmems += wbmem; } @@ -145,7 +154,10 @@ impl Engine { let param = tensor.dims.iter().product::() as usize; params += param; let param = Ops::make_divisible(param, byte_alignment); - let n = Self::nbytes_from_onnx_dtype_id(tensor.data_type as usize); + let n = Self::get_ort_dtype_from_proto_dtype_id(tensor.data_type) + .map(|x| x.byte_size(1)) + .unwrap_or_default(); + let wbmem = param * n; wbmems += wbmem; } @@ -211,7 +223,7 @@ impl Engine { // update pb.set_message(format!( - "{}({}) on {:?}", + "{}({}) on {}", self.spec, match self.params { Some(bytes) if bytes != 0 => { @@ -231,7 +243,7 @@ impl Engine { pub fn run(&mut self, xs: Xs) -> Result { let mut ys = xs.derive(); - if let Some(onnx) = &self.onnx { + if let Some(onnx) = &mut self.onnx { // alignment let xs_ = elapsed!(&format!("[{}] ort_preprocessing", self.spec), self.ts, { let mut xs_ = Vec::new(); @@ -267,38 +279,22 @@ impl Engine { fn preprocess(x: &X, dtype: &TensorElementType) -> Result { let x = match dtype { - TensorElementType::Float32 => Value::from_array(x.view())?.into_dyn(), - TensorElementType::Float16 => { - Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() + TensorElementType::Float32 | TensorElementType::Float64 => { + Value::from_array(x.0.clone())?.into_dyn() } - TensorElementType::Float64 => Value::from_array(x.view())?.into_dyn(), - TensorElementType::Bfloat16 => { - Value::from_array(x.mapv(bf16::from_f32).view())?.into_dyn() - } - TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn(), - TensorElementType::Int16 => { - Value::from_array(x.mapv(|x_| x_ as i16).view())?.into_dyn() - } - TensorElementType::Int32 => { - Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() - } - TensorElementType::Int64 => { - Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() - } - TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn(), - TensorElementType::Uint16 => { - Value::from_array(x.mapv(|x_| x_ as u16).view())?.into_dyn() - } - TensorElementType::Uint32 => { - Value::from_array(x.mapv(|x_| x_ as u32).view())?.into_dyn() - } - TensorElementType::Uint64 => { - Value::from_array(x.mapv(|x_| x_ as u64).view())?.into_dyn() - } - TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn(), + TensorElementType::Float16 => Value::from_array(x.mapv(f16::from_f32))?.into_dyn(), + TensorElementType::Bfloat16 => Value::from_array(x.mapv(bf16::from_f32))?.into_dyn(), + TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8))?.into_dyn(), + TensorElementType::Int16 => Value::from_array(x.mapv(|x_| x_ as i16))?.into_dyn(), + TensorElementType::Int32 => Value::from_array(x.mapv(|x_| x_ as i32))?.into_dyn(), + TensorElementType::Int64 => Value::from_array(x.mapv(|x_| x_ as i64))?.into_dyn(), + TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8))?.into_dyn(), + TensorElementType::Uint16 => Value::from_array(x.mapv(|x_| x_ as u16))?.into_dyn(), + TensorElementType::Uint32 => Value::from_array(x.mapv(|x_| x_ as u32))?.into_dyn(), + TensorElementType::Uint64 => Value::from_array(x.mapv(|x_| x_ as u64))?.into_dyn(), + TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.))?.into_dyn(), _ => unimplemented!(), }; - Ok(x) } @@ -307,7 +303,7 @@ impl Engine { where T: Clone + 'static + ort::tensor::PrimitiveTensorElementType, { - match x.try_extract_tensor::() { + match x.try_extract_array::() { Err(err) => { debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err); Array::zeros(0).into_dyn() @@ -344,7 +340,7 @@ impl Engine { \nConsider enabling them by passing, e.g., `--features #FEATURE`"; match self.device { - Device::TensorRT(id) => { + Device::TensorRt(id) => { #[cfg(not(feature = "trt"))] { anyhow::bail!(feature_help @@ -431,16 +427,28 @@ impl Engine { } } } - Device::CoreML(id) => { - #[cfg(not(feature = "mps"))] + Device::CoreMl(id) => { + #[cfg(not(feature = "coreml"))] { anyhow::bail!(feature_help .replace("#EP", "CoreML") - .replace("#FEATURE", "mps")); + .replace("#FEATURE", "coreml")); } - #[cfg(feature = "mps")] + #[cfg(feature = "coreml")] { - let ep = ort::execution_providers::CoreMLExecutionProvider::default(); + let ep = ort::execution_providers::CoreMLExecutionProvider::default() + .with_model_cache_dir( + crate::Dir::Cache + .crate_dir_default_with_subs(&["coreml-cache"])? + .display(), + ) + .with_compute_units(ort::execution_providers::coreml::CoreMLComputeUnits::All) + .with_static_input_shapes(false) + .with_subgraphs(true) + .with_model_format(ort::execution_providers::coreml::CoreMLModelFormat::MLProgram) + .with_specialization_strategy( + ort::execution_providers::coreml::CoreMLSpecializationStrategy::FastPrediction, + ); match ep.is_available() { Ok(true) => { ep.register(&mut builder).map_err(|err| { @@ -452,13 +460,14 @@ impl Engine { } } _ => { - let ep = ort::execution_providers::CPUExecutionProvider::default(); + let ep = ort::execution_providers::CPUExecutionProvider::default() + .with_arena_allocator(true); match ep.is_available() { Ok(true) => { ep.register(&mut builder) .map_err(|err| anyhow::anyhow!("Failed to register Cpu: {}", err))?; } - _ => anyhow::bail!(compile_help.replace("#EP", "Cpu")), + _ => unreachable!("CPU EP is not available. This case should ideally not be reached under normal circumstances."), } } } @@ -532,38 +541,8 @@ impl Engine { Ok(ys) } - #[allow(dead_code)] - fn nbytes_from_onnx_dtype_id(x: usize) -> usize { - match x { - 7 | 11 | 13 => 8, // i64, f64, u64 - 1 | 6 | 12 => 4, // f32, i32, u32 - 10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16 - 2 | 3 | 9 => 1, // u8, i8, bool - 8 => 4, // string(1~4) - _ => 1, // TODO: others - } - } - - #[allow(dead_code)] - fn nbytes_from_onnx_dtype(x: &TensorElementType) -> usize { - match x { - TensorElementType::Float64 | TensorElementType::Uint64 | TensorElementType::Int64 => 8, // i64, f64, u64 - TensorElementType::Float32 - | TensorElementType::Uint32 - | TensorElementType::Int32 - | TensorElementType::String => 4, // f32, i32, u32, string(1~4) - TensorElementType::Float16 - | TensorElementType::Bfloat16 - | TensorElementType::Int16 - | TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 - TensorElementType::Uint8 | TensorElementType::Int8 | TensorElementType::Bool => 1, // u8, i8, bool - } - } - - #[allow(dead_code)] - fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option { + fn get_ort_dtype_from_proto_dtype_id(value: i32) -> Option { match value { - 0 => None, 1 => Some(TensorElementType::Float32), 2 => Some(TensorElementType::Uint8), 3 => Some(TensorElementType::Int8), @@ -577,10 +556,16 @@ impl Engine { 11 => Some(TensorElementType::Float64), 12 => Some(TensorElementType::Uint32), 13 => Some(TensorElementType::Uint64), - 14 => None, // COMPLEX64 - 15 => None, // COMPLEX128 + 14 => Some(TensorElementType::Complex64), + 15 => Some(TensorElementType::Complex128), 16 => Some(TensorElementType::Bfloat16), - _ => None, + 17 => Some(TensorElementType::Float8E4M3FN), + 18 => Some(TensorElementType::Float8E4M3FNUZ), + 19 => Some(TensorElementType::Float8E5M2), + 20 => Some(TensorElementType::Float8E5M2FNUZ), + 21 => Some(TensorElementType::Uint4), + 22 => Some(TensorElementType::Int4), + _ => None, // 23: Float4e2m1, 0: Undefined } } @@ -609,7 +594,7 @@ impl Engine { _ => continue, }; let tensor_type = tensor.elem_type; - let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) { + let tensor_type = match Self::get_ort_dtype_from_proto_dtype_id(tensor_type) { Some(dtype) => dtype, None => continue, }; @@ -642,24 +627,6 @@ impl Engine { }) } - // pub fn to_ort(&self) -> TensorElementType { - // match self { - // Self::Int8 => TensorElementType::Int8, - // Self::Int16 => TensorElementType::Int16, - // Self::Int32 => TensorElementType::Int32, - // Self::Int64 => TensorElementType::Int64, - // Self::Uint8 => TensorElementType::Uint8, - // Self::Uint16 => TensorElementType::Uint16, - // Self::Uint32 => TensorElementType::Uint32, - // Self::Uint64 => TensorElementType::Uint64, - // Self::Fp16 => TensorElementType::Float16, - // Self::Fp32 => TensorElementType::Float32, - // Self::Fp64 => TensorElementType::Float64, - // Self::Bf16 => TensorElementType::Bfloat16, - // _ => todo!(), - // } - // } - pub fn load_onnx>(p: P) -> Result { let f = std::fs::read(p.as_ref())?; onnx::ModelProto::decode(f.as_slice()).map_err(|err| { diff --git a/src/inference/x.rs b/src/inference/x.rs index 307cec8..deabc77 100644 --- a/src/inference/x.rs +++ b/src/inference/x.rs @@ -1,7 +1,6 @@ use anyhow::Result; use image::DynamicImage; use ndarray::{Array, Dim, IntoDimension, Ix2, IxDyn, IxDynImpl}; -// use std::ops::Mul; use crate::{Ops, ResizeMode}; diff --git a/src/io/hub.rs b/src/io/hub.rs index e29e754..cbd6bc2 100644 --- a/src/io/hub.rs +++ b/src/io/hub.rs @@ -88,9 +88,6 @@ pub struct Hub { /// Directory to store the downloaded file to: Dir, - /// Download timeout in seconds - timeout: u64, - /// Time to live (cache duration) ttl: Duration, @@ -116,7 +113,6 @@ impl Default for Hub { owner, repo, to, - timeout: 3000, max_attempts: 3, ttl: Duration::from_secs(10 * 60), } @@ -195,7 +191,7 @@ impl Hub { .join(&file_name_); pack = pack.with_url(s).with_tag(&tag_).with_file_name(&file_name_); - if let Some(n) = retry!(self.max_attempts, Self::fetch_get_response(s))? + if let Some(n) = retry!(self.max_attempts, self.fetch_get_response(s))? .headers() .get(http::header::CONTENT_LENGTH) .and_then(|v| v.to_str().ok()?.parse::().ok()) @@ -208,7 +204,7 @@ impl Hub { // => Default hub // Fetch releases - let releases = match Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) { + let releases = match self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl) { Err(err) => anyhow::bail!( "Failed to download: No releases found in this repo. Error: {}", err @@ -286,7 +282,7 @@ impl Hub { self.max_attempts, 1000, 3000, - Self::download( + self.download( &pack.url, &saveout, Some(&format!("{}/{}", pack.tag, pack.file_name)), @@ -303,7 +299,7 @@ impl Hub { self.max_attempts, 1000, 3000, - Self::download( + self.download( &pack.url, &saveout, Some(&format!("{}/{}", pack.tag, pack.file_name)), @@ -319,8 +315,8 @@ impl Hub { } /// Fetch releases from GitHub and cache them - fn fetch_and_cache_releases(url: &str, cache_path: &Path) -> Result { - let response = retry!(3, Self::fetch_get_response(url))?; + fn fetch_and_cache_releases(&self, url: &str, cache_path: &Path) -> Result { + let response = retry!(self.max_attempts, self.fetch_get_response(url))?; let body = response .into_body() .read_to_string() @@ -351,7 +347,7 @@ impl Hub { } pub fn tags(&self) -> Vec { - Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) + self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl) .unwrap_or_default() .into_iter() .map(|x| x.tag_name) @@ -359,7 +355,7 @@ impl Hub { } pub fn files(&self, tag: &str) -> Vec { - Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) + self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl) .unwrap_or_default() .into_iter() .find(|r| r.tag_name == tag) @@ -394,11 +390,12 @@ impl Hub { /// Download a file from a github release to a specified path with a progress bar pub fn download + std::fmt::Debug>( + &self, src: &str, dst: P, message: Option<&str>, ) -> Result<()> { - let resp = Self::fetch_get_response(src)?; + let resp = self.fetch_get_response(src)?; let ntotal = resp .headers() .get(http::header::CONTENT_LENGTH) @@ -412,7 +409,8 @@ impl Hub { )?; let mut reader = resp.into_body().into_reader(); - let mut buffer = [0; 2048]; + const BUFFER_SIZE: usize = 64 * 1024; + let mut buffer = [0; BUFFER_SIZE]; let mut downloaded_bytes = 0usize; let mut file = std::fs::File::create(&dst) .with_context(|| format!("Failed to create destination file: {:?}", dst))?; @@ -442,7 +440,7 @@ impl Hub { Ok(()) } - fn fetch_get_response(url: &str) -> anyhow::Result> { + fn fetch_get_response(&self, url: &str) -> anyhow::Result> { let config = ureq::Agent::config_builder() .proxy(ureq::Proxy::try_from_env()) .build(); @@ -462,10 +460,16 @@ impl Hub { fn cache_file(owner: &str, repo: &str) -> String { let safe_owner = owner.replace(|c: char| !c.is_ascii_alphanumeric(), "_"); let safe_repo = repo.replace(|c: char| !c.is_ascii_alphanumeric(), "_"); - format!(".releases_{}_{}.json", safe_owner, safe_repo) + format!(".cache-releases-{}-{}.json", safe_owner, safe_repo) } - fn get_releases(owner: &str, repo: &str, to: &Dir, ttl: &Duration) -> Result> { + fn get_releases( + &self, + owner: &str, + repo: &str, + to: &Dir, + ttl: &Duration, + ) -> Result> { let cache = to.crate_dir_default()?.join(Self::cache_file(owner, repo)); let is_file_expired = Self::is_file_expired(&cache, ttl)?; let body = if is_file_expired { @@ -473,7 +477,7 @@ impl Hub { "https://api.github.com/repos/{}/{}/releases?per_page=100", owner, repo ); - Self::fetch_and_cache_releases(&gh_api_release, &cache)? + self.fetch_and_cache_releases(&gh_api_release, &cache)? } else { std::fs::read_to_string(&cache)? }; @@ -518,11 +522,6 @@ impl Hub { self } - pub fn with_timeout(mut self, x: u64) -> Self { - self.timeout = x; - self - } - pub fn with_max_attempts(mut self, x: u32) -> Self { self.max_attempts = x; self diff --git a/src/utils/device.rs b/src/utils/device.rs index a8099cf..97530de 100644 --- a/src/utils/device.rs +++ b/src/utils/device.rs @@ -2,8 +2,8 @@ pub enum Device { Cpu(usize), Cuda(usize), - TensorRT(usize), - CoreML(usize), + TensorRt(usize), + CoreMl(usize), } impl Default for Device { @@ -15,10 +15,10 @@ impl Default for Device { impl std::fmt::Display for Device { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let x = match self { - Self::Cpu(i) => format!("cpu:{}", i), - Self::Cuda(i) => format!("cuda:{}", i), - Self::CoreML(i) => format!("mps:{}", i), - Self::TensorRT(i) => format!("tensorrt:{}", i), + Self::Cpu(i) => format!("CPU:{}", i), + Self::Cuda(i) => format!("CUDA:{}(NVIDIA)", i), + Self::TensorRt(i) => format!("TensorRT:{}(NVIDIA)", i), + Self::CoreMl(i) => format!("CoreML:{}(Apple)", i), }; write!(f, "{}", x) } @@ -41,8 +41,9 @@ impl TryFrom<&str> for Device { match d.to_lowercase().as_str() { "cpu" => Ok(Self::Cpu(id)), "cuda" => Ok(Self::Cuda(id)), - "trt" | "tensorrt" => Ok(Self::TensorRT(id)), - "coreml" | "mps" => Ok(Self::CoreML(id)), + "trt" | "tensorrt" => Ok(Self::TensorRt(id)), + "coreml" | "mps" => Ok(Self::CoreMl(id)), + _ => anyhow::bail!("Unsupported device str: {s:?}."), } } @@ -51,10 +52,7 @@ impl TryFrom<&str> for Device { impl Device { pub fn id(&self) -> usize { match self { - Device::Cpu(i) => *i, - Device::Cuda(i) => *i, - Device::TensorRT(i) => *i, - Device::CoreML(i) => *i, + Self::Cpu(i) | Self::Cuda(i) | Self::TensorRt(i) | Self::CoreMl(i) => *i, } } } diff --git a/src/utils/dtype.rs b/src/utils/dtype.rs index bd5c890..867ddd7 100644 --- a/src/utils/dtype.rs +++ b/src/utils/dtype.rs @@ -7,6 +7,7 @@ pub enum DType { Int16, Int32, Int64, + Uint4, Uint8, Uint16, Uint32, @@ -15,12 +16,17 @@ pub enum DType { Fp32, Fp64, Bf16, - Bool, - String, Bnb4, Q4, Q4f16, Q8, + Fp8e4m3fn, + Fp8e4m3fnuz, + Fp8e5m2, + Fp8e5m2fnuz, + Fp4e2m1, + Complex64, + Complex128, } impl TryFrom<&str> for DType { @@ -29,6 +35,7 @@ impl TryFrom<&str> for DType { fn try_from(s: &str) -> Result { match s.to_lowercase().as_str() { "auto" | "dyn" => Ok(Self::Auto), + "u4" | "uint4" => Ok(Self::Uint4), "u8" | "uint8" => Ok(Self::Uint8), "u16" | "uint16" => Ok(Self::Uint16), "u32" | "uint32" => Ok(Self::Uint32), @@ -46,6 +53,13 @@ impl TryFrom<&str> for DType { "q4" => Ok(Self::Q4), "q8" => Ok(Self::Q8), "bnb4" => Ok(Self::Bnb4), + "f8e4m3fn" => Ok(Self::Fp8e4m3fn), + "f8e4m3fnuz" => Ok(Self::Fp8e4m3fnuz), + "f8e5m2" => Ok(Self::Fp8e5m2), + "f8e5m2fnuz" => Ok(Self::Fp8e5m2fnuz), + "f4e2m1" => Ok(Self::Fp4e2m1), + "complex64" => Ok(Self::Complex64), + "complex128" => Ok(Self::Complex128), x => anyhow::bail!("Unsupported DType: {}", x), } } @@ -60,6 +74,7 @@ impl std::fmt::Display for DType { Self::Int16 => "int16", Self::Int32 => "int32", Self::Int64 => "int64", + Self::Uint4 => "uint4", Self::Uint8 => "uint8", Self::Uint16 => "uint16", Self::Uint32 => "uint32", @@ -68,12 +83,17 @@ impl std::fmt::Display for DType { Self::Fp32 => "fp32", Self::Fp64 => "fp64", Self::Bf16 => "bf16", - Self::String => "string", - Self::Bool => "bool", Self::Bnb4 => "bnb4", Self::Q4 => "q4", Self::Q4f16 => "q4f16", Self::Q8 => "q8", + Self::Fp8e4m3fn => "f8e4m3fn", + Self::Fp8e4m3fnuz => "f8e4m3fnuz", + Self::Fp8e5m2 => "f8e5m2", + Self::Fp8e5m2fnuz => "f8e5m2fnuz", + Self::Fp4e2m1 => "f4e2m1", + Self::Complex64 => "complex64", + Self::Complex128 => "complex128", }; write!(f, "{}", x) } diff --git a/src/utils/onnx.proto3 b/src/utils/onnx.proto3 index f47006f..2809051 100644 --- a/src/utils/onnx.proto3 +++ b/src/utils/onnx.proto3 @@ -86,7 +86,7 @@ enum Version { IR_VERSION_2019_9_19 = 0x0000000000000006; // IR VERSION 7 published on May 8, 2020 - // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add support to allow function body graph to rely on multiple external operator sets. // - Add a list to promote inference graph's initializers to global and // mutable variables. Global variables are visible in all graphs of the // stored models. @@ -106,7 +106,15 @@ enum Version { // IR VERSION 9 published on May 5, 2023 // Added AttributeProto to FunctionProto so that default attribute values can be set. // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. - IR_VERSION = 0x0000000000000009; + IR_VERSION_2023_5_5 = 0x0000000000000009; + + // IR VERSION 10 published on March 25, 2024 + // Added UINT4, INT4, overload field for functions and metadata_props on multiple proto definitions. + IR_VERSION_2024_3_25 = 0x000000000000000A; + + // IR VERSION 11 published on May 12, 2025 + // Added FLOAT4E2M1, multi-device protobuf classes. + IR_VERSION = 0x000000000000000B; } // Attributes @@ -190,6 +198,8 @@ message ValueInfoProto { TypeProto type = 2; // A human-readable documentation for this value. Markdown is allowed. string doc_string = 3; + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 4; } // Nodes @@ -204,19 +214,101 @@ message NodeProto { repeated string output = 2; // namespace Value // An optional identifier for this node in a graph. - // This field MAY be absent in ths version of the IR. + // This field MAY be absent in this version of the IR. string name = 3; // namespace Node // The symbolic identifier of the Operator to execute. string op_type = 4; // namespace Operator // The domain of the OperatorSet that specifies the operator named by op_type. string domain = 7; // namespace Domain + // Overload identifier, used only to map this to a model-local function. + string overload = 8; // Additional named attributes. repeated AttributeProto attribute = 5; // A human-readable documentation for this node. Markdown is allowed. string doc_string = 6; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 9; + + // Configuration of multi-device annotations. + repeated NodeDeviceConfigurationProto device_configurations = 10; +} + +// IntIntListEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message IntIntListEntryProto { + int64 key = 1; + repeated int64 value = 2; +}; + +// Multi-device configuration proto for NodeProto. +message NodeDeviceConfigurationProto { + // This field MUST be present for this version of the IR. + // ID of the configuration. MUST match the name of a DeviceConfigurationProto. + string configuration_id = 1; + // Sharding spec for the node. + repeated ShardingSpecProto sharding_spec = 2; + // Pipeline stage of this node. + int32 pipeline_stage = 3; +} + +// ShardingSpecProto: This describes the sharding spec for a specific +// input or output tensor of a node. +message ShardingSpecProto { + // This field MUST be present for this version of the IR. + // Identifies the input or output of the node that is being sharded. + // Required to match a name specified in the node's input or output list of ValueInfoProtos. + // It is called `logical tensor` in subsequent descriptions. + string tensor_name = 1; + + // The following is the list of devices across which the logical + // tensor is sharded or replicated. + repeated int64 device = 2; + + // Each element v in above field devices may represent either a + // device or a set of devices (when we want the same shard/tensor + // to be replicated across a subset of devices), as indicated by + // the following optional map. If the map contains an entry for v, + // then v represents a device group, and the map indicates the set + // of devices in that group. + repeated IntIntListEntryProto index_to_device_group_map = 3; + + // The following is the sharded-shape of the tensor, consisting of + // the sharding-spec for each axis of the tensor. + repeated ShardedDimProto sharded_dim = 4; +} + +// ShardedDimProto: This describes the sharding spec for a single +// axis of a sharded tensor. +message ShardedDimProto { + // This field MUST be present for this version of the IR. + // The axis this sharding corresponds to. Must be in the range of + // [-r, r - 1], where r is the rank of the tensor. Negative axis values means + // counting from the back. + int64 axis = 1; + + // Describes how the tensor on the provided axis is sharded. + // The common-case is described by a single instance of SimpleShardedDimProto. + // Multiple instances can be used to handle cases where a sharded + // tensor is reshaped, fusing multiple axes into one. + repeated SimpleShardedDimProto simple_sharding = 2; +} + +// SimpleShardedDimProto: Indicates that N blocks are divided into M shards. +// N is allowed to be symbolic where M is required to be a constant. +message SimpleShardedDimProto { + // Dimension value to be sharded. + oneof dim { + int64 dim_value = 1; + string dim_param = 2; + } + + // This field MUST be present for this version of the IR. + // Number of shards to split dim into. + int64 num_shards = 3; } // Training information @@ -401,9 +493,9 @@ message ModelProto { // A list of function protos local to the model. // - // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". + // The (domain, name, overload) tuple must be unique across the function protos in this list. // In case of any conflicts the behavior (whether the model local functions are given higher priority, - // or standard operator sets are given higher priotity or this is treated as error) is defined by + // or standard operator sets are given higher priority or this is treated as error) is defined by // the runtimes. // // The operator sets imported by FunctionProto should be compatible with the ones @@ -416,8 +508,24 @@ message ModelProto { // One FunctionProto can reference other FunctionProto in the model, however, recursive reference // is not allowed. repeated FunctionProto functions = 25; + + // Describes different target configurations for a multi-device use case. + // A model MAY describe multiple multi-device configurations for execution. + repeated DeviceConfigurationProto configuration = 26; }; +// DeviceConfigurationProto describes a multi-device configuration for a model. +message DeviceConfigurationProto { + // This field MUST be present for this version of the IR. + // Name of the configuration. + string name = 1; + // This field MUST be present for this version of the IR. + // Number of devices inside this configuration. + int32 num_devices = 2; + // Optional names of the devices. MUST be length of num_devices if provided. + repeated string device = 3; +} + // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto { @@ -475,6 +583,9 @@ message GraphProto { // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. repeated TensorAnnotation quantization_annotation = 14; + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 16; + reserved 3, 4, 6 to 9; reserved "ir_version", "producer_version", "producer_tag", "domain"; } @@ -520,7 +631,14 @@ message TensorProto { FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients - FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero + + // 4-bit integer data types + UINT4 = 21; // Unsigned integer in range [0, 15] + INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation + + // 4-bit floating point data types + FLOAT4E2M1 = 23; // Future extensions go here. } @@ -555,11 +673,19 @@ message TensorProto { // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. repeated float float_data = 4 [packed = true]; - // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values - // float16 and float8 values must be bit-wise converted to an uint16_t prior - // to writing to the buffer. + // For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4: + // - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer + // representation before being written to the buffer. + // - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte. + // The first element is stored in the 4 least significant bits (LSB), + // and the second element is stored in the 4 most significant bits (MSB). + // + // Consequently: + // - For data types with a bit-width of 8 or greater, each `int32_data` stores one element. + // - For 4-bit data types, each `int32_data` stores two elements. + // // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ + // INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1 repeated int32 int32_data = 5 [packed = true]; // For strings. @@ -589,6 +715,7 @@ message TensorProto { // Complex64 elements must be written as two consecutive FLOAT values, real component first. // Complex128 elements must be written as two consecutive DOUBLE values, real component first. // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB. // // Note: the advantage of specific field rather than the raw_data field is // that in some cases (e.g. int data), protobuf does a better packing via @@ -631,6 +758,9 @@ message TensorProto { // When this field is present, the data_type field MUST be // UINT32 or UINT64 repeated uint64 uint64_data = 11 [packed = true]; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 16; } // A serialized sparse-tensor value @@ -777,9 +907,8 @@ enum OperatorStatus { } message FunctionProto { - // The name of the function, similar usage of op_type in OperatorProto. - // Combined with FunctionProto.domain, this forms the unique identity of - // the FunctionProto. + // The name of the function, similar to op_type in NodeProto. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. string name = 1; // Deprecated since IR Version 8 @@ -826,9 +955,22 @@ message FunctionProto { repeated OperatorSetIdProto opset_import = 9; - // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of - // the FunctionProto. + // The domain which this function belongs to. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. string domain = 10; + + // The overload identifier of the function. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + string overload = 13; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; } // For using protobuf-lite