mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Bump ort
from 2.0.0-rc.9 to 2.0.0-rc.10 (#107)
This commit is contained in:
8
.github/workflows/rust-ci.yml
vendored
8
.github/workflows/rust-ci.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
|
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
|
||||||
|
|
||||||
- name: Setup Rust
|
- name: Setup Rust
|
||||||
uses: dtolnay/rust-toolchain@stable
|
uses: dtolnay/rust-toolchain@stable
|
||||||
@ -47,7 +47,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
|
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
|
||||||
|
|
||||||
- name: Setup Rust
|
- name: Setup Rust
|
||||||
uses: dtolnay/rust-toolchain@stable
|
uses: dtolnay/rust-toolchain@stable
|
||||||
@ -67,7 +67,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
|
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
|
||||||
|
|
||||||
- name: Setup Rust
|
- name: Setup Rust
|
||||||
uses: dtolnay/rust-toolchain@nightly
|
uses: dtolnay/rust-toolchain@nightly
|
||||||
@ -93,7 +93,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
DEBIAN_FRONTEND=noninteractive apt-get update --fix-missing
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ca-certificates clang curl pkg-config protobuf-compiler
|
DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential libssl-dev ca-certificates clang curl pkg-config protobuf-compiler
|
||||||
|
|
||||||
- name: Setup Rust
|
- name: Setup Rust
|
||||||
uses: dtolnay/rust-toolchain@stable
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
18
Cargo.toml
18
Cargo.toml
@ -38,10 +38,10 @@ fast_image_resize = { version = "5.1.2", features = ["image"] }
|
|||||||
ndarray-npy = "0.9.1"
|
ndarray-npy = "0.9.1"
|
||||||
half = { version = "2.3.1" }
|
half = { version = "2.3.1" }
|
||||||
prost = "0.13.5"
|
prost = "0.13.5"
|
||||||
ort = { version = "2.0.0-rc.9", default-features = false, optional = true, features = [
|
ort = { version = "=2.0.0-rc.10", default-features = false, optional = true, features = [
|
||||||
"ndarray",
|
|
||||||
"copy-dylibs",
|
"copy-dylibs",
|
||||||
"half",
|
"half",
|
||||||
|
"std",
|
||||||
] }
|
] }
|
||||||
tokenizers = { version = "0.21.1" }
|
tokenizers = { version = "0.21.1" }
|
||||||
paste = "1.0.15"
|
paste = "1.0.15"
|
||||||
@ -54,10 +54,10 @@ argh = "0.1.13"
|
|||||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] }
|
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["ort-download-binaries"]
|
default = [ "ort-download-binaries" ]
|
||||||
ort-download-binaries = ["ort", "ort/download-binaries"]
|
video = [ "dep:video-rs" ]
|
||||||
ort-load-dynamic = ["ort", "ort/load-dynamic"]
|
ort-download-binaries = [ "ort", "ort/download-binaries" ]
|
||||||
cuda = ["ort/cuda"]
|
ort-load-dynamic = [ "ort", "ort/load-dynamic" ]
|
||||||
trt = ["ort/tensorrt"]
|
cuda = [ "ort/cuda" ]
|
||||||
mps = ["ort/coreml"]
|
trt = [ "ort/tensorrt" ]
|
||||||
video = ["dep:video-rs"]
|
coreml = [ "ort/coreml" ]
|
||||||
|
3
build.rs
3
build.rs
@ -3,8 +3,5 @@ use std::io::Result;
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
prost_build::compile_protos(&["src/utils/onnx.proto3"], &["src"])?;
|
prost_build::compile_protos(&["src/utils/onnx.proto3"], &["src"])?;
|
||||||
|
|
||||||
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
|
|
||||||
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -20,10 +20,12 @@ use crate::{
|
|||||||
impl From<TensorElementType> for DType {
|
impl From<TensorElementType> for DType {
|
||||||
fn from(dtype: TensorElementType) -> Self {
|
fn from(dtype: TensorElementType) -> Self {
|
||||||
match dtype {
|
match dtype {
|
||||||
|
TensorElementType::Int4 => Self::Int4,
|
||||||
TensorElementType::Int8 => Self::Int8,
|
TensorElementType::Int8 => Self::Int8,
|
||||||
TensorElementType::Int16 => Self::Int16,
|
TensorElementType::Int16 => Self::Int16,
|
||||||
TensorElementType::Int32 => Self::Int32,
|
TensorElementType::Int32 => Self::Int32,
|
||||||
TensorElementType::Int64 => Self::Int64,
|
TensorElementType::Int64 => Self::Int64,
|
||||||
|
TensorElementType::Uint4 => Self::Uint4,
|
||||||
TensorElementType::Uint8 => Self::Uint8,
|
TensorElementType::Uint8 => Self::Uint8,
|
||||||
TensorElementType::Uint16 => Self::Uint16,
|
TensorElementType::Uint16 => Self::Uint16,
|
||||||
TensorElementType::Uint32 => Self::Uint32,
|
TensorElementType::Uint32 => Self::Uint32,
|
||||||
@ -32,14 +34,19 @@ impl From<TensorElementType> for DType {
|
|||||||
TensorElementType::Float32 => Self::Fp32,
|
TensorElementType::Float32 => Self::Fp32,
|
||||||
TensorElementType::Float64 => Self::Fp64,
|
TensorElementType::Float64 => Self::Fp64,
|
||||||
TensorElementType::Bfloat16 => Self::Bf16,
|
TensorElementType::Bfloat16 => Self::Bf16,
|
||||||
TensorElementType::String => Self::String,
|
TensorElementType::Float8E4M3FN => Self::Fp8e4m3fn,
|
||||||
TensorElementType::Bool => Self::Bool,
|
TensorElementType::Float8E4M3FNUZ => Self::Fp8e4m3fnuz,
|
||||||
|
TensorElementType::Float8E5M2 => Self::Fp8e5m2,
|
||||||
|
TensorElementType::Float8E5M2FNUZ => Self::Fp8e5m2fnuz,
|
||||||
|
TensorElementType::Complex64 => Self::Complex64,
|
||||||
|
TensorElementType::Complex128 => Self::Complex128,
|
||||||
|
_ => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions.
|
/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions.
|
||||||
#[derive(Builder, Debug, Clone)]
|
#[derive(Builder, Debug, Clone, Default)]
|
||||||
pub struct OrtTensorAttr {
|
pub struct OrtTensorAttr {
|
||||||
pub names: Vec<String>,
|
pub names: Vec<String>,
|
||||||
pub dtypes: Vec<TensorElementType>,
|
pub dtypes: Vec<TensorElementType>,
|
||||||
@ -133,7 +140,9 @@ impl Engine {
|
|||||||
let param = tensor_proto.dims.iter().product::<i64>() as usize;
|
let param = tensor_proto.dims.iter().product::<i64>() as usize;
|
||||||
params += param;
|
params += param;
|
||||||
let param = Ops::make_divisible(param, byte_alignment);
|
let param = Ops::make_divisible(param, byte_alignment);
|
||||||
let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize);
|
let n = Self::get_ort_dtype_from_proto_dtype_id(tensor_proto.data_type)
|
||||||
|
.map(|x| x.byte_size(1))
|
||||||
|
.unwrap_or_default();
|
||||||
let wbmem = param * n;
|
let wbmem = param * n;
|
||||||
wbmems += wbmem;
|
wbmems += wbmem;
|
||||||
}
|
}
|
||||||
@ -145,7 +154,10 @@ impl Engine {
|
|||||||
let param = tensor.dims.iter().product::<i64>() as usize;
|
let param = tensor.dims.iter().product::<i64>() as usize;
|
||||||
params += param;
|
params += param;
|
||||||
let param = Ops::make_divisible(param, byte_alignment);
|
let param = Ops::make_divisible(param, byte_alignment);
|
||||||
let n = Self::nbytes_from_onnx_dtype_id(tensor.data_type as usize);
|
let n = Self::get_ort_dtype_from_proto_dtype_id(tensor.data_type)
|
||||||
|
.map(|x| x.byte_size(1))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
let wbmem = param * n;
|
let wbmem = param * n;
|
||||||
wbmems += wbmem;
|
wbmems += wbmem;
|
||||||
}
|
}
|
||||||
@ -211,7 +223,7 @@ impl Engine {
|
|||||||
|
|
||||||
// update
|
// update
|
||||||
pb.set_message(format!(
|
pb.set_message(format!(
|
||||||
"{}({}) on {:?}",
|
"{}({}) on {}",
|
||||||
self.spec,
|
self.spec,
|
||||||
match self.params {
|
match self.params {
|
||||||
Some(bytes) if bytes != 0 => {
|
Some(bytes) if bytes != 0 => {
|
||||||
@ -231,7 +243,7 @@ impl Engine {
|
|||||||
|
|
||||||
pub fn run(&mut self, xs: Xs) -> Result<Xs> {
|
pub fn run(&mut self, xs: Xs) -> Result<Xs> {
|
||||||
let mut ys = xs.derive();
|
let mut ys = xs.derive();
|
||||||
if let Some(onnx) = &self.onnx {
|
if let Some(onnx) = &mut self.onnx {
|
||||||
// alignment
|
// alignment
|
||||||
let xs_ = elapsed!(&format!("[{}] ort_preprocessing", self.spec), self.ts, {
|
let xs_ = elapsed!(&format!("[{}] ort_preprocessing", self.spec), self.ts, {
|
||||||
let mut xs_ = Vec::new();
|
let mut xs_ = Vec::new();
|
||||||
@ -267,38 +279,22 @@ impl Engine {
|
|||||||
|
|
||||||
fn preprocess(x: &X, dtype: &TensorElementType) -> Result<DynValue> {
|
fn preprocess(x: &X, dtype: &TensorElementType) -> Result<DynValue> {
|
||||||
let x = match dtype {
|
let x = match dtype {
|
||||||
TensorElementType::Float32 => Value::from_array(x.view())?.into_dyn(),
|
TensorElementType::Float32 | TensorElementType::Float64 => {
|
||||||
TensorElementType::Float16 => {
|
Value::from_array(x.0.clone())?.into_dyn()
|
||||||
Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
|
||||||
}
|
}
|
||||||
TensorElementType::Float64 => Value::from_array(x.view())?.into_dyn(),
|
TensorElementType::Float16 => Value::from_array(x.mapv(f16::from_f32))?.into_dyn(),
|
||||||
TensorElementType::Bfloat16 => {
|
TensorElementType::Bfloat16 => Value::from_array(x.mapv(bf16::from_f32))?.into_dyn(),
|
||||||
Value::from_array(x.mapv(bf16::from_f32).view())?.into_dyn()
|
TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8))?.into_dyn(),
|
||||||
}
|
TensorElementType::Int16 => Value::from_array(x.mapv(|x_| x_ as i16))?.into_dyn(),
|
||||||
TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn(),
|
TensorElementType::Int32 => Value::from_array(x.mapv(|x_| x_ as i32))?.into_dyn(),
|
||||||
TensorElementType::Int16 => {
|
TensorElementType::Int64 => Value::from_array(x.mapv(|x_| x_ as i64))?.into_dyn(),
|
||||||
Value::from_array(x.mapv(|x_| x_ as i16).view())?.into_dyn()
|
TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8))?.into_dyn(),
|
||||||
}
|
TensorElementType::Uint16 => Value::from_array(x.mapv(|x_| x_ as u16))?.into_dyn(),
|
||||||
TensorElementType::Int32 => {
|
TensorElementType::Uint32 => Value::from_array(x.mapv(|x_| x_ as u32))?.into_dyn(),
|
||||||
Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
TensorElementType::Uint64 => Value::from_array(x.mapv(|x_| x_ as u64))?.into_dyn(),
|
||||||
}
|
TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.))?.into_dyn(),
|
||||||
TensorElementType::Int64 => {
|
|
||||||
Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
|
||||||
}
|
|
||||||
TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn(),
|
|
||||||
TensorElementType::Uint16 => {
|
|
||||||
Value::from_array(x.mapv(|x_| x_ as u16).view())?.into_dyn()
|
|
||||||
}
|
|
||||||
TensorElementType::Uint32 => {
|
|
||||||
Value::from_array(x.mapv(|x_| x_ as u32).view())?.into_dyn()
|
|
||||||
}
|
|
||||||
TensorElementType::Uint64 => {
|
|
||||||
Value::from_array(x.mapv(|x_| x_ as u64).view())?.into_dyn()
|
|
||||||
}
|
|
||||||
TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn(),
|
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,7 +303,7 @@ impl Engine {
|
|||||||
where
|
where
|
||||||
T: Clone + 'static + ort::tensor::PrimitiveTensorElementType,
|
T: Clone + 'static + ort::tensor::PrimitiveTensorElementType,
|
||||||
{
|
{
|
||||||
match x.try_extract_tensor::<T>() {
|
match x.try_extract_array::<T>() {
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err);
|
debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err);
|
||||||
Array::zeros(0).into_dyn()
|
Array::zeros(0).into_dyn()
|
||||||
@ -344,7 +340,7 @@ impl Engine {
|
|||||||
\nConsider enabling them by passing, e.g., `--features #FEATURE`";
|
\nConsider enabling them by passing, e.g., `--features #FEATURE`";
|
||||||
|
|
||||||
match self.device {
|
match self.device {
|
||||||
Device::TensorRT(id) => {
|
Device::TensorRt(id) => {
|
||||||
#[cfg(not(feature = "trt"))]
|
#[cfg(not(feature = "trt"))]
|
||||||
{
|
{
|
||||||
anyhow::bail!(feature_help
|
anyhow::bail!(feature_help
|
||||||
@ -431,16 +427,28 @@ impl Engine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Device::CoreML(id) => {
|
Device::CoreMl(id) => {
|
||||||
#[cfg(not(feature = "mps"))]
|
#[cfg(not(feature = "coreml"))]
|
||||||
{
|
{
|
||||||
anyhow::bail!(feature_help
|
anyhow::bail!(feature_help
|
||||||
.replace("#EP", "CoreML")
|
.replace("#EP", "CoreML")
|
||||||
.replace("#FEATURE", "mps"));
|
.replace("#FEATURE", "coreml"));
|
||||||
}
|
}
|
||||||
#[cfg(feature = "mps")]
|
#[cfg(feature = "coreml")]
|
||||||
{
|
{
|
||||||
let ep = ort::execution_providers::CoreMLExecutionProvider::default();
|
let ep = ort::execution_providers::CoreMLExecutionProvider::default()
|
||||||
|
.with_model_cache_dir(
|
||||||
|
crate::Dir::Cache
|
||||||
|
.crate_dir_default_with_subs(&["coreml-cache"])?
|
||||||
|
.display(),
|
||||||
|
)
|
||||||
|
.with_compute_units(ort::execution_providers::coreml::CoreMLComputeUnits::All)
|
||||||
|
.with_static_input_shapes(false)
|
||||||
|
.with_subgraphs(true)
|
||||||
|
.with_model_format(ort::execution_providers::coreml::CoreMLModelFormat::MLProgram)
|
||||||
|
.with_specialization_strategy(
|
||||||
|
ort::execution_providers::coreml::CoreMLSpecializationStrategy::FastPrediction,
|
||||||
|
);
|
||||||
match ep.is_available() {
|
match ep.is_available() {
|
||||||
Ok(true) => {
|
Ok(true) => {
|
||||||
ep.register(&mut builder).map_err(|err| {
|
ep.register(&mut builder).map_err(|err| {
|
||||||
@ -452,13 +460,14 @@ impl Engine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let ep = ort::execution_providers::CPUExecutionProvider::default();
|
let ep = ort::execution_providers::CPUExecutionProvider::default()
|
||||||
|
.with_arena_allocator(true);
|
||||||
match ep.is_available() {
|
match ep.is_available() {
|
||||||
Ok(true) => {
|
Ok(true) => {
|
||||||
ep.register(&mut builder)
|
ep.register(&mut builder)
|
||||||
.map_err(|err| anyhow::anyhow!("Failed to register Cpu: {}", err))?;
|
.map_err(|err| anyhow::anyhow!("Failed to register Cpu: {}", err))?;
|
||||||
}
|
}
|
||||||
_ => anyhow::bail!(compile_help.replace("#EP", "Cpu")),
|
_ => unreachable!("CPU EP is not available. This case should ideally not be reached under normal circumstances."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -532,38 +541,8 @@ impl Engine {
|
|||||||
Ok(ys)
|
Ok(ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
fn get_ort_dtype_from_proto_dtype_id(value: i32) -> Option<TensorElementType> {
|
||||||
fn nbytes_from_onnx_dtype_id(x: usize) -> usize {
|
|
||||||
match x {
|
|
||||||
7 | 11 | 13 => 8, // i64, f64, u64
|
|
||||||
1 | 6 | 12 => 4, // f32, i32, u32
|
|
||||||
10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16
|
|
||||||
2 | 3 | 9 => 1, // u8, i8, bool
|
|
||||||
8 => 4, // string(1~4)
|
|
||||||
_ => 1, // TODO: others
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn nbytes_from_onnx_dtype(x: &TensorElementType) -> usize {
|
|
||||||
match x {
|
|
||||||
TensorElementType::Float64 | TensorElementType::Uint64 | TensorElementType::Int64 => 8, // i64, f64, u64
|
|
||||||
TensorElementType::Float32
|
|
||||||
| TensorElementType::Uint32
|
|
||||||
| TensorElementType::Int32
|
|
||||||
| TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
|
||||||
TensorElementType::Float16
|
|
||||||
| TensorElementType::Bfloat16
|
|
||||||
| TensorElementType::Int16
|
|
||||||
| TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
|
||||||
TensorElementType::Uint8 | TensorElementType::Int8 | TensorElementType::Bool => 1, // u8, i8, bool
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<TensorElementType> {
|
|
||||||
match value {
|
match value {
|
||||||
0 => None,
|
|
||||||
1 => Some(TensorElementType::Float32),
|
1 => Some(TensorElementType::Float32),
|
||||||
2 => Some(TensorElementType::Uint8),
|
2 => Some(TensorElementType::Uint8),
|
||||||
3 => Some(TensorElementType::Int8),
|
3 => Some(TensorElementType::Int8),
|
||||||
@ -577,10 +556,16 @@ impl Engine {
|
|||||||
11 => Some(TensorElementType::Float64),
|
11 => Some(TensorElementType::Float64),
|
||||||
12 => Some(TensorElementType::Uint32),
|
12 => Some(TensorElementType::Uint32),
|
||||||
13 => Some(TensorElementType::Uint64),
|
13 => Some(TensorElementType::Uint64),
|
||||||
14 => None, // COMPLEX64
|
14 => Some(TensorElementType::Complex64),
|
||||||
15 => None, // COMPLEX128
|
15 => Some(TensorElementType::Complex128),
|
||||||
16 => Some(TensorElementType::Bfloat16),
|
16 => Some(TensorElementType::Bfloat16),
|
||||||
_ => None,
|
17 => Some(TensorElementType::Float8E4M3FN),
|
||||||
|
18 => Some(TensorElementType::Float8E4M3FNUZ),
|
||||||
|
19 => Some(TensorElementType::Float8E5M2),
|
||||||
|
20 => Some(TensorElementType::Float8E5M2FNUZ),
|
||||||
|
21 => Some(TensorElementType::Uint4),
|
||||||
|
22 => Some(TensorElementType::Int4),
|
||||||
|
_ => None, // 23: Float4e2m1, 0: Undefined
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -609,7 +594,7 @@ impl Engine {
|
|||||||
_ => continue,
|
_ => continue,
|
||||||
};
|
};
|
||||||
let tensor_type = tensor.elem_type;
|
let tensor_type = tensor.elem_type;
|
||||||
let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) {
|
let tensor_type = match Self::get_ort_dtype_from_proto_dtype_id(tensor_type) {
|
||||||
Some(dtype) => dtype,
|
Some(dtype) => dtype,
|
||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
@ -642,24 +627,6 @@ impl Engine {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn to_ort(&self) -> TensorElementType {
|
|
||||||
// match self {
|
|
||||||
// Self::Int8 => TensorElementType::Int8,
|
|
||||||
// Self::Int16 => TensorElementType::Int16,
|
|
||||||
// Self::Int32 => TensorElementType::Int32,
|
|
||||||
// Self::Int64 => TensorElementType::Int64,
|
|
||||||
// Self::Uint8 => TensorElementType::Uint8,
|
|
||||||
// Self::Uint16 => TensorElementType::Uint16,
|
|
||||||
// Self::Uint32 => TensorElementType::Uint32,
|
|
||||||
// Self::Uint64 => TensorElementType::Uint64,
|
|
||||||
// Self::Fp16 => TensorElementType::Float16,
|
|
||||||
// Self::Fp32 => TensorElementType::Float32,
|
|
||||||
// Self::Fp64 => TensorElementType::Float64,
|
|
||||||
// Self::Bf16 => TensorElementType::Bfloat16,
|
|
||||||
// _ => todo!(),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn load_onnx<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
pub fn load_onnx<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
||||||
let f = std::fs::read(p.as_ref())?;
|
let f = std::fs::read(p.as_ref())?;
|
||||||
onnx::ModelProto::decode(f.as_slice()).map_err(|err| {
|
onnx::ModelProto::decode(f.as_slice()).map_err(|err| {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use image::DynamicImage;
|
use image::DynamicImage;
|
||||||
use ndarray::{Array, Dim, IntoDimension, Ix2, IxDyn, IxDynImpl};
|
use ndarray::{Array, Dim, IntoDimension, Ix2, IxDyn, IxDynImpl};
|
||||||
// use std::ops::Mul;
|
|
||||||
|
|
||||||
use crate::{Ops, ResizeMode};
|
use crate::{Ops, ResizeMode};
|
||||||
|
|
||||||
|
@ -88,9 +88,6 @@ pub struct Hub {
|
|||||||
/// Directory to store the downloaded file
|
/// Directory to store the downloaded file
|
||||||
to: Dir,
|
to: Dir,
|
||||||
|
|
||||||
/// Download timeout in seconds
|
|
||||||
timeout: u64,
|
|
||||||
|
|
||||||
/// Time to live (cache duration)
|
/// Time to live (cache duration)
|
||||||
ttl: Duration,
|
ttl: Duration,
|
||||||
|
|
||||||
@ -116,7 +113,6 @@ impl Default for Hub {
|
|||||||
owner,
|
owner,
|
||||||
repo,
|
repo,
|
||||||
to,
|
to,
|
||||||
timeout: 3000,
|
|
||||||
max_attempts: 3,
|
max_attempts: 3,
|
||||||
ttl: Duration::from_secs(10 * 60),
|
ttl: Duration::from_secs(10 * 60),
|
||||||
}
|
}
|
||||||
@ -195,7 +191,7 @@ impl Hub {
|
|||||||
.join(&file_name_);
|
.join(&file_name_);
|
||||||
|
|
||||||
pack = pack.with_url(s).with_tag(&tag_).with_file_name(&file_name_);
|
pack = pack.with_url(s).with_tag(&tag_).with_file_name(&file_name_);
|
||||||
if let Some(n) = retry!(self.max_attempts, Self::fetch_get_response(s))?
|
if let Some(n) = retry!(self.max_attempts, self.fetch_get_response(s))?
|
||||||
.headers()
|
.headers()
|
||||||
.get(http::header::CONTENT_LENGTH)
|
.get(http::header::CONTENT_LENGTH)
|
||||||
.and_then(|v| v.to_str().ok()?.parse::<u64>().ok())
|
.and_then(|v| v.to_str().ok()?.parse::<u64>().ok())
|
||||||
@ -208,7 +204,7 @@ impl Hub {
|
|||||||
// => Default hub
|
// => Default hub
|
||||||
|
|
||||||
// Fetch releases
|
// Fetch releases
|
||||||
let releases = match Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl) {
|
let releases = match self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl) {
|
||||||
Err(err) => anyhow::bail!(
|
Err(err) => anyhow::bail!(
|
||||||
"Failed to download: No releases found in this repo. Error: {}",
|
"Failed to download: No releases found in this repo. Error: {}",
|
||||||
err
|
err
|
||||||
@ -286,7 +282,7 @@ impl Hub {
|
|||||||
self.max_attempts,
|
self.max_attempts,
|
||||||
1000,
|
1000,
|
||||||
3000,
|
3000,
|
||||||
Self::download(
|
self.download(
|
||||||
&pack.url,
|
&pack.url,
|
||||||
&saveout,
|
&saveout,
|
||||||
Some(&format!("{}/{}", pack.tag, pack.file_name)),
|
Some(&format!("{}/{}", pack.tag, pack.file_name)),
|
||||||
@ -303,7 +299,7 @@ impl Hub {
|
|||||||
self.max_attempts,
|
self.max_attempts,
|
||||||
1000,
|
1000,
|
||||||
3000,
|
3000,
|
||||||
Self::download(
|
self.download(
|
||||||
&pack.url,
|
&pack.url,
|
||||||
&saveout,
|
&saveout,
|
||||||
Some(&format!("{}/{}", pack.tag, pack.file_name)),
|
Some(&format!("{}/{}", pack.tag, pack.file_name)),
|
||||||
@ -319,8 +315,8 @@ impl Hub {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Fetch releases from GitHub and cache them
|
/// Fetch releases from GitHub and cache them
|
||||||
fn fetch_and_cache_releases(url: &str, cache_path: &Path) -> Result<String> {
|
fn fetch_and_cache_releases(&self, url: &str, cache_path: &Path) -> Result<String> {
|
||||||
let response = retry!(3, Self::fetch_get_response(url))?;
|
let response = retry!(self.max_attempts, self.fetch_get_response(url))?;
|
||||||
let body = response
|
let body = response
|
||||||
.into_body()
|
.into_body()
|
||||||
.read_to_string()
|
.read_to_string()
|
||||||
@ -351,7 +347,7 @@ impl Hub {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn tags(&self) -> Vec<String> {
|
pub fn tags(&self) -> Vec<String> {
|
||||||
Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
|
self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|x| x.tag_name)
|
.map(|x| x.tag_name)
|
||||||
@ -359,7 +355,7 @@ impl Hub {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn files(&self, tag: &str) -> Vec<String> {
|
pub fn files(&self, tag: &str) -> Vec<String> {
|
||||||
Self::get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
|
self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl)
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find(|r| r.tag_name == tag)
|
.find(|r| r.tag_name == tag)
|
||||||
@ -394,11 +390,12 @@ impl Hub {
|
|||||||
|
|
||||||
/// Download a file from a github release to a specified path with a progress bar
|
/// Download a file from a github release to a specified path with a progress bar
|
||||||
pub fn download<P: AsRef<Path> + std::fmt::Debug>(
|
pub fn download<P: AsRef<Path> + std::fmt::Debug>(
|
||||||
|
&self,
|
||||||
src: &str,
|
src: &str,
|
||||||
dst: P,
|
dst: P,
|
||||||
message: Option<&str>,
|
message: Option<&str>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let resp = Self::fetch_get_response(src)?;
|
let resp = self.fetch_get_response(src)?;
|
||||||
let ntotal = resp
|
let ntotal = resp
|
||||||
.headers()
|
.headers()
|
||||||
.get(http::header::CONTENT_LENGTH)
|
.get(http::header::CONTENT_LENGTH)
|
||||||
@ -412,7 +409,8 @@ impl Hub {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut reader = resp.into_body().into_reader();
|
let mut reader = resp.into_body().into_reader();
|
||||||
let mut buffer = [0; 2048];
|
const BUFFER_SIZE: usize = 64 * 1024;
|
||||||
|
let mut buffer = [0; BUFFER_SIZE];
|
||||||
let mut downloaded_bytes = 0usize;
|
let mut downloaded_bytes = 0usize;
|
||||||
let mut file = std::fs::File::create(&dst)
|
let mut file = std::fs::File::create(&dst)
|
||||||
.with_context(|| format!("Failed to create destination file: {:?}", dst))?;
|
.with_context(|| format!("Failed to create destination file: {:?}", dst))?;
|
||||||
@ -442,7 +440,7 @@ impl Hub {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_get_response(url: &str) -> anyhow::Result<http::Response<ureq::Body>> {
|
fn fetch_get_response(&self, url: &str) -> anyhow::Result<http::Response<ureq::Body>> {
|
||||||
let config = ureq::Agent::config_builder()
|
let config = ureq::Agent::config_builder()
|
||||||
.proxy(ureq::Proxy::try_from_env())
|
.proxy(ureq::Proxy::try_from_env())
|
||||||
.build();
|
.build();
|
||||||
@ -462,10 +460,16 @@ impl Hub {
|
|||||||
fn cache_file(owner: &str, repo: &str) -> String {
|
fn cache_file(owner: &str, repo: &str) -> String {
|
||||||
let safe_owner = owner.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
let safe_owner = owner.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
||||||
let safe_repo = repo.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
let safe_repo = repo.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
||||||
format!(".releases_{}_{}.json", safe_owner, safe_repo)
|
format!(".cache-releases-{}-{}.json", safe_owner, safe_repo)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_releases(owner: &str, repo: &str, to: &Dir, ttl: &Duration) -> Result<Vec<Release>> {
|
fn get_releases(
|
||||||
|
&self,
|
||||||
|
owner: &str,
|
||||||
|
repo: &str,
|
||||||
|
to: &Dir,
|
||||||
|
ttl: &Duration,
|
||||||
|
) -> Result<Vec<Release>> {
|
||||||
let cache = to.crate_dir_default()?.join(Self::cache_file(owner, repo));
|
let cache = to.crate_dir_default()?.join(Self::cache_file(owner, repo));
|
||||||
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
|
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
|
||||||
let body = if is_file_expired {
|
let body = if is_file_expired {
|
||||||
@ -473,7 +477,7 @@ impl Hub {
|
|||||||
"https://api.github.com/repos/{}/{}/releases?per_page=100",
|
"https://api.github.com/repos/{}/{}/releases?per_page=100",
|
||||||
owner, repo
|
owner, repo
|
||||||
);
|
);
|
||||||
Self::fetch_and_cache_releases(&gh_api_release, &cache)?
|
self.fetch_and_cache_releases(&gh_api_release, &cache)?
|
||||||
} else {
|
} else {
|
||||||
std::fs::read_to_string(&cache)?
|
std::fs::read_to_string(&cache)?
|
||||||
};
|
};
|
||||||
@ -518,11 +522,6 @@ impl Hub {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_timeout(mut self, x: u64) -> Self {
|
|
||||||
self.timeout = x;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_max_attempts(mut self, x: u32) -> Self {
|
pub fn with_max_attempts(mut self, x: u32) -> Self {
|
||||||
self.max_attempts = x;
|
self.max_attempts = x;
|
||||||
self
|
self
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu(usize),
|
Cpu(usize),
|
||||||
Cuda(usize),
|
Cuda(usize),
|
||||||
TensorRT(usize),
|
TensorRt(usize),
|
||||||
CoreML(usize),
|
CoreMl(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Device {
|
impl Default for Device {
|
||||||
@ -15,10 +15,10 @@ impl Default for Device {
|
|||||||
impl std::fmt::Display for Device {
|
impl std::fmt::Display for Device {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
let x = match self {
|
let x = match self {
|
||||||
Self::Cpu(i) => format!("cpu:{}", i),
|
Self::Cpu(i) => format!("CPU:{}", i),
|
||||||
Self::Cuda(i) => format!("cuda:{}", i),
|
Self::Cuda(i) => format!("CUDA:{}(NVIDIA)", i),
|
||||||
Self::CoreML(i) => format!("mps:{}", i),
|
Self::TensorRt(i) => format!("TensorRT:{}(NVIDIA)", i),
|
||||||
Self::TensorRT(i) => format!("tensorrt:{}", i),
|
Self::CoreMl(i) => format!("CoreML:{}(Apple)", i),
|
||||||
};
|
};
|
||||||
write!(f, "{}", x)
|
write!(f, "{}", x)
|
||||||
}
|
}
|
||||||
@ -41,8 +41,9 @@ impl TryFrom<&str> for Device {
|
|||||||
match d.to_lowercase().as_str() {
|
match d.to_lowercase().as_str() {
|
||||||
"cpu" => Ok(Self::Cpu(id)),
|
"cpu" => Ok(Self::Cpu(id)),
|
||||||
"cuda" => Ok(Self::Cuda(id)),
|
"cuda" => Ok(Self::Cuda(id)),
|
||||||
"trt" | "tensorrt" => Ok(Self::TensorRT(id)),
|
"trt" | "tensorrt" => Ok(Self::TensorRt(id)),
|
||||||
"coreml" | "mps" => Ok(Self::CoreML(id)),
|
"coreml" | "mps" => Ok(Self::CoreMl(id)),
|
||||||
|
|
||||||
_ => anyhow::bail!("Unsupported device str: {s:?}."),
|
_ => anyhow::bail!("Unsupported device str: {s:?}."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -51,10 +52,7 @@ impl TryFrom<&str> for Device {
|
|||||||
impl Device {
|
impl Device {
|
||||||
pub fn id(&self) -> usize {
|
pub fn id(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu(i) => *i,
|
Self::Cpu(i) | Self::Cuda(i) | Self::TensorRt(i) | Self::CoreMl(i) => *i,
|
||||||
Device::Cuda(i) => *i,
|
|
||||||
Device::TensorRT(i) => *i,
|
|
||||||
Device::CoreML(i) => *i,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ pub enum DType {
|
|||||||
Int16,
|
Int16,
|
||||||
Int32,
|
Int32,
|
||||||
Int64,
|
Int64,
|
||||||
|
Uint4,
|
||||||
Uint8,
|
Uint8,
|
||||||
Uint16,
|
Uint16,
|
||||||
Uint32,
|
Uint32,
|
||||||
@ -15,12 +16,17 @@ pub enum DType {
|
|||||||
Fp32,
|
Fp32,
|
||||||
Fp64,
|
Fp64,
|
||||||
Bf16,
|
Bf16,
|
||||||
Bool,
|
|
||||||
String,
|
|
||||||
Bnb4,
|
Bnb4,
|
||||||
Q4,
|
Q4,
|
||||||
Q4f16,
|
Q4f16,
|
||||||
Q8,
|
Q8,
|
||||||
|
Fp8e4m3fn,
|
||||||
|
Fp8e4m3fnuz,
|
||||||
|
Fp8e5m2,
|
||||||
|
Fp8e5m2fnuz,
|
||||||
|
Fp4e2m1,
|
||||||
|
Complex64,
|
||||||
|
Complex128,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<&str> for DType {
|
impl TryFrom<&str> for DType {
|
||||||
@ -29,6 +35,7 @@ impl TryFrom<&str> for DType {
|
|||||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||||
match s.to_lowercase().as_str() {
|
match s.to_lowercase().as_str() {
|
||||||
"auto" | "dyn" => Ok(Self::Auto),
|
"auto" | "dyn" => Ok(Self::Auto),
|
||||||
|
"u4" | "uint4" => Ok(Self::Uint4),
|
||||||
"u8" | "uint8" => Ok(Self::Uint8),
|
"u8" | "uint8" => Ok(Self::Uint8),
|
||||||
"u16" | "uint16" => Ok(Self::Uint16),
|
"u16" | "uint16" => Ok(Self::Uint16),
|
||||||
"u32" | "uint32" => Ok(Self::Uint32),
|
"u32" | "uint32" => Ok(Self::Uint32),
|
||||||
@ -46,6 +53,13 @@ impl TryFrom<&str> for DType {
|
|||||||
"q4" => Ok(Self::Q4),
|
"q4" => Ok(Self::Q4),
|
||||||
"q8" => Ok(Self::Q8),
|
"q8" => Ok(Self::Q8),
|
||||||
"bnb4" => Ok(Self::Bnb4),
|
"bnb4" => Ok(Self::Bnb4),
|
||||||
|
"f8e4m3fn" => Ok(Self::Fp8e4m3fn),
|
||||||
|
"f8e4m3fnuz" => Ok(Self::Fp8e4m3fnuz),
|
||||||
|
"f8e5m2" => Ok(Self::Fp8e5m2),
|
||||||
|
"f8e5m2fnuz" => Ok(Self::Fp8e5m2fnuz),
|
||||||
|
"f4e2m1" => Ok(Self::Fp4e2m1),
|
||||||
|
"complex64" => Ok(Self::Complex64),
|
||||||
|
"complex128" => Ok(Self::Complex128),
|
||||||
x => anyhow::bail!("Unsupported DType: {}", x),
|
x => anyhow::bail!("Unsupported DType: {}", x),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -60,6 +74,7 @@ impl std::fmt::Display for DType {
|
|||||||
Self::Int16 => "int16",
|
Self::Int16 => "int16",
|
||||||
Self::Int32 => "int32",
|
Self::Int32 => "int32",
|
||||||
Self::Int64 => "int64",
|
Self::Int64 => "int64",
|
||||||
|
Self::Uint4 => "uint4",
|
||||||
Self::Uint8 => "uint8",
|
Self::Uint8 => "uint8",
|
||||||
Self::Uint16 => "uint16",
|
Self::Uint16 => "uint16",
|
||||||
Self::Uint32 => "uint32",
|
Self::Uint32 => "uint32",
|
||||||
@ -68,12 +83,17 @@ impl std::fmt::Display for DType {
|
|||||||
Self::Fp32 => "fp32",
|
Self::Fp32 => "fp32",
|
||||||
Self::Fp64 => "fp64",
|
Self::Fp64 => "fp64",
|
||||||
Self::Bf16 => "bf16",
|
Self::Bf16 => "bf16",
|
||||||
Self::String => "string",
|
|
||||||
Self::Bool => "bool",
|
|
||||||
Self::Bnb4 => "bnb4",
|
Self::Bnb4 => "bnb4",
|
||||||
Self::Q4 => "q4",
|
Self::Q4 => "q4",
|
||||||
Self::Q4f16 => "q4f16",
|
Self::Q4f16 => "q4f16",
|
||||||
Self::Q8 => "q8",
|
Self::Q8 => "q8",
|
||||||
|
Self::Fp8e4m3fn => "f8e4m3fn",
|
||||||
|
Self::Fp8e4m3fnuz => "f8e4m3fnuz",
|
||||||
|
Self::Fp8e5m2 => "f8e5m2",
|
||||||
|
Self::Fp8e5m2fnuz => "f8e5m2fnuz",
|
||||||
|
Self::Fp4e2m1 => "f4e2m1",
|
||||||
|
Self::Complex64 => "complex64",
|
||||||
|
Self::Complex128 => "complex128",
|
||||||
};
|
};
|
||||||
write!(f, "{}", x)
|
write!(f, "{}", x)
|
||||||
}
|
}
|
||||||
|
@ -86,7 +86,7 @@ enum Version {
|
|||||||
IR_VERSION_2019_9_19 = 0x0000000000000006;
|
IR_VERSION_2019_9_19 = 0x0000000000000006;
|
||||||
|
|
||||||
// IR VERSION 7 published on May 8, 2020
|
// IR VERSION 7 published on May 8, 2020
|
||||||
// - Add support to allow function body graph to rely on multiple external opreator sets.
|
// - Add support to allow function body graph to rely on multiple external operator sets.
|
||||||
// - Add a list to promote inference graph's initializers to global and
|
// - Add a list to promote inference graph's initializers to global and
|
||||||
// mutable variables. Global variables are visible in all graphs of the
|
// mutable variables. Global variables are visible in all graphs of the
|
||||||
// stored models.
|
// stored models.
|
||||||
@ -106,7 +106,15 @@ enum Version {
|
|||||||
// IR VERSION 9 published on May 5, 2023
|
// IR VERSION 9 published on May 5, 2023
|
||||||
// Added AttributeProto to FunctionProto so that default attribute values can be set.
|
// Added AttributeProto to FunctionProto so that default attribute values can be set.
|
||||||
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
|
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
|
||||||
IR_VERSION = 0x0000000000000009;
|
IR_VERSION_2023_5_5 = 0x0000000000000009;
|
||||||
|
|
||||||
|
// IR VERSION 10 published on March 25, 2024
|
||||||
|
// Added UINT4, INT4, overload field for functions and metadata_props on multiple proto definitions.
|
||||||
|
IR_VERSION_2024_3_25 = 0x000000000000000A;
|
||||||
|
|
||||||
|
// IR VERSION 11 published on May 12, 2025
|
||||||
|
// Added FLOAT4E2M1, multi-device protobuf classes.
|
||||||
|
IR_VERSION = 0x000000000000000B;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attributes
|
// Attributes
|
||||||
@ -190,6 +198,8 @@ message ValueInfoProto {
|
|||||||
TypeProto type = 2;
|
TypeProto type = 2;
|
||||||
// A human-readable documentation for this value. Markdown is allowed.
|
// A human-readable documentation for this value. Markdown is allowed.
|
||||||
string doc_string = 3;
|
string doc_string = 3;
|
||||||
|
// Named metadata values; keys should be distinct.
|
||||||
|
repeated StringStringEntryProto metadata_props = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nodes
|
// Nodes
|
||||||
@ -204,19 +214,101 @@ message NodeProto {
|
|||||||
repeated string output = 2; // namespace Value
|
repeated string output = 2; // namespace Value
|
||||||
|
|
||||||
// An optional identifier for this node in a graph.
|
// An optional identifier for this node in a graph.
|
||||||
// This field MAY be absent in ths version of the IR.
|
// This field MAY be absent in this version of the IR.
|
||||||
string name = 3; // namespace Node
|
string name = 3; // namespace Node
|
||||||
|
|
||||||
// The symbolic identifier of the Operator to execute.
|
// The symbolic identifier of the Operator to execute.
|
||||||
string op_type = 4; // namespace Operator
|
string op_type = 4; // namespace Operator
|
||||||
// The domain of the OperatorSet that specifies the operator named by op_type.
|
// The domain of the OperatorSet that specifies the operator named by op_type.
|
||||||
string domain = 7; // namespace Domain
|
string domain = 7; // namespace Domain
|
||||||
|
// Overload identifier, used only to map this to a model-local function.
|
||||||
|
string overload = 8;
|
||||||
|
|
||||||
// Additional named attributes.
|
// Additional named attributes.
|
||||||
repeated AttributeProto attribute = 5;
|
repeated AttributeProto attribute = 5;
|
||||||
|
|
||||||
// A human-readable documentation for this node. Markdown is allowed.
|
// A human-readable documentation for this node. Markdown is allowed.
|
||||||
string doc_string = 6;
|
string doc_string = 6;
|
||||||
|
|
||||||
|
// Named metadata values; keys should be distinct.
|
||||||
|
repeated StringStringEntryProto metadata_props = 9;
|
||||||
|
|
||||||
|
// Configuration of multi-device annotations.
|
||||||
|
repeated NodeDeviceConfigurationProto device_configurations = 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntIntListEntryProto follows the pattern for cross-proto-version maps.
|
||||||
|
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
|
||||||
|
message IntIntListEntryProto {
|
||||||
|
int64 key = 1;
|
||||||
|
repeated int64 value = 2;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Multi-device configuration proto for NodeProto.
|
||||||
|
message NodeDeviceConfigurationProto {
|
||||||
|
// This field MUST be present for this version of the IR.
|
||||||
|
// ID of the configuration. MUST match the name of a DeviceConfigurationProto.
|
||||||
|
string configuration_id = 1;
|
||||||
|
// Sharding spec for the node.
|
||||||
|
repeated ShardingSpecProto sharding_spec = 2;
|
||||||
|
// Pipeline stage of this node.
|
||||||
|
int32 pipeline_stage = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShardingSpecProto: This describes the sharding spec for a specific
|
||||||
|
// input or output tensor of a node.
|
||||||
|
message ShardingSpecProto {
|
||||||
|
// This field MUST be present for this version of the IR.
|
||||||
|
// Identifies the input or output of the node that is being sharded.
|
||||||
|
// Required to match a name specified in the node's input or output list of ValueInfoProtos.
|
||||||
|
// It is called `logical tensor` in subsequent descriptions.
|
||||||
|
string tensor_name = 1;
|
||||||
|
|
||||||
|
// The following is the list of devices across which the logical
|
||||||
|
// tensor is sharded or replicated.
|
||||||
|
repeated int64 device = 2;
|
||||||
|
|
||||||
|
// Each element v in above field devices may represent either a
|
||||||
|
// device or a set of devices (when we want the same shard/tensor
|
||||||
|
// to be replicated across a subset of devices), as indicated by
|
||||||
|
// the following optional map. If the map contains an entry for v,
|
||||||
|
// then v represents a device group, and the map indicates the set
|
||||||
|
// of devices in that group.
|
||||||
|
repeated IntIntListEntryProto index_to_device_group_map = 3;
|
||||||
|
|
||||||
|
// The following is the sharded-shape of the tensor, consisting of
|
||||||
|
// the sharding-spec for each axis of the tensor.
|
||||||
|
repeated ShardedDimProto sharded_dim = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShardedDimProto: This describes the sharding spec for a single
|
||||||
|
// axis of a sharded tensor.
|
||||||
|
message ShardedDimProto {
|
||||||
|
// This field MUST be present for this version of the IR.
|
||||||
|
// The axis this sharding corresponds to. Must be in the range of
|
||||||
|
// [-r, r - 1], where r is the rank of the tensor. Negative axis values means
|
||||||
|
// counting from the back.
|
||||||
|
int64 axis = 1;
|
||||||
|
|
||||||
|
// Describes how the tensor on the provided axis is sharded.
|
||||||
|
// The common-case is described by a single instance of SimpleShardedDimProto.
|
||||||
|
// Multiple instances can be used to handle cases where a sharded
|
||||||
|
// tensor is reshaped, fusing multiple axes into one.
|
||||||
|
repeated SimpleShardedDimProto simple_sharding = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleShardedDimProto: Indicates that N blocks are divided into M shards.
|
||||||
|
// N is allowed to be symbolic where M is required to be a constant.
|
||||||
|
message SimpleShardedDimProto {
|
||||||
|
// Dimension value to be sharded.
|
||||||
|
oneof dim {
|
||||||
|
int64 dim_value = 1;
|
||||||
|
string dim_param = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This field MUST be present for this version of the IR.
|
||||||
|
// Number of shards to split dim into.
|
||||||
|
int64 num_shards = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Training information
|
// Training information
|
||||||
@ -401,9 +493,9 @@ message ModelProto {
|
|||||||
|
|
||||||
// A list of function protos local to the model.
|
// A list of function protos local to the model.
|
||||||
//
|
//
|
||||||
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
|
// The (domain, name, overload) tuple must be unique across the function protos in this list.
|
||||||
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
|
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
|
||||||
// or standard operator sets are given higher priotity or this is treated as error) is defined by
|
// or standard operator sets are given higher priority or this is treated as error) is defined by
|
||||||
// the runtimes.
|
// the runtimes.
|
||||||
//
|
//
|
||||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
// The operator sets imported by FunctionProto should be compatible with the ones
|
||||||
@ -416,8 +508,24 @@ message ModelProto {
|
|||||||
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
|
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
|
||||||
// is not allowed.
|
// is not allowed.
|
||||||
repeated FunctionProto functions = 25;
|
repeated FunctionProto functions = 25;
|
||||||
|
|
||||||
|
// Describes different target configurations for a multi-device use case.
|
||||||
|
// A model MAY describe multiple multi-device configurations for execution.
|
||||||
|
repeated DeviceConfigurationProto configuration = 26;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// DeviceConfigurationProto describes a multi-device configuration for a model.
|
||||||
|
message DeviceConfigurationProto {
|
||||||
|
// This field MUST be present for this version of the IR.
|
||||||
|
// Name of the configuration.
|
||||||
|
string name = 1;
|
||||||
|
// This field MUST be present for this version of the IR.
|
||||||
|
// Number of devices inside this configuration.
|
||||||
|
int32 num_devices = 2;
|
||||||
|
// Optional names of the devices. MUST be length of num_devices if provided.
|
||||||
|
repeated string device = 3;
|
||||||
|
}
|
||||||
|
|
||||||
// StringStringEntryProto follows the pattern for cross-proto-version maps.
|
// StringStringEntryProto follows the pattern for cross-proto-version maps.
|
||||||
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
|
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
|
||||||
message StringStringEntryProto {
|
message StringStringEntryProto {
|
||||||
@ -475,6 +583,9 @@ message GraphProto {
|
|||||||
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
|
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
|
||||||
repeated TensorAnnotation quantization_annotation = 14;
|
repeated TensorAnnotation quantization_annotation = 14;
|
||||||
|
|
||||||
|
// Named metadata values; keys should be distinct.
|
||||||
|
repeated StringStringEntryProto metadata_props = 16;
|
||||||
|
|
||||||
reserved 3, 4, 6 to 9;
|
reserved 3, 4, 6 to 9;
|
||||||
reserved "ir_version", "producer_version", "producer_tag", "domain";
|
reserved "ir_version", "producer_version", "producer_tag", "domain";
|
||||||
}
|
}
|
||||||
@ -520,7 +631,14 @@ message TensorProto {
|
|||||||
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
|
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
|
||||||
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
|
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
|
||||||
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
|
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
|
||||||
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
|
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero
|
||||||
|
|
||||||
|
// 4-bit integer data types
|
||||||
|
UINT4 = 21; // Unsigned integer in range [0, 15]
|
||||||
|
INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation
|
||||||
|
|
||||||
|
// 4-bit floating point data types
|
||||||
|
FLOAT4E2M1 = 23;
|
||||||
|
|
||||||
// Future extensions go here.
|
// Future extensions go here.
|
||||||
}
|
}
|
||||||
@ -555,11 +673,19 @@ message TensorProto {
|
|||||||
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
|
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
|
||||||
repeated float float_data = 4 [packed = true];
|
repeated float float_data = 4 [packed = true];
|
||||||
|
|
||||||
// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
|
// For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4:
|
||||||
// float16 and float8 values must be bit-wise converted to an uint16_t prior
|
// - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer
|
||||||
// to writing to the buffer.
|
// representation before being written to the buffer.
|
||||||
|
// - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte.
|
||||||
|
// The first element is stored in the 4 least significant bits (LSB),
|
||||||
|
// and the second element is stored in the 4 most significant bits (MSB).
|
||||||
|
//
|
||||||
|
// Consequently:
|
||||||
|
// - For data types with a bit-width of 8 or greater, each `int32_data` stores one element.
|
||||||
|
// - For 4-bit data types, each `int32_data` stores two elements.
|
||||||
|
//
|
||||||
// When this field is present, the data_type field MUST be
|
// When this field is present, the data_type field MUST be
|
||||||
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
|
// INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1
|
||||||
repeated int32 int32_data = 5 [packed = true];
|
repeated int32 int32_data = 5 [packed = true];
|
||||||
|
|
||||||
// For strings.
|
// For strings.
|
||||||
@ -589,6 +715,7 @@ message TensorProto {
|
|||||||
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||||
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||||
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||||
|
// uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB.
|
||||||
//
|
//
|
||||||
// Note: the advantage of specific field rather than the raw_data field is
|
// Note: the advantage of specific field rather than the raw_data field is
|
||||||
// that in some cases (e.g. int data), protobuf does a better packing via
|
// that in some cases (e.g. int data), protobuf does a better packing via
|
||||||
@ -631,6 +758,9 @@ message TensorProto {
|
|||||||
// When this field is present, the data_type field MUST be
|
// When this field is present, the data_type field MUST be
|
||||||
// UINT32 or UINT64
|
// UINT32 or UINT64
|
||||||
repeated uint64 uint64_data = 11 [packed = true];
|
repeated uint64 uint64_data = 11 [packed = true];
|
||||||
|
|
||||||
|
// Named metadata values; keys should be distinct.
|
||||||
|
repeated StringStringEntryProto metadata_props = 16;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A serialized sparse-tensor value
|
// A serialized sparse-tensor value
|
||||||
@ -777,9 +907,8 @@ enum OperatorStatus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message FunctionProto {
|
message FunctionProto {
|
||||||
// The name of the function, similar usage of op_type in OperatorProto.
|
// The name of the function, similar to op_type in NodeProto.
|
||||||
// Combined with FunctionProto.domain, this forms the unique identity of
|
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
|
||||||
// the FunctionProto.
|
|
||||||
string name = 1;
|
string name = 1;
|
||||||
|
|
||||||
// Deprecated since IR Version 8
|
// Deprecated since IR Version 8
|
||||||
@ -826,9 +955,22 @@ message FunctionProto {
|
|||||||
|
|
||||||
repeated OperatorSetIdProto opset_import = 9;
|
repeated OperatorSetIdProto opset_import = 9;
|
||||||
|
|
||||||
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
|
// The domain which this function belongs to.
|
||||||
// the FunctionProto.
|
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
|
||||||
string domain = 10;
|
string domain = 10;
|
||||||
|
|
||||||
|
// The overload identifier of the function.
|
||||||
|
// This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.
|
||||||
|
string overload = 13;
|
||||||
|
|
||||||
|
// Information for the values in the function. The ValueInfoProto.name's
|
||||||
|
// must be distinct and refer to names in the function (including inputs,
|
||||||
|
// outputs, and intermediate values). It is optional for a value to appear
|
||||||
|
// in value_info list.
|
||||||
|
repeated ValueInfoProto value_info = 12;
|
||||||
|
|
||||||
|
// Named metadata values; keys should be distinct.
|
||||||
|
repeated StringStringEntryProto metadata_props = 14;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For using protobuf-lite
|
// For using protobuf-lite
|
||||||
|
Reference in New Issue
Block a user