mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
upgrade ort to v2.0.0-rc.9 (#52)
This commit is contained in:
@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
|
||||
[dependencies]
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
ndarray = { version = "0.16.1", features = ["rayon"] }
|
||||
ort = { version = "2.0.0-rc.5", default-features = false}
|
||||
ort = { version = "2.0.0-rc.9", default-features = false }
|
||||
anyhow = { version = "1.0.75" }
|
||||
regex = { version = "1.5.4" }
|
||||
rand = { version = "0.8.5" }
|
||||
@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
|
||||
ab_glyph = "0.2.23"
|
||||
geo = "0.28.0"
|
||||
prost = "0.12.4"
|
||||
fast_image_resize = { version = "4.2.1", features = ["image"]}
|
||||
fast_image_resize = { version = "4.2.1", features = ["image"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tempfile = "3.12.0"
|
||||
@ -50,7 +50,6 @@ default = [
|
||||
"ort/cuda",
|
||||
"ort/tensorrt",
|
||||
"ort/coreml",
|
||||
"ort/operator-libraries"
|
||||
]
|
||||
auto = ["ort/download-binaries"]
|
||||
|
||||
|
@ -2,7 +2,9 @@ use anyhow::Result;
|
||||
use half::f16;
|
||||
use ndarray::{Array, IxDyn};
|
||||
use ort::{
|
||||
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
|
||||
execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
|
||||
session::{builder::SessionBuilder, Session},
|
||||
tensor::TensorElementType,
|
||||
};
|
||||
use prost::Message;
|
||||
use std::collections::HashSet;
|
||||
@ -88,14 +90,14 @@ impl OrtEngine {
|
||||
|
||||
// build
|
||||
ort::init().commit()?;
|
||||
let builder = Session::builder()?;
|
||||
let mut builder = Session::builder()?;
|
||||
let mut device = config.device.to_owned();
|
||||
match device {
|
||||
Device::Trt(device_id) => {
|
||||
Self::build_trt(
|
||||
&inputs_attrs.names,
|
||||
&inputs_minoptmax,
|
||||
&builder,
|
||||
&mut builder,
|
||||
device_id,
|
||||
config.trt_int8_enable,
|
||||
config.trt_fp16_enable,
|
||||
@ -103,23 +105,23 @@ impl OrtEngine {
|
||||
)?;
|
||||
}
|
||||
Device::Cuda(device_id) => {
|
||||
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| {
|
||||
Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
|
||||
tracing::warn!("{err}, Using cpu");
|
||||
device = Device::Cpu(0);
|
||||
})
|
||||
}
|
||||
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| {
|
||||
Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
|
||||
tracing::warn!("{err}, Using cpu");
|
||||
device = Device::Cpu(0);
|
||||
}),
|
||||
Device::Cpu(_) => {
|
||||
Self::build_cpu(&builder)?;
|
||||
Self::build_cpu(&mut builder)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
|
||||
let session = builder
|
||||
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
|
||||
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
|
||||
.commit_from_file(&config.onnx_path)?;
|
||||
|
||||
// summary
|
||||
@ -149,7 +151,7 @@ impl OrtEngine {
|
||||
fn build_trt(
|
||||
names: &[String],
|
||||
inputs_minoptmax: &[Vec<MinOptMax>],
|
||||
builder: &SessionBuilder,
|
||||
builder: &mut SessionBuilder,
|
||||
device_id: usize,
|
||||
int8_enable: bool,
|
||||
fp16_enable: bool,
|
||||
@ -205,8 +207,9 @@ impl OrtEngine {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> {
|
||||
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32);
|
||||
fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
|
||||
let ep = ort::execution_providers::CUDAExecutionProvider::default()
|
||||
.with_device_id(device_id as i32);
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
@ -214,8 +217,8 @@ impl OrtEngine {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_coreml(builder: &SessionBuilder) -> Result<()> {
|
||||
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
||||
fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
|
||||
let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
@ -223,8 +226,8 @@ impl OrtEngine {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_cpu(builder: &SessionBuilder) -> Result<()> {
|
||||
let ep = ort::CPUExecutionProvider::default();
|
||||
fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
|
||||
let ep = ort::execution_providers::CPUExecutionProvider::default();
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
@ -292,28 +295,28 @@ impl OrtEngine {
|
||||
let t_pre = std::time::Instant::now();
|
||||
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
|
||||
let x_ = match &idtype {
|
||||
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
|
||||
TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
|
||||
TensorElementType::Float16 => {
|
||||
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int32 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int64 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Uint8 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int8 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Bool => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
||||
}
|
||||
_ => todo!(),
|
||||
};
|
||||
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
|
||||
xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
|
||||
}
|
||||
let t_pre = t_pre.elapsed();
|
||||
self.ts.add_or_push(0, t_pre);
|
||||
@ -451,45 +454,45 @@ impl OrtEngine {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize {
|
||||
fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
|
||||
match x {
|
||||
ort::TensorElementType::Float64
|
||||
| ort::TensorElementType::Uint64
|
||||
| ort::TensorElementType::Int64 => 8, // i64, f64, u64
|
||||
ort::TensorElementType::Float32
|
||||
| ort::TensorElementType::Uint32
|
||||
| ort::TensorElementType::Int32
|
||||
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
||||
ort::TensorElementType::Float16
|
||||
| ort::TensorElementType::Bfloat16
|
||||
| ort::TensorElementType::Int16
|
||||
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
||||
ort::TensorElementType::Uint8
|
||||
| ort::TensorElementType::Int8
|
||||
| ort::TensorElementType::Bool => 1, // u8, i8, bool
|
||||
ort::tensor::TensorElementType::Float64
|
||||
| ort::tensor::TensorElementType::Uint64
|
||||
| ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
|
||||
ort::tensor::TensorElementType::Float32
|
||||
| ort::tensor::TensorElementType::Uint32
|
||||
| ort::tensor::TensorElementType::Int32
|
||||
| ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
||||
ort::tensor::TensorElementType::Float16
|
||||
| ort::tensor::TensorElementType::Bfloat16
|
||||
| ort::tensor::TensorElementType::Int16
|
||||
| ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
||||
ort::tensor::TensorElementType::Uint8
|
||||
| ort::tensor::TensorElementType::Int8
|
||||
| ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> {
|
||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
|
||||
match value {
|
||||
0 => None,
|
||||
1 => Some(ort::TensorElementType::Float32),
|
||||
2 => Some(ort::TensorElementType::Uint8),
|
||||
3 => Some(ort::TensorElementType::Int8),
|
||||
4 => Some(ort::TensorElementType::Uint16),
|
||||
5 => Some(ort::TensorElementType::Int16),
|
||||
6 => Some(ort::TensorElementType::Int32),
|
||||
7 => Some(ort::TensorElementType::Int64),
|
||||
8 => Some(ort::TensorElementType::String),
|
||||
9 => Some(ort::TensorElementType::Bool),
|
||||
10 => Some(ort::TensorElementType::Float16),
|
||||
11 => Some(ort::TensorElementType::Float64),
|
||||
12 => Some(ort::TensorElementType::Uint32),
|
||||
13 => Some(ort::TensorElementType::Uint64),
|
||||
1 => Some(ort::tensor::TensorElementType::Float32),
|
||||
2 => Some(ort::tensor::TensorElementType::Uint8),
|
||||
3 => Some(ort::tensor::TensorElementType::Int8),
|
||||
4 => Some(ort::tensor::TensorElementType::Uint16),
|
||||
5 => Some(ort::tensor::TensorElementType::Int16),
|
||||
6 => Some(ort::tensor::TensorElementType::Int32),
|
||||
7 => Some(ort::tensor::TensorElementType::Int64),
|
||||
8 => Some(ort::tensor::TensorElementType::String),
|
||||
9 => Some(ort::tensor::TensorElementType::Bool),
|
||||
10 => Some(ort::tensor::TensorElementType::Float16),
|
||||
11 => Some(ort::tensor::TensorElementType::Float64),
|
||||
12 => Some(ort::tensor::TensorElementType::Uint32),
|
||||
13 => Some(ort::tensor::TensorElementType::Uint64),
|
||||
14 => None, // COMPLEX64
|
||||
15 => None, // COMPLEX128
|
||||
16 => Some(ort::TensorElementType::Bfloat16),
|
||||
16 => Some(ort::tensor::TensorElementType::Bfloat16),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -499,7 +502,7 @@ impl OrtEngine {
|
||||
value_info: &[onnx::ValueInfoProto],
|
||||
) -> Result<OrtTensorAttr> {
|
||||
let mut dimss: Vec<Vec<usize>> = Vec::new();
|
||||
let mut dtypes: Vec<ort::TensorElementType> = Vec::new();
|
||||
let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
|
||||
let mut names: Vec<String> = Vec::new();
|
||||
for v in value_info.iter() {
|
||||
if initializer_names.contains(v.name.as_str()) {
|
||||
@ -569,7 +572,7 @@ impl OrtEngine {
|
||||
&self.outputs_attrs.names
|
||||
}
|
||||
|
||||
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> {
|
||||
pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||
&self.outputs_attrs.dtypes
|
||||
}
|
||||
|
||||
@ -585,7 +588,7 @@ impl OrtEngine {
|
||||
&self.inputs_attrs.names
|
||||
}
|
||||
|
||||
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> {
|
||||
pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||
&self.inputs_attrs.dtypes
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user