upgrade ort to v2.0.0-rc.9 (#52)

This commit is contained in:
Collide
2024-12-03 19:16:23 +08:00
committed by GitHub
parent 57db14ce5d
commit 2785b090c6
2 changed files with 60 additions and 58 deletions

View File

@@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
[dependencies] [dependencies]
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
ndarray = { version = "0.16.1", features = ["rayon"] } ndarray = { version = "0.16.1", features = ["rayon"] }
ort = { version = "2.0.0-rc.5", default-features = false} ort = { version = "2.0.0-rc.9", default-features = false }
anyhow = { version = "1.0.75" } anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" } regex = { version = "1.5.4" }
rand = { version = "0.8.5" } rand = { version = "0.8.5" }
@@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
ab_glyph = "0.2.23" ab_glyph = "0.2.23"
geo = "0.28.0" geo = "0.28.0"
prost = "0.12.4" prost = "0.12.4"
fast_image_resize = { version = "4.2.1", features = ["image"]} fast_image_resize = { version = "4.2.1", features = ["image"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
tempfile = "3.12.0" tempfile = "3.12.0"
@@ -50,7 +50,6 @@ default = [
"ort/cuda", "ort/cuda",
"ort/tensorrt", "ort/tensorrt",
"ort/coreml", "ort/coreml",
"ort/operator-libraries"
] ]
auto = ["ort/download-binaries"] auto = ["ort/download-binaries"]

View File

@@ -2,7 +2,9 @@ use anyhow::Result;
use half::f16; use half::f16;
use ndarray::{Array, IxDyn}; use ndarray::{Array, IxDyn};
use ort::{ use ort::{
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider, execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
session::{builder::SessionBuilder, Session},
tensor::TensorElementType,
}; };
use prost::Message; use prost::Message;
use std::collections::HashSet; use std::collections::HashSet;
@@ -88,14 +90,14 @@ impl OrtEngine {
// build // build
ort::init().commit()?; ort::init().commit()?;
let builder = Session::builder()?; let mut builder = Session::builder()?;
let mut device = config.device.to_owned(); let mut device = config.device.to_owned();
match device { match device {
Device::Trt(device_id) => { Device::Trt(device_id) => {
Self::build_trt( Self::build_trt(
&inputs_attrs.names, &inputs_attrs.names,
&inputs_minoptmax, &inputs_minoptmax,
&builder, &mut builder,
device_id, device_id,
config.trt_int8_enable, config.trt_int8_enable,
config.trt_fp16_enable, config.trt_fp16_enable,
@@ -103,23 +105,23 @@ impl OrtEngine {
)?; )?;
} }
Device::Cuda(device_id) => { Device::Cuda(device_id) => {
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| { Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu"); tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0); device = Device::Cpu(0);
}) })
} }
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| { Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu"); tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0); device = Device::Cpu(0);
}), }),
Device::Cpu(_) => { Device::Cpu(_) => {
Self::build_cpu(&builder)?; Self::build_cpu(&mut builder)?;
} }
_ => todo!(), _ => todo!(),
} }
let session = builder let session = builder
.with_optimization_level(ort::GraphOptimizationLevel::Level3)? .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
.commit_from_file(&config.onnx_path)?; .commit_from_file(&config.onnx_path)?;
// summary // summary
@@ -149,7 +151,7 @@ impl OrtEngine {
fn build_trt( fn build_trt(
names: &[String], names: &[String],
inputs_minoptmax: &[Vec<MinOptMax>], inputs_minoptmax: &[Vec<MinOptMax>],
builder: &SessionBuilder, builder: &mut SessionBuilder,
device_id: usize, device_id: usize,
int8_enable: bool, int8_enable: bool,
fp16_enable: bool, fp16_enable: bool,
@@ -205,8 +207,9 @@ impl OrtEngine {
} }
} }
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> { fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32); let ep = ort::execution_providers::CUDAExecutionProvider::default()
.with_device_id(device_id as i32);
if ep.is_available()? && ep.register(builder).is_ok() { if ep.is_available()? && ep.register(builder).is_ok() {
Ok(()) Ok(())
} else { } else {
@@ -214,8 +217,8 @@ impl OrtEngine {
} }
} }
fn build_coreml(builder: &SessionBuilder) -> Result<()> { fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only(); let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
if ep.is_available()? && ep.register(builder).is_ok() { if ep.is_available()? && ep.register(builder).is_ok() {
Ok(()) Ok(())
} else { } else {
@@ -223,8 +226,8 @@ impl OrtEngine {
} }
} }
fn build_cpu(builder: &SessionBuilder) -> Result<()> { fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::CPUExecutionProvider::default(); let ep = ort::execution_providers::CPUExecutionProvider::default();
if ep.is_available()? && ep.register(builder).is_ok() { if ep.is_available()? && ep.register(builder).is_ok() {
Ok(()) Ok(())
} else { } else {
@@ -292,28 +295,28 @@ impl OrtEngine {
let t_pre = std::time::Instant::now(); let t_pre = std::time::Instant::now();
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) { for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
let x_ = match &idtype { let x_ = match &idtype {
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(), TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float16 => { TensorElementType::Float16 => {
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
} }
TensorElementType::Int32 => { TensorElementType::Int32 => {
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
} }
TensorElementType::Int64 => { TensorElementType::Int64 => {
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
} }
TensorElementType::Uint8 => { TensorElementType::Uint8 => {
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
} }
TensorElementType::Int8 => { TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
} }
TensorElementType::Bool => { TensorElementType::Bool => {
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
} }
_ => todo!(), _ => todo!(),
}; };
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_)); xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
} }
let t_pre = t_pre.elapsed(); let t_pre = t_pre.elapsed();
self.ts.add_or_push(0, t_pre); self.ts.add_or_push(0, t_pre);
@@ -451,45 +454,45 @@ impl OrtEngine {
} }
#[allow(dead_code)] #[allow(dead_code)]
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize { fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
match x { match x {
ort::TensorElementType::Float64 ort::tensor::TensorElementType::Float64
| ort::TensorElementType::Uint64 | ort::tensor::TensorElementType::Uint64
| ort::TensorElementType::Int64 => 8, // i64, f64, u64 | ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
ort::TensorElementType::Float32 ort::tensor::TensorElementType::Float32
| ort::TensorElementType::Uint32 | ort::tensor::TensorElementType::Uint32
| ort::TensorElementType::Int32 | ort::tensor::TensorElementType::Int32
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4) | ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
ort::TensorElementType::Float16 ort::tensor::TensorElementType::Float16
| ort::TensorElementType::Bfloat16 | ort::tensor::TensorElementType::Bfloat16
| ort::TensorElementType::Int16 | ort::tensor::TensorElementType::Int16
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 | ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
ort::TensorElementType::Uint8 ort::tensor::TensorElementType::Uint8
| ort::TensorElementType::Int8 | ort::tensor::TensorElementType::Int8
| ort::TensorElementType::Bool => 1, // u8, i8, bool | ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
} }
} }
#[allow(dead_code)] #[allow(dead_code)]
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> { fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
match value { match value {
0 => None, 0 => None,
1 => Some(ort::TensorElementType::Float32), 1 => Some(ort::tensor::TensorElementType::Float32),
2 => Some(ort::TensorElementType::Uint8), 2 => Some(ort::tensor::TensorElementType::Uint8),
3 => Some(ort::TensorElementType::Int8), 3 => Some(ort::tensor::TensorElementType::Int8),
4 => Some(ort::TensorElementType::Uint16), 4 => Some(ort::tensor::TensorElementType::Uint16),
5 => Some(ort::TensorElementType::Int16), 5 => Some(ort::tensor::TensorElementType::Int16),
6 => Some(ort::TensorElementType::Int32), 6 => Some(ort::tensor::TensorElementType::Int32),
7 => Some(ort::TensorElementType::Int64), 7 => Some(ort::tensor::TensorElementType::Int64),
8 => Some(ort::TensorElementType::String), 8 => Some(ort::tensor::TensorElementType::String),
9 => Some(ort::TensorElementType::Bool), 9 => Some(ort::tensor::TensorElementType::Bool),
10 => Some(ort::TensorElementType::Float16), 10 => Some(ort::tensor::TensorElementType::Float16),
11 => Some(ort::TensorElementType::Float64), 11 => Some(ort::tensor::TensorElementType::Float64),
12 => Some(ort::TensorElementType::Uint32), 12 => Some(ort::tensor::TensorElementType::Uint32),
13 => Some(ort::TensorElementType::Uint64), 13 => Some(ort::tensor::TensorElementType::Uint64),
14 => None, // COMPLEX64 14 => None, // COMPLEX64
15 => None, // COMPLEX128 15 => None, // COMPLEX128
16 => Some(ort::TensorElementType::Bfloat16), 16 => Some(ort::tensor::TensorElementType::Bfloat16),
_ => None, _ => None,
} }
} }
@@ -499,7 +502,7 @@ impl OrtEngine {
value_info: &[onnx::ValueInfoProto], value_info: &[onnx::ValueInfoProto],
) -> Result<OrtTensorAttr> { ) -> Result<OrtTensorAttr> {
let mut dimss: Vec<Vec<usize>> = Vec::new(); let mut dimss: Vec<Vec<usize>> = Vec::new();
let mut dtypes: Vec<ort::TensorElementType> = Vec::new(); let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
let mut names: Vec<String> = Vec::new(); let mut names: Vec<String> = Vec::new();
for v in value_info.iter() { for v in value_info.iter() {
if initializer_names.contains(v.name.as_str()) { if initializer_names.contains(v.name.as_str()) {
@@ -569,7 +572,7 @@ impl OrtEngine {
&self.outputs_attrs.names &self.outputs_attrs.names
} }
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> { pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.outputs_attrs.dtypes &self.outputs_attrs.dtypes
} }
@@ -585,7 +588,7 @@ impl OrtEngine {
&self.inputs_attrs.names &self.inputs_attrs.names
} }
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> { pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.inputs_attrs.dtypes &self.inputs_attrs.dtypes
} }