Bump ort from 2.0.0-rc.9 to 2.0.0-rc.10 (#107)

This commit is contained in:
Jamjamjon
2025-06-03 20:53:48 +08:00
committed by GitHub
parent a3a4bf47ed
commit 28f3d18ac3
9 changed files with 291 additions and 169 deletions

View File

@ -22,7 +22,7 @@ jobs:
- name: Install dependencies
run: |
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
@ -47,7 +47,7 @@ jobs:
- name: Install dependencies
run: |
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
@ -67,7 +67,7 @@ jobs:
- name: Install dependencies
run: |
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
- name: Setup Rust
uses: dtolnay/rust-toolchain@nightly
@ -93,7 +93,7 @@ jobs:
- name: Install dependencies
run: |
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable

View File

@ -38,10 +38,10 @@ fast_image_resize = { version = "5.1.2", features = ["image"] }
ndarray-npy = "0.9.1"
half = { version = "2.3.1" }
prost = "0.13.5"
ort = { version = "2.0.0-rc.9", default-features = false, optional = true, features = [
"ndarray",
ort = { version = "=2.0.0-rc.10", default-features = false, optional = true, features = [
"copy-dylibs",
"half",
"std",
] }
tokenizers = { version = "0.21.1" }
paste = "1.0.15"
@ -55,9 +55,9 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] }
[features]
default = [ "ort-download-binaries" ]
video = [ "dep:video-rs" ]
ort-download-binaries = [ "ort", "ort/download-binaries" ]
ort-load-dynamic = [ "ort", "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]
trt = [ "ort/tensorrt" ]
mps = ["ort/coreml"]
video = ["dep:video-rs"]
coreml = [ "ort/coreml" ]

View File

@ -3,8 +3,5 @@ use std::io::Result;
fn main() -> Result<()> {
prost_build::compile_protos(&["src/utils/onnx.proto3"], &["src"])?;
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
Ok(())
}

View File

@ -20,10 +20,12 @@ use crate::{
impl From<TensorElementType> for DType {
fn from(dtype: TensorElementType) -> Self {
match dtype {
TensorElementType::Int4 => Self::Int4,
TensorElementType::Int8 => Self::Int8,
TensorElementType::Int16 => Self::Int16,
TensorElementType::Int32 => Self::Int32,
TensorElementType::Int64 => Self::Int64,
TensorElementType::Uint4 => Self::Uint4,
TensorElementType::Uint8 => Self::Uint8,
TensorElementType::Uint16 => Self::Uint16,
TensorElementType::Uint32 => Self::Uint32,
@ -32,14 +34,19 @@ impl From<TensorElementType> for DType {
TensorElementType::Float32 => Self::Fp32,
TensorElementType::Float64 => Self::Fp64,
TensorElementType::Bfloat16 => Self::Bf16,
TensorElementType::String => Self::String,
TensorElementType::Bool => Self::Bool,
TensorElementType::Float8E4M3FN => Self::Fp8e4m3fn,
TensorElementType::Float8E4M3FNUZ => Self::Fp8e4m3fnuz,
TensorElementType::Float8E5M2 => Self::Fp8e5m2,
TensorElementType::Float8E5M2FNUZ => Self::Fp8e5m2fnuz,
TensorElementType::Complex64 => Self::Complex64,
TensorElementType::Complex128 => Self::Complex128,
_ => todo!(),
}
}
}
/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions.
#[derive(Builder, Debug, Clone)]
#[derive(Builder, Debug, Clone, Default)]
pub struct OrtTensorAttr {
pub names: Vec<String>,
pub dtypes: Vec<TensorElementType>,
@ -133,7 +140,9 @@ impl Engine {
let param = tensor_proto.dims.iter().product::<i64>() as usize;
params += param;
let param = Ops::make_divisible(param, byte_alignment);
let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize);
let n = Self::get_ort_dtype_from_proto_dtype_id(tensor_proto.data_type)
.map(|x| x.byte_size(1))
.unwrap_or_default();
let wbmem = param * n;
wbmems += wbmem;
}
@ -145,7 +154,10 @@ impl Engine {
let param = tensor.dims.iter().product::<i64>() as usize;
params += param;
let param = Ops::make_divisible(param, byte_alignment);
let n = Self::nbytes_from_onnx_dtype_id(tensor.data_type as usize);
let n = Self::get_ort_dtype_from_proto_dtype_id(tensor.data_type)
.map(|x| x.byte_size(1))
.unwrap_or_default();
let wbmem = param * n;
wbmems += wbmem;
}
@ -211,7 +223,7 @@ impl Engine {
// update
pb.set_message(format!(
"{}({}) on {:?}",
"{}({}) on {}",
self.spec,
match self.params {
Some(bytes) if bytes != 0 => {
@ -231,7 +243,7 @@ impl Engine {
pub fn run(&mut self, xs: Xs) -> Result<Xs> {
let mut ys = xs.derive();
if let Some(onnx) = &self.onnx {
if let Some(onnx) = &mut self.onnx {
// alignment
let xs_ = elapsed!(&format!("[{}] ort_preprocessing", self.spec), self.ts, {
let mut xs_ = Vec::new();
@ -267,38 +279,22 @@ impl Engine {
fn preprocess(x: &X, dtype: &TensorElementType) -> Result<DynValue> {
let x = match dtype {
TensorElementType::Float32 => Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float16 => {
Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
TensorElementType::Float32 | TensorElementType::Float64 => {
Value::from_array(x.0.clone())?.into_dyn()
}
TensorElementType::Float64 => Value::from_array(x.view())?.into_dyn(),
TensorElementType::Bfloat16 => {
Value::from_array(x.mapv(bf16::from_f32).view())?.into_dyn()
}
TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn(),
TensorElementType::Int16 => {
Value::from_array(x.mapv(|x_| x_ as i16).view())?.into_dyn()
}
TensorElementType::Int32 => {
Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
}
TensorElementType::Int64 => {
Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
}
TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn(),
TensorElementType::Uint16 => {
Value::from_array(x.mapv(|x_| x_ as u16).view())?.into_dyn()
}
TensorElementType::Uint32 => {
Value::from_array(x.mapv(|x_| x_ as u32).view())?.into_dyn()
}
TensorElementType::Uint64 => {
Value::from_array(x.mapv(|x_| x_ as u64).view())?.into_dyn()
}
TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn(),
TensorElementType::Float16 => Value::from_array(x.mapv(f16::from_f32))?.into_dyn(),
TensorElementType::Bfloat16 => Value::from_array(x.mapv(bf16::from_f32))?.into_dyn(),
TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8))?.into_dyn(),
TensorElementType::Int16 => Value::from_array(x.mapv(|x_| x_ as i16))?.into_dyn(),
TensorElementType::Int32 => Value::from_array(x.mapv(|x_| x_ as i32))?.into_dyn(),
TensorElementType::Int64 => Value::from_array(x.mapv(|x_| x_ as i64))?.into_dyn(),
TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8))?.into_dyn(),
TensorElementType::Uint16 => Value::from_array(x.mapv(|x_| x_ as u16))?.into_dyn(),
TensorElementType::Uint32 => Value::from_array(x.mapv(|x_| x_ as u32))?.into_dyn(),
TensorElementType::Uint64 => Value::from_array(x.mapv(|x_| x_ as u64))?.into_dyn(),
TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.))?.into_dyn(),
_ => unimplemented!(),
};
Ok(x)
}
@ -307,7 +303,7 @@ impl Engine {
where
T: Clone + 'static + ort::tensor::PrimitiveTensorElementType,
{
match x.try_extract_tensor::<T>() {
match x.try_extract_array::<T>() {
Err(err) => {
debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err);
Array::zeros(0).into_dyn()
@ -344,7 +340,7 @@ impl Engine {
\nConsider enabling them by passing, e.g., `--features #FEATURE`";
match self.device {
Device::TensorRT(id) => {
Device::TensorRt(id) => {
#[cfg(not(feature = "trt"))]
{
anyhow::bail!(feature_help
@ -431,16 +427,28 @@ impl Engine {
}
}
}
Device::CoreML(id) => {
#[cfg(not(feature = "mps"))]
Device::CoreMl(id) => {
#[cfg(not(feature = "coreml"))]
{
anyhow::bail!(feature_help
.replace("#EP", "CoreML")
.replace("#FEATURE", "mps"));
.replace("#FEATURE", "coreml"));
}
#[cfg(feature = "mps")]
#[cfg(feature = "coreml")]
{
let ep = ort::execution_providers::CoreMLExecutionProvider::default();
let ep = ort::execution_providers::CoreMLExecutionProvider::default()
.with_model_cache_dir(
crate::Dir::Cache
.crate_dir_default_with_subs(&["coreml-cache"])?
.display(),
)
.with_compute_units(ort::execution_providers::coreml::CoreMLComputeUnits::All)
.with_static_input_shapes(false)
.with_subgraphs(true)
.with_model_format(ort::execution_providers::coreml::CoreMLModelFormat::MLProgram)
.with_specialization_strategy(
ort::execution_providers::coreml::CoreMLSpecializationStrategy::FastPrediction,
);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
@ -452,13 +460,14 @@ impl Engine {
}
}
_ => {
let ep = ort::execution_providers::CPUExecutionProvider::default();
let ep = ort::execution_providers::CPUExecutionProvider::default()
.with_arena_allocator(true);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder)
.map_err(|err| anyhow::anyhow!("Failed to register Cpu: {}", err))?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "Cpu")),
_ => unreachable!("CPU EP is not available. This case should ideally not be reached under normal circumstances."),
}
}
}
@ -532,38 +541,8 @@ impl Engine {
Ok(ys)
}
#[allow(dead_code)]
fn nbytes_from_onnx_dtype_id(x: usize) -> usize {
match x {
7 | 11 | 13 => 8, // i64, f64, u64
1 | 6 | 12 => 4, // f32, i32, u32
10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16
2 | 3 | 9 => 1, // u8, i8, bool
8 => 4, // string(1~4)
_ => 1, // TODO: others
}
}
#[allow(dead_code)]
fn nbytes_from_onnx_dtype(x: &TensorElementType) -> usize {
match x {
TensorElementType::Float64 | TensorElementType::Uint64 | TensorElementType::Int64 => 8, // i64, f64, u64
TensorElementType::Float32
| TensorElementType::Uint32
| TensorElementType::Int32
| TensorElementType::String => 4, // f32, i32, u32, string(1~4)
TensorElementType::Float16
| TensorElementType::Bfloat16
| TensorElementType::Int16
| TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
TensorElementType::Uint8 | TensorElementType::Int8 | TensorElementType::Bool => 1, // u8, i8, bool
}
}
#[allow(dead_code)]
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<TensorElementType> {
fn get_ort_dtype_from_proto_dtype_id(value: i32) -> Option<TensorElementType> {
match value {
0 => None,
1 => Some(TensorElementType::Float32),
2 => Some(TensorElementType::Uint8),
3 => Some(TensorElementType::Int8),
@ -577,10 +556,16 @@ impl Engine {
11 => Some(TensorElementType::Float64),
12 => Some(TensorElementType::Uint32),
13 => Some(TensorElementType::Uint64),
14 => None, // COMPLEX64
15 => None, // COMPLEX128
14 => Some(TensorElementType::Complex64),
15 => Some(TensorElementType::Complex128),
16 => Some(TensorElementType::Bfloat16),
_ => None,
17 => Some(TensorElementType::Float8E4M3FN),
18 => Some(TensorElementType::Float8E4M3FNUZ),
19 => Some(TensorElementType::Float8E5M2),
20 => Some(TensorElementType::Float8E5M2FNUZ),
21 => Some(TensorElementType::Uint4),
22 => Some(TensorElementType::Int4),
_ => None, // 23: Float4e2m1, 0: Undefined
}
}
@ -609,7 +594,7 @@ impl Engine {
_ => continue,
};
let tensor_type = tensor.elem_type;
let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) {
let tensor_type = match Self::get_ort_dtype_from_proto_dtype_id(tensor_type) {
Some(dtype) => dtype,
None => continue,
};
@ -642,24 +627,6 @@ impl Engine {
})
}
// pub fn to_ort(&self) -> TensorElementType {
// match self {
// Self::Int8 => TensorElementType::Int8,
// Self::Int16 => TensorElementType::Int16,
// Self::Int32 => TensorElementType::Int32,
// Self::Int64 => TensorElementType::Int64,
// Self::Uint8 => TensorElementType::Uint8,
// Self::Uint16 => TensorElementType::Uint16,
// Self::Uint32 => TensorElementType::Uint32,
// Self::Uint64 => TensorElementType::Uint64,
// Self::Fp16 => TensorElementType::Float16,
// Self::Fp32 => TensorElementType::Float32,
// Self::Fp64 => TensorElementType::Float64,
// Self::Bf16 => TensorElementType::Bfloat16,
// _ => todo!(),
// }
// }
pub fn load_onnx<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let f = std::fs::read(p.as_ref())?;
onnx::ModelProto::decode(f.as_slice()).map_err(|err| {

View File

@ -1,7 +1,6 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Dim, IntoDimension, Ix2, IxDyn, IxDynImpl};
// use std::ops::Mul;
use crate::{Ops, ResizeMode};

View File

@ -88,9 +88,6 @@ pub struct Hub {
/// Directory to store the downloaded file
to: Dir,
/// Download timeout in seconds
timeout: u64,
/// Time to live (cache duration)
ttl: Duration,
@ -116,7 +113,6 @@ impl Default for Hub {
owner,
repo,
to,
timeout: 3000,
max_attempts: 3,
ttl: Duration::from_secs(10 * 60),
}
@ -195,7 +191,7 @@ impl Hub {
.join(&file_name_);
pack = pack.with_url(s).with_tag(&tag_).with_file_name(&file_name_);
if let Some(n) = retry!(self.max_attempts, Self::fetch_get_response(s))?
if let Some(n) = retry!(self.max_attempts, self.fetch_get_response(s))?
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok()?.parse::<u64>().ok())
@ -208,7 +204,7 @@ impl Hub {
// => Default hub
// Fetch releases
let releases = match Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) {
let releases = match self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl) {
Err(err) => anyhow::bail!(
"Failed to download: No releases found in this repo. Error: {}",
err
@ -286,7 +282,7 @@ impl Hub {
self.max_attempts,
1000,
3000,
Self::download(
self.download(
&pack.url,
&saveout,
Some(&format!("{}/{}", pack.tag, pack.file_name)),
@ -303,7 +299,7 @@ impl Hub {
self.max_attempts,
1000,
3000,
Self::download(
self.download(
&pack.url,
&saveout,
Some(&format!("{}/{}", pack.tag, pack.file_name)),
@ -319,8 +315,8 @@ impl Hub {
}
/// Fetch releases from GitHub and cache them
fn fetch_and_cache_releases(url: &str, cache_path: &Path) -> Result<String> {
let response = retry!(3, Self::fetch_get_response(url))?;
fn fetch_and_cache_releases(&self, url: &str, cache_path: &Path) -> Result<String> {
let response = retry!(self.max_attempts, self.fetch_get_response(url))?;
let body = response
.into_body()
.read_to_string()
@ -351,7 +347,7 @@ impl Hub {
}
pub fn tags(&self) -> Vec<String> {
Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
.unwrap_or_default()
.into_iter()
.map(|x| x.tag_name)
@ -359,7 +355,7 @@ impl Hub {
}
pub fn files(&self, tag: &str) -> Vec<String> {
Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
.unwrap_or_default()
.into_iter()
.find(|r| r.tag_name == tag)
@ -394,11 +390,12 @@ impl Hub {
/// Download a file from a github release to a specified path with a progress bar
pub fn download<P: AsRef<Path> + std::fmt::Debug>(
&self,
src: &str,
dst: P,
message: Option<&str>,
) -> Result<()> {
let resp = Self::fetch_get_response(src)?;
let resp = self.fetch_get_response(src)?;
let ntotal = resp
.headers()
.get(http::header::CONTENT_LENGTH)
@ -412,7 +409,8 @@ impl Hub {
)?;
let mut reader = resp.into_body().into_reader();
let mut buffer = [0; 2048];
const BUFFER_SIZE: usize = 64 * 1024;
let mut buffer = [0; BUFFER_SIZE];
let mut downloaded_bytes = 0usize;
let mut file = std::fs::File::create(&dst)
.with_context(|| format!("Failed to create destination file: {:?}", dst))?;
@ -442,7 +440,7 @@ impl Hub {
Ok(())
}
fn fetch_get_response(url: &str) -> anyhow::Result<http::Response<ureq::Body>> {
fn fetch_get_response(&self, url: &str) -> anyhow::Result<http::Response<ureq::Body>> {
let config = ureq::Agent::config_builder()
.proxy(ureq::Proxy::try_from_env())
.build();
@ -462,10 +460,16 @@ impl Hub {
fn cache_file(owner: &str, repo: &str) -> String {
let safe_owner = owner.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
let safe_repo = repo.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
format!(".releases_{}_{}.json", safe_owner, safe_repo)
format!(".cache-releases-{}-{}.json", safe_owner, safe_repo)
}
fn get_releases(owner: &str, repo: &str, to: &Dir, ttl: &Duration) -> Result<Vec<Release>> {
fn get_releases(
&self,
owner: &str,
repo: &str,
to: &Dir,
ttl: &Duration,
) -> Result<Vec<Release>> {
let cache = to.crate_dir_default()?.join(Self::cache_file(owner, repo));
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
let body = if is_file_expired {
@ -473,7 +477,7 @@ impl Hub {
"https://api.github.com/repos/{}/{}/releases?per_page=100",
owner, repo
);
Self::fetch_and_cache_releases(&gh_api_release, &cache)?
self.fetch_and_cache_releases(&gh_api_release, &cache)?
} else {
std::fs::read_to_string(&cache)?
};
@ -518,11 +522,6 @@ impl Hub {
self
}
pub fn with_timeout(mut self, x: u64) -> Self {
self.timeout = x;
self
}
pub fn with_max_attempts(mut self, x: u32) -> Self {
self.max_attempts = x;
self

View File

@ -2,8 +2,8 @@
pub enum Device {
Cpu(usize),
Cuda(usize),
TensorRT(usize),
CoreML(usize),
TensorRt(usize),
CoreMl(usize),
}
impl Default for Device {
@ -15,10 +15,10 @@ impl Default for Device {
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let x = match self {
Self::Cpu(i) => format!("cpu:{}", i),
Self::Cuda(i) => format!("cuda:{}", i),
Self::CoreML(i) => format!("mps:{}", i),
Self::TensorRT(i) => format!("tensorrt:{}", i),
Self::Cpu(i) => format!("CPU:{}", i),
Self::Cuda(i) => format!("CUDA:{}(NVIDIA)", i),
Self::TensorRt(i) => format!("TensorRT:{}(NVIDIA)", i),
Self::CoreMl(i) => format!("CoreML:{}(Apple)", i),
};
write!(f, "{}", x)
}
@ -41,8 +41,9 @@ impl TryFrom<&str> for Device {
match d.to_lowercase().as_str() {
"cpu" => Ok(Self::Cpu(id)),
"cuda" => Ok(Self::Cuda(id)),
"trt" | "tensorrt" => Ok(Self::TensorRT(id)),
"coreml" | "mps" => Ok(Self::CoreML(id)),
"trt" | "tensorrt" => Ok(Self::TensorRt(id)),
"coreml" | "mps" => Ok(Self::CoreMl(id)),
_ => anyhow::bail!("Unsupported device str: {s:?}."),
}
}
@ -51,10 +52,7 @@ impl TryFrom<&str> for Device {
impl Device {
pub fn id(&self) -> usize {
match self {
Device::Cpu(i) => *i,
Device::Cuda(i) => *i,
Device::TensorRT(i) => *i,
Device::CoreML(i) => *i,
Self::Cpu(i) | Self::Cuda(i) | Self::TensorRt(i) | Self::CoreMl(i) => *i,
}
}
}

View File

@ -7,6 +7,7 @@ pub enum DType {
Int16,
Int32,
Int64,
Uint4,
Uint8,
Uint16,
Uint32,
@ -15,12 +16,17 @@ pub enum DType {
Fp32,
Fp64,
Bf16,
Bool,
String,
Bnb4,
Q4,
Q4f16,
Q8,
Fp8e4m3fn,
Fp8e4m3fnuz,
Fp8e5m2,
Fp8e5m2fnuz,
Fp4e2m1,
Complex64,
Complex128,
}
impl TryFrom<&str> for DType {
@ -29,6 +35,7 @@ impl TryFrom<&str> for DType {
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s.to_lowercase().as_str() {
"auto" | "dyn" => Ok(Self::Auto),
"u4" | "uint4" => Ok(Self::Uint4),
"u8" | "uint8" => Ok(Self::Uint8),
"u16" | "uint16" => Ok(Self::Uint16),
"u32" | "uint32" => Ok(Self::Uint32),
@ -46,6 +53,13 @@ impl TryFrom<&str> for DType {
"q4" => Ok(Self::Q4),
"q8" => Ok(Self::Q8),
"bnb4" => Ok(Self::Bnb4),
"f8e4m3fn" => Ok(Self::Fp8e4m3fn),
"f8e4m3fnuz" => Ok(Self::Fp8e4m3fnuz),
"f8e5m2" => Ok(Self::Fp8e5m2),
"f8e5m2fnuz" => Ok(Self::Fp8e5m2fnuz),
"f4e2m1" => Ok(Self::Fp4e2m1),
"complex64" => Ok(Self::Complex64),
"complex128" => Ok(Self::Complex128),
x => anyhow::bail!("Unsupported DType: {}", x),
}
}
@ -60,6 +74,7 @@ impl std::fmt::Display for DType {
Self::Int16 => "int16",
Self::Int32 => "int32",
Self::Int64 => "int64",
Self::Uint4 => "uint4",
Self::Uint8 => "uint8",
Self::Uint16 => "uint16",
Self::Uint32 => "uint32",
@ -68,12 +83,17 @@ impl std::fmt::Display for DType {
Self::Fp32 => "fp32",
Self::Fp64 => "fp64",
Self::Bf16 => "bf16",
Self::String => "string",
Self::Bool => "bool",
Self::Bnb4 => "bnb4",
Self::Q4 => "q4",
Self::Q4f16 => "q4f16",
Self::Q8 => "q8",
Self::Fp8e4m3fn => "f8e4m3fn",
Self::Fp8e4m3fnuz => "f8e4m3fnuz",
Self::Fp8e5m2 => "f8e5m2",
Self::Fp8e5m2fnuz => "f8e5m2fnuz",
Self::Fp4e2m1 => "f4e2m1",
Self::Complex64 => "complex64",
Self::Complex128 => "complex128",
};
write!(f, "{}", x)
}

View File

@ -86,7 +86,7 @@ enum Version {
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add support to allow function body graph to rely on multiple external operator sets.
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
@ -106,7 +106,15 @@ enum Version {
// IR VERSION 9 published on May 5, 2023
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
IR_VERSION_2023_5_5 = 0x0000000000000009;
// IR VERSION 10 published on March 25, 2024
// Added UINT4, INT4, overload field for functions and metadata_props on multiple proto definitions.
IR_VERSION_2024_3_25 = 0x000000000000000A;
// IR VERSION 11 published on May 12, 2025
// Added FLOAT4E2M1, multi-device protobuf classes.
IR_VERSION = 0x000000000000000B;
}
// Attributes
@ -190,6 +198,8 @@ message ValueInfoProto {
TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
string doc_string = 3;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 4;
}
// Nodes
@ -204,19 +214,101 @@ message NodeProto {
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
// This field MAY be absent in this version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
string domain = 7; // namespace Domain
// Overload identifier, used only to map this to a model-local function.
string overload = 8;
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
string doc_string = 6;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 9;
// Configuration of multi-device annotations.
repeated NodeDeviceConfigurationProto device_configurations = 10;
}
// IntIntListEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message IntIntListEntryProto {
int64 key = 1;
repeated int64 value = 2;
};
// Multi-device configuration proto for NodeProto.
message NodeDeviceConfigurationProto {
// This field MUST be present for this version of the IR.
// ID of the configuration. MUST match the name of a DeviceConfigurationProto.
string configuration_id = 1;
// Sharding spec for the node.
repeated ShardingSpecProto sharding_spec = 2;
// Pipeline stage of this node.
int32 pipeline_stage = 3;
}
// ShardingSpecProto: This describes the sharding spec for a specific
// input or output tensor of a node.
message ShardingSpecProto {
// This field MUST be present for this version of the IR.
// Identifies the input or output of the node that is being sharded.
// Required to match a name specified in the node's input or output list of ValueInfoProtos.
// It is called `logical tensor` in subsequent descriptions.
string tensor_name = 1;
// The following is the list of devices across which the logical
// tensor is sharded or replicated.
repeated int64 device = 2;
// Each element v in above field devices may represent either a
// device or a set of devices (when we want the same shard/tensor
// to be replicated across a subset of devices), as indicated by
// the following optional map. If the map contains an entry for v,
// then v represents a device group, and the map indicates the set
// of devices in that group.
repeated IntIntListEntryProto index_to_device_group_map = 3;
// The following is the sharded-shape of the tensor, consisting of
// the sharding-spec for each axis of the tensor.
repeated ShardedDimProto sharded_dim = 4;
}
// ShardedDimProto: This describes the sharding spec for a single
// axis of a sharded tensor.
message ShardedDimProto {
// This field MUST be present for this version of the IR.
// The axis this sharding corresponds to. Must be in the range of
// [-r, r - 1], where r is the rank of the tensor. Negative axis values means
// counting from the back.
int64 axis = 1;
// Describes how the tensor on the provided axis is sharded.
// The common-case is described by a single instance of SimpleShardedDimProto.
// Multiple instances can be used to handle cases where a sharded
// tensor is reshaped, fusing multiple axes into one.
repeated SimpleShardedDimProto simple_sharding = 2;
}
// SimpleShardedDimProto: Indicates that N blocks are divided into M shards.
// N is allowed to be symbolic where M is required to be a constant.
message SimpleShardedDimProto {
// Dimension value to be sharded.
oneof dim {
int64 dim_value = 1;
string dim_param = 2;
}
// This field MUST be present for this version of the IR.
// Number of shards to split dim into.
int64 num_shards = 3;
}
// Training information
@ -401,9 +493,9 @@ message ModelProto {
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// The (domain, name, overload) tuple must be unique across the function protos in this list.
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard operator sets are given higher priotity or this is treated as error) is defined by
// or standard operator sets are given higher priority or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
@ -416,8 +508,24 @@ message ModelProto {
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
// Describes different target configurations for a multi-device use case.
// A model MAY describe multiple multi-device configurations for execution.
repeated DeviceConfigurationProto configuration = 26;
};
// DeviceConfigurationProto describes a multi-device configuration for a model.
message DeviceConfigurationProto {
// This field MUST be present for this version of the IR.
// Name of the configuration.
string name = 1;
// This field MUST be present for this version of the IR.
// Number of devices inside this configuration.
int32 num_devices = 2;
// Optional names of the devices. MUST be length of num_devices if provided.
repeated string device = 3;
}
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
@ -475,6 +583,9 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 16;
reserved 3, 4, 6 to 9;
reserved "ir_version", "producer_version", "producer_tag", "domain";
}
@ -520,7 +631,14 @@ message TensorProto {
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero
// 4-bit integer data types
UINT4 = 21; // Unsigned integer in range [0, 15]
INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation
// 4-bit floating point data types
FLOAT4E2M1 = 23;
// Future extensions go here.
}
@ -555,11 +673,19 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4:
// - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer
// representation before being written to the buffer.
// - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte.
// The first element is stored in the 4 least significant bits (LSB),
// and the second element is stored in the 4 most significant bits (MSB).
//
// Consequently:
// - For data types with a bit-width of 8 or greater, each `int32_data` stores one element.
// - For 4-bit data types, each `int32_data` stores two elements.
//
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
// INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1
repeated int32 int32_data = 5 [packed = true];
// For strings.
@ -589,6 +715,7 @@ message TensorProto {
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
// uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB.
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
@ -631,6 +758,9 @@ message TensorProto {
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 16;
}
// A serialized sparse-tensor value
@ -777,9 +907,8 @@ enum OperatorStatus {
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
// The name of the function, similar to op_type in NodeProto.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
string name = 1;
// Deprecated since IR Version 8
@ -826,9 +955,22 @@ message FunctionProto {
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
// The domain which this function belongs to.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
string domain = 10;
// The overload identifier of the function.
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
string overload = 13;
// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
}
// For using protobuf-lite