mirror of
https://github.com/mii443/usls.git
synced 2025-12-03 11:08:20 +00:00
upgrade ort to v2.0.0-rc.9 (#52)
This commit is contained in:
@@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
ndarray = { version = "0.16.1", features = ["rayon"] }
|
ndarray = { version = "0.16.1", features = ["rayon"] }
|
||||||
ort = { version = "2.0.0-rc.5", default-features = false}
|
ort = { version = "2.0.0-rc.9", default-features = false }
|
||||||
anyhow = { version = "1.0.75" }
|
anyhow = { version = "1.0.75" }
|
||||||
regex = { version = "1.5.4" }
|
regex = { version = "1.5.4" }
|
||||||
rand = { version = "0.8.5" }
|
rand = { version = "0.8.5" }
|
||||||
@@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
|
|||||||
ab_glyph = "0.2.23"
|
ab_glyph = "0.2.23"
|
||||||
geo = "0.28.0"
|
geo = "0.28.0"
|
||||||
prost = "0.12.4"
|
prost = "0.12.4"
|
||||||
fast_image_resize = { version = "4.2.1", features = ["image"]}
|
fast_image_resize = { version = "4.2.1", features = ["image"] }
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
tempfile = "3.12.0"
|
tempfile = "3.12.0"
|
||||||
@@ -50,7 +50,6 @@ default = [
|
|||||||
"ort/cuda",
|
"ort/cuda",
|
||||||
"ort/tensorrt",
|
"ort/tensorrt",
|
||||||
"ort/coreml",
|
"ort/coreml",
|
||||||
"ort/operator-libraries"
|
|
||||||
]
|
]
|
||||||
auto = ["ort/download-binaries"]
|
auto = ["ort/download-binaries"]
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ use anyhow::Result;
|
|||||||
use half::f16;
|
use half::f16;
|
||||||
use ndarray::{Array, IxDyn};
|
use ndarray::{Array, IxDyn};
|
||||||
use ort::{
|
use ort::{
|
||||||
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
|
execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
|
||||||
|
session::{builder::SessionBuilder, Session},
|
||||||
|
tensor::TensorElementType,
|
||||||
};
|
};
|
||||||
use prost::Message;
|
use prost::Message;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
@@ -88,14 +90,14 @@ impl OrtEngine {
|
|||||||
|
|
||||||
// build
|
// build
|
||||||
ort::init().commit()?;
|
ort::init().commit()?;
|
||||||
let builder = Session::builder()?;
|
let mut builder = Session::builder()?;
|
||||||
let mut device = config.device.to_owned();
|
let mut device = config.device.to_owned();
|
||||||
match device {
|
match device {
|
||||||
Device::Trt(device_id) => {
|
Device::Trt(device_id) => {
|
||||||
Self::build_trt(
|
Self::build_trt(
|
||||||
&inputs_attrs.names,
|
&inputs_attrs.names,
|
||||||
&inputs_minoptmax,
|
&inputs_minoptmax,
|
||||||
&builder,
|
&mut builder,
|
||||||
device_id,
|
device_id,
|
||||||
config.trt_int8_enable,
|
config.trt_int8_enable,
|
||||||
config.trt_fp16_enable,
|
config.trt_fp16_enable,
|
||||||
@@ -103,23 +105,23 @@ impl OrtEngine {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
Device::Cuda(device_id) => {
|
Device::Cuda(device_id) => {
|
||||||
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| {
|
Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
|
||||||
tracing::warn!("{err}, Using cpu");
|
tracing::warn!("{err}, Using cpu");
|
||||||
device = Device::Cpu(0);
|
device = Device::Cpu(0);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| {
|
Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
|
||||||
tracing::warn!("{err}, Using cpu");
|
tracing::warn!("{err}, Using cpu");
|
||||||
device = Device::Cpu(0);
|
device = Device::Cpu(0);
|
||||||
}),
|
}),
|
||||||
Device::Cpu(_) => {
|
Device::Cpu(_) => {
|
||||||
Self::build_cpu(&builder)?;
|
Self::build_cpu(&mut builder)?;
|
||||||
}
|
}
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
}
|
}
|
||||||
|
|
||||||
let session = builder
|
let session = builder
|
||||||
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
|
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
|
||||||
.commit_from_file(&config.onnx_path)?;
|
.commit_from_file(&config.onnx_path)?;
|
||||||
|
|
||||||
// summary
|
// summary
|
||||||
@@ -149,7 +151,7 @@ impl OrtEngine {
|
|||||||
fn build_trt(
|
fn build_trt(
|
||||||
names: &[String],
|
names: &[String],
|
||||||
inputs_minoptmax: &[Vec<MinOptMax>],
|
inputs_minoptmax: &[Vec<MinOptMax>],
|
||||||
builder: &SessionBuilder,
|
builder: &mut SessionBuilder,
|
||||||
device_id: usize,
|
device_id: usize,
|
||||||
int8_enable: bool,
|
int8_enable: bool,
|
||||||
fp16_enable: bool,
|
fp16_enable: bool,
|
||||||
@@ -205,8 +207,9 @@ impl OrtEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> {
|
fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
|
||||||
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32);
|
let ep = ort::execution_providers::CUDAExecutionProvider::default()
|
||||||
|
.with_device_id(device_id as i32);
|
||||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
@@ -214,8 +217,8 @@ impl OrtEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_coreml(builder: &SessionBuilder) -> Result<()> {
|
fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
|
||||||
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
||||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
@@ -223,8 +226,8 @@ impl OrtEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_cpu(builder: &SessionBuilder) -> Result<()> {
|
fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
|
||||||
let ep = ort::CPUExecutionProvider::default();
|
let ep = ort::execution_providers::CPUExecutionProvider::default();
|
||||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
@@ -292,28 +295,28 @@ impl OrtEngine {
|
|||||||
let t_pre = std::time::Instant::now();
|
let t_pre = std::time::Instant::now();
|
||||||
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
|
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
|
||||||
let x_ = match &idtype {
|
let x_ = match &idtype {
|
||||||
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
|
TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
|
||||||
TensorElementType::Float16 => {
|
TensorElementType::Float16 => {
|
||||||
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Int32 => {
|
TensorElementType::Int32 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Int64 => {
|
TensorElementType::Int64 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Uint8 => {
|
TensorElementType::Uint8 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Int8 => {
|
TensorElementType::Int8 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Bool => {
|
TensorElementType::Bool => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
|
xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
|
||||||
}
|
}
|
||||||
let t_pre = t_pre.elapsed();
|
let t_pre = t_pre.elapsed();
|
||||||
self.ts.add_or_push(0, t_pre);
|
self.ts.add_or_push(0, t_pre);
|
||||||
@@ -451,45 +454,45 @@ impl OrtEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize {
|
fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
|
||||||
match x {
|
match x {
|
||||||
ort::TensorElementType::Float64
|
ort::tensor::TensorElementType::Float64
|
||||||
| ort::TensorElementType::Uint64
|
| ort::tensor::TensorElementType::Uint64
|
||||||
| ort::TensorElementType::Int64 => 8, // i64, f64, u64
|
| ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
|
||||||
ort::TensorElementType::Float32
|
ort::tensor::TensorElementType::Float32
|
||||||
| ort::TensorElementType::Uint32
|
| ort::tensor::TensorElementType::Uint32
|
||||||
| ort::TensorElementType::Int32
|
| ort::tensor::TensorElementType::Int32
|
||||||
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
| ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
||||||
ort::TensorElementType::Float16
|
ort::tensor::TensorElementType::Float16
|
||||||
| ort::TensorElementType::Bfloat16
|
| ort::tensor::TensorElementType::Bfloat16
|
||||||
| ort::TensorElementType::Int16
|
| ort::tensor::TensorElementType::Int16
|
||||||
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
| ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
||||||
ort::TensorElementType::Uint8
|
ort::tensor::TensorElementType::Uint8
|
||||||
| ort::TensorElementType::Int8
|
| ort::tensor::TensorElementType::Int8
|
||||||
| ort::TensorElementType::Bool => 1, // u8, i8, bool
|
| ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> {
|
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
|
||||||
match value {
|
match value {
|
||||||
0 => None,
|
0 => None,
|
||||||
1 => Some(ort::TensorElementType::Float32),
|
1 => Some(ort::tensor::TensorElementType::Float32),
|
||||||
2 => Some(ort::TensorElementType::Uint8),
|
2 => Some(ort::tensor::TensorElementType::Uint8),
|
||||||
3 => Some(ort::TensorElementType::Int8),
|
3 => Some(ort::tensor::TensorElementType::Int8),
|
||||||
4 => Some(ort::TensorElementType::Uint16),
|
4 => Some(ort::tensor::TensorElementType::Uint16),
|
||||||
5 => Some(ort::TensorElementType::Int16),
|
5 => Some(ort::tensor::TensorElementType::Int16),
|
||||||
6 => Some(ort::TensorElementType::Int32),
|
6 => Some(ort::tensor::TensorElementType::Int32),
|
||||||
7 => Some(ort::TensorElementType::Int64),
|
7 => Some(ort::tensor::TensorElementType::Int64),
|
||||||
8 => Some(ort::TensorElementType::String),
|
8 => Some(ort::tensor::TensorElementType::String),
|
||||||
9 => Some(ort::TensorElementType::Bool),
|
9 => Some(ort::tensor::TensorElementType::Bool),
|
||||||
10 => Some(ort::TensorElementType::Float16),
|
10 => Some(ort::tensor::TensorElementType::Float16),
|
||||||
11 => Some(ort::TensorElementType::Float64),
|
11 => Some(ort::tensor::TensorElementType::Float64),
|
||||||
12 => Some(ort::TensorElementType::Uint32),
|
12 => Some(ort::tensor::TensorElementType::Uint32),
|
||||||
13 => Some(ort::TensorElementType::Uint64),
|
13 => Some(ort::tensor::TensorElementType::Uint64),
|
||||||
14 => None, // COMPLEX64
|
14 => None, // COMPLEX64
|
||||||
15 => None, // COMPLEX128
|
15 => None, // COMPLEX128
|
||||||
16 => Some(ort::TensorElementType::Bfloat16),
|
16 => Some(ort::tensor::TensorElementType::Bfloat16),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -499,7 +502,7 @@ impl OrtEngine {
|
|||||||
value_info: &[onnx::ValueInfoProto],
|
value_info: &[onnx::ValueInfoProto],
|
||||||
) -> Result<OrtTensorAttr> {
|
) -> Result<OrtTensorAttr> {
|
||||||
let mut dimss: Vec<Vec<usize>> = Vec::new();
|
let mut dimss: Vec<Vec<usize>> = Vec::new();
|
||||||
let mut dtypes: Vec<ort::TensorElementType> = Vec::new();
|
let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
|
||||||
let mut names: Vec<String> = Vec::new();
|
let mut names: Vec<String> = Vec::new();
|
||||||
for v in value_info.iter() {
|
for v in value_info.iter() {
|
||||||
if initializer_names.contains(v.name.as_str()) {
|
if initializer_names.contains(v.name.as_str()) {
|
||||||
@@ -569,7 +572,7 @@ impl OrtEngine {
|
|||||||
&self.outputs_attrs.names
|
&self.outputs_attrs.names
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> {
|
pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||||
&self.outputs_attrs.dtypes
|
&self.outputs_attrs.dtypes
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -585,7 +588,7 @@ impl OrtEngine {
|
|||||||
&self.inputs_attrs.names
|
&self.inputs_attrs.names
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> {
|
pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||||
&self.inputs_attrs.dtypes
|
&self.inputs_attrs.dtypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user